In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

# Sample data (variable-length input/output)
data = [
    {"input": [1, 2, 3], "output": [4, 5]},
    {"input": [4, 5, 6, 7, 8], "output": [9, 10, 11]},
    {"input": [7, 8], "output": [9, 10, 11, 12, 13]},
]

# Encoder Class
class Encoder(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.lstm = nn.LSTM(1, hidden_size, batch_first=True)
    
    def forward(self, x):
        # x shape: (batch_size, seq_len, input_size)
        _, (hidden, cell) = self.lstm(x)
        return hidden, cell

# Decoder Class
class Decoder(nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.lstm = nn.LSTM(1, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
    
    def forward(self, x, hidden, cell):
        # x shape: (batch_size, 1, input_size)
        output, (hidden, cell) = self.lstm(x, (hidden, cell))
        prediction = self.fc(output)
        return prediction, hidden, cell

# Combined Seq2Seq Model
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src, trg=None, max_pred_len=5):
        # Encoder pass
        hidden, cell = self.encoder(src)
        
        # Prepare decoder
        batch_size = src.shape[0]
        decoder_input = src[:, -1:, :]  # Last input as first decoder input
        
        # Determine prediction length
        pred_len = len(trg[0]) if trg is not None else max_pred_len
        
        # Store outputs
        outputs = []
        
        # Decoder pass
        for t in range(pred_len):
            decoder_output, hidden, cell = self.decoder(decoder_input, hidden, cell)
            outputs.append(decoder_output.squeeze(0))
            
            # Teacher forcing (if training)
            if trg is not None and random.random() < 0.5:
                decoder_input = trg[:, t:t+1, :]
            else:
                decoder_input = decoder_output
        
        return torch.cat(outputs, dim=0).unsqueeze(0)

# Helper function to prepare samples
def prepare_sample(sample):
    input_tensor = torch.tensor(sample["input"], dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
    output_tensor = torch.tensor(sample["output"], dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
    return input_tensor, output_tensor

# Initialize models
encoder = Encoder()
decoder = Decoder()
model = Seq2Seq(encoder, decoder)

# Training setup
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()

# Training loop
for epoch in range(100):
    random.shuffle(data)
    total_loss = 0
    
    for sample in data:
        src, trg = prepare_sample(sample)
        optimizer.zero_grad()
        
        pred = model(src, trg=trg)
        loss = criterion(pred, trg)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    #print(f"Epoch {epoch}, Loss: {total_loss / len(data)}")

# Prediction function
def predict(model, input_seq, max_len=5):
    model.eval()
    with torch.no_grad():
        src = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
        pred = model(src, max_pred_len=max_len)
        return pred.squeeze().tolist()

# Test predictions
print(predict(model, [1, 2, 3], 2))       # Should output ~[4, 5]
print(predict(model, [4, 5, 6, 7, 8])) # Should output ~[9, 10, 11]
print(predict(model, [7, 8]))          # Should output ~[9, 10, 11, 12, 13]

[3.6442456245422363, 5.135899066925049, 6.6543192863464355, 8.053271293640137, 9.221745491027832]
[9.019563674926758, 10.173517227172852, 10.826087951660156, 11.196292877197266, 11.374700546264648]
[8.966129302978516, 10.64471435546875, 11.203575134277344, 11.390665054321289, 11.45847225189209]


In [9]:
# Test predictions
print(predict(model, [1, 2, 3], 2))       # Should output ~[4, 5]
print(predict(model, [4, 5, 6, 7, 8])) # Should output ~[9, 10, 11]
print(predict(model, [7, 8]))          # Should output ~[9, 10, 11, 12, 13]

[3.6442456245422363, 5.135899066925049]
[9.019563674926758, 10.173517227172852, 10.826087951660156, 11.196292877197266, 11.374700546264648]
[8.966129302978516, 10.64471435546875, 11.203575134277344, 11.390665054321289, 11.45847225189209]
