In [2]:
import torch
import torch.nn as nn
from ncps.torch import LTC

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.ltc_layers = nn.ModuleList([LTC(input_size, hidden_size) for _ in range(num_layers)])
    
    def forward(self, input_seq, hidden):
        for ltc in self.ltc_layers:
            output, hidden = ltc(input_seq, hidden)
            input_seq = output  # Pass the output to the next layer
        return hidden
    
    def initHidden(self, batch_size):
        # Initialize hidden state for all layers
        return torch.zeros(self.num_layers, batch_size, self.hidden_size)

class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, num_layers=1):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.ltc_layers = nn.ModuleList([LTC(hidden_size, hidden_size) for _ in range(num_layers)])
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)
    
    def forward(self, input_seq, hidden):
        for ltc in self.ltc_layers:
            output, hidden = ltc(input_seq, hidden)
            input_seq = output  # Pass the output to the next layer
        output = self.fc(output)
        output = self.softmax(output)
        return output, hidden

class Seq2SeqLTC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(Seq2SeqLTC, self).__init__()
        self.encoder = Encoder(input_size, hidden_size, num_layers)
        self.decoder = Decoder(hidden_size, output_size, num_layers)
    
    def forward(self, input_seq, target_seq):
        # Encoder
        encoder_hidden = self.encoder.initHidden(input_seq.size(0))  # Assuming batch_size
        encoder_hidden = self.encoder(input_seq, encoder_hidden)
        
        # Decoder
        decoder_input = target_seq[0]  # Start with the first token
        decoder_hidden = encoder_hidden
        outputs = []
        
        for t in range(1, target_seq.size(0)):  # Loop through the target sequence
            output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            outputs.append(output)
            decoder_input = output  # Feed output back into the decoder
        
        return torch.stack(outputs)

    def initHidden(self, batch_size):
        # Initialize hidden state for all layers
        return torch.zeros(self.num_layers, batch_size, self.hidden_size)

# Example usage
input_size = 256  # Size of input features
hidden_size = 128  # Size of hidden state
output_size = 256  # Size of output features
num_layers = 2  # Number of LTC layers

seq2seq_model = Seq2SeqLTC(input_size, hidden_size, output_size, num_layers)

# Dummy input and target sequences (for illustration)
input_seq = torch.randn(10, 1, input_size)  # Sequence length of 10, batch size of 1
target_seq = torch.randn(10, 1, output_size)  # Sequence length of 10, batch size of 1

output = seq2seq_model(input_seq, target_seq)


RuntimeError: For batched 2-D input, hx and cx should also be 2-D but got (3-D) tensor