In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np

In [2]:
# Data generation
def generate_data(num_samples=10000):
    data = []
    for _ in range(num_samples):
        a = random.randint(1, 999)
        b = random.randint(1, 999)
        question = f"{a}+{b}"
        answer = str(a + b)
        data.append((question, answer))
    return data

# Vocabulary
chars = '0123456789+'
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}
vocab_size = len(chars)

def encode(text, max_len):
    encoded = [char_to_idx[c] for c in text]
    return encoded + [0] * (max_len - len(encoded))

def decode(indices):
    return ''.join([idx_to_char[i] for i in indices if i != 0])

# Prepare data
data = generate_data()
max_input_len = 7  # "999+999"
max_output_len = 4  # "1998"

X = torch.tensor([encode(q, max_input_len) for q, a in data])
y = torch.tensor([encode(a, max_output_len) for q, a in data])

print(f"Data shape: X={X.shape}, y={y.shape}")

Data shape: X=torch.Size([10000, 7]), y=torch.Size([10000, 4])


In [3]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
    
    def forward(self, x):
        embedded = self.embedding(x)
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, hidden, cell

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, hidden, encoder_outputs):
        seq_len = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))
        attention = torch.softmax(self.v(energy).squeeze(2), dim=1)
        context = torch.bmm(attention.unsqueeze(1), encoder_outputs)
        return context.squeeze(1), attention

class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.attention = Attention(hidden_size)
        self.lstm = nn.LSTM(hidden_size * 2, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, input, hidden, cell, encoder_outputs):
        embedded = self.embedding(input.unsqueeze(1))
        context, _ = self.attention(hidden[-1], encoder_outputs)
        lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        prediction = self.out(output.squeeze(1))
        return prediction, hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.encoder = Encoder(vocab_size, hidden_size)
        self.decoder = Decoder(vocab_size, hidden_size)
    
    def forward(self, src, trg):
        encoder_outputs, hidden, cell = self.encoder(src)
        outputs = []
        input = torch.zeros(src.size(0), dtype=torch.long, device=src.device)
        
        for t in range(trg.size(1)):
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs.append(output)
            input = trg[:, t]
        
        return torch.stack(outputs, dim=1)

In [4]:
# Training
model = Seq2Seq(vocab_size, 128)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Split data
train_size = int(0.8 * len(X))
X_train, y_train = X[:train_size], y[:train_size]
X_test, y_test = X[train_size:], y[train_size:]

batch_size = 64
epochs = 20

for epoch in range(epochs):
    model.train()
    total_loss = 0
    
    for i in range(0, len(X_train), batch_size):
        batch_X = X_train[i:i+batch_size]
        batch_y = y_train[i:i+batch_size]
        
        optimizer.zero_grad()
        output = model(batch_X, batch_y)
        loss = criterion(output.reshape(-1, vocab_size), batch_y.reshape(-1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    if (epoch + 1) % 5 == 0:
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(X_train)*batch_size:.4f}')

print("Training completed!")

Epoch 5/20, Loss: 1.4373
Epoch 10/20, Loss: 1.2674
Epoch 15/20, Loss: 1.1438
Epoch 20/20, Loss: 1.0056
Training completed!


In [5]:
def predict(model, question):
    model.eval()
    with torch.no_grad():
        # Encode input
        input_seq = torch.tensor([encode(question, max_input_len)])
        
        # Get encoder outputs
        encoder_outputs, hidden, cell = model.encoder(input_seq)
        
        # Decode step by step
        outputs = []
        input_token = torch.zeros(1, dtype=torch.long)
        
        for _ in range(max_output_len):
            output, hidden, cell = model.decoder(input_token, hidden, cell, encoder_outputs)
            predicted = output.argmax(dim=1)
            outputs.append(predicted.item())
            input_token = predicted
            
            if predicted.item() == 0:  # Stop at padding
                break
        
        return decode(outputs)

# Test on some examples
test_examples = ["123+456", "789+111", "50+25", "999+1"]
for example in test_examples:
    prediction = predict(model, example)
    actual = str(eval(example))
    print(f"{example} = {prediction} (actual: {actual})")

123+456 = 6181 (actual: 579)
789+111 = 9121 (actual: 900)
50+25 = 7132 (actual: 75)
999+1 = 1134 (actual: 1000)


In [None]:
# Interactive prediction cell
user_input = input("Enter an addition problem (e.g., '123+456'): ")
try:
    prediction = predict(model, user_input)
    actual = str(eval(user_input))
    print(f"\nModel prediction: {user_input} = {prediction}")
    print(f"Actual answer: {actual}")
    print(f"Correct: {'✓' if prediction == actual else '✗'}")
except:
    print("Invalid input. Please use format like '123+456'")