In [1]:
import inspect
from typing import Dict

import torch
from torch import nn

from neuralhydrology.modelzoo import get_model
from neuralhydrology.modelzoo.head import get_head
from neuralhydrology.modelzoo.basemodel import BaseModel
from neuralhydrology.utils.config import Config
from neuralhydrology.modelzoo.inputlayer import InputLayer

In [None]:
class LSTMAttention(BaseModel):
    # Define the parts of the model that can used for finetuning
    module_parts = ['embedding_net', 'lstm', 'attention', 'head']

    def __init__(self, cfg: Config):
        super().__init__(cfg=cfg)

        # Input layer: dynamic and static feature inputs
        self.embedding_net = InputLayer(cfg)

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=self.embedding_net.output_size,
            hidden_size=cfg.hidden_size,
            batch_first=True
        )

        # Attention: Layers for calculating attention scores.
        # Dot product attention
        self.attention_query_transform = nn.Linear(cfg.hidden_size, cfg.hidden_size)

        # Dropout
        self.dropout = nn.Dropout(p=cfg.output_dropout)

        # Output head
        self.head = get_head(cfg=cfg, n_in=cfg.hidden_size, n_out=self.output_size)

    def forward (self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Pass dynamic and static inputs through the embedding network
        x = self.embedding_net(data)

        # Pass the embedded sequence through the LSTM
        # output: hidden states for each time step [batch_size, sequence_length, hidden_size]
        # (h_n, c_n): final hidden and cell states [1 (layer), batch_size, hidden_size] (for a single layer)
        output, (h_n, c_n) = self.lstm(x)

        # h_n shape: [1, batch_size, hidden_size] -> squeeze layer dim -> [batch_size, hidden_size]
        h_n_squeezed = h_n.squeeze(0)

        # Project h_n before using as query
        query = self.attention_query_transform(h_n_squeezed) # shape [batch_size, hidden_size]

        # Calculate Attention scores
        # Compare the query (final hidden state) to each time step's output
        # use batch matrix multiply (bmm)
        query_for_scores = query.unsqueeze(1) # shape [batch_size, 1, hidden_size]

        # Transpose output: [batch_size, hidden_size, sequence_length]
        output_transposed = output.transpose(1, 2)

        # Batch matrix multiply: (query [batch, 1, hidden]) * (output_transposed [batch, hidden, sequence])
        scores = torch.bmm(query_for_scores, output_transposed) # shape [batch_size, 1, squence_length]

        # Squeeze the extra dimension: [batch_size, sequence_length]
        scores = scores.squeeze(1)

        # Calculate attention weights (Softmax)
        attention_weights = torch.softmax(scores, dim=1) # Apply softmax across the sequence_length dimension, shape: [batch_size, sequence_length]

        # Create context vector
        # unsqueeze attention_weights to match output shape for element-wise multiplication
        attention_weights_unsqueezed = attention_weights.unsqueeze(2) # shape [batch_size, sequence_length, 1]

        # element-wise multiplication
        weighted_output = attention_weights_unsqueezed * output # shape [batch_size, sequence_length, hidden_size]

        # sum across the sequence_length dimension to get the context vector
        context_vector = torch.sum(weighted_output, dim=1) # shape [batch_size, hidden_size]

        # Apply dropout
        context_vector_dropped = self.dropout(context_vector)

        # Pass through output head to get a prediction
        y_hat = self.head(context_vector_dropped) # shape [batch_size, output_size]

        # Return results in a dictionary
        return {'y_hat': y_hat,
                'attention_weights': attention_weights}


NameError: name 'BaseModel' is not defined