In [1]:
import torch
import torch.nn as nn

# ---------------------------------------------
# Vocabulary Setup
# ---------------------------------------------
SRC_VOCAB = {'<pad>': 0, '<eos>': 1, 'hello': 2, 'hi': 3}
TRG_VOCAB = {'<pad>': 0, '<eos>': 1, 'bonjour': 2, 'salut': 3}

SRC_itos = {v: k for k, v in SRC_VOCAB.items()}
TRG_itos = {v: k for k, v in TRG_VOCAB.items()}

INPUT_DIM = len(SRC_VOCAB)
OUTPUT_DIM = len(TRG_VOCAB)
EMB_DIM = 8
HID_DIM = 16

# Sample tokenized source (hello <eos>)
src_sentence = torch.tensor([[SRC_VOCAB['hello'], SRC_VOCAB['<eos>']]], dtype=torch.long)

# ---------------------------------------------
# Encoder
# ---------------------------------------------
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.rnn(embedded)
        return hidden  # (1, batch, hid_dim)

# ---------------------------------------------
# Decoder
# ---------------------------------------------
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, batch_first=True)
        self.fc_out = nn.Linear(hid_dim, output_dim)

    def forward(self, input, hidden):
        input = input.unsqueeze(1)  # (batch, 1)
        embedded = self.embedding(input)  # (batch, 1, emb_dim)
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(1))  # (batch, output_dim)
        return prediction, hidden

# ---------------------------------------------
# Seq2Seq Wrapper with Inference
# ---------------------------------------------
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, max_len=10):
        batch_size = src.shape[0]
        hidden = self.encoder(src)
        
        # Start with <eos> or custom <sos> token
        input_token = torch.tensor([TRG_VOCAB['<eos>']], dtype=torch.long).to(self.device)

        translated_tokens = []
        for _ in range(max_len):
            output, hidden = self.decoder(input_token, hidden)
            top1 = output.argmax(1)  # Greedy decoding
            if top1.item() == TRG_VOCAB['<eos>']:
                break
            translated_tokens.append(top1.item())
            input_token = top1  # Next input is current prediction

        return translated_tokens

# ---------------------------------------------
# Run Translation
# ---------------------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder(INPUT_DIM, EMB_DIM, HID_DIM).to(device)
decoder = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)

# Inference
src_sentence = src_sentence.to(device)
predicted_indices = model(src_sentence)

# Convert indices back to words
translated_sentence = [TRG_itos[idx] for idx in predicted_indices]
print("Predicted translation:", ' '.join(translated_sentence))


Predicted translation: <pad> bonjour bonjour bonjour bonjour bonjour bonjour bonjour bonjour bonjour
