In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleModel(nn.Module):
    def __init__(self, v_size, embed_size):
        super(SimpleModel, self).__init__()
        self.embedding = nn.Embedding(v_size, embed_size)
        self.fc = nn.Linear(embed_size, v_size) 
    
    def forward(self, x):
        x = self.embedding(x) 
        x = torch.mean(x, dim=1)  
        logits = self.fc(x)  
        return logits

class BeamSearchDecoder:
    def __init__(self, model, beam_width, max_len, sostoken, eostoken, device):
        self.model = model
        self.beam_width = beam_width
        self.max_len = max_len
        self.sostoken = sostoken
        self.eostoken = eostoken
        self.device = device

    def decode(self, src_seq):
        if not isinstance(src_seq, torch.Tensor):
            src_seq = torch.tensor(src_seq, dtype=torch.long, device=self.device)
        
        batch_size = src_seq.size(0)
        beams = [(torch.tensor([self.sostoken], device=self.device), 0)] * self.beam_width
        
        for _ in range(self.max_len):
            all_candidates = []
            for seq, score in beams:
                if seq[-1] == self.eostoken:
                    all_candidates.append((seq, score))
                else:
                    input_seq = seq.unsqueeze(0)  #batch_dim
                    with torch.no_grad():
                        logits = self.model(input_seq)
                    
                    probs = F.softmax(logits, dim=-1)
                    top_probs, top_indices = probs.topk(self.beam_width)

                    for i in range(self.beam_width):
                        candidate = (torch.cat([seq, top_indices[0, i].unsqueeze(0)]), score - torch.log(top_probs[0, i]))
                        all_candidates.append(candidate)

            ordered = sorted(all_candidates, key=lambda x: x[1])
            beams = ordered[:self.beam_width]

        return beams[0][0].tolist()

# Defining models
vocab_size = 100
embed_size = 50

# Initialize the model
model = SimpleModel(vocab_size, embed_size)

# moving model 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

#moving seq
src_seq = torch.tensor(src_seq, dtype=torch.long).unsqueeze(0).to(device)


beam_width = 3
max_len = 50
sostoken = 1
eostoken = 9


decoder = BeamSearchDecoder(model, beam_width, max_len, sostoken, eostoken, device)

try:
    translation = decoder.decode(src_seq)
    print("Translated Sequence:", translation)
except Exception as e:
    print(f"An error occurred during decoding: {e}")


Translated Sequence: [1, 44, 44, 44, 78, 7, 60, 60, 60, 86, 86, 86, 86, 86, 86, 46, 90, 46, 90, 90, 46, 90, 90, 46, 25, 25, 90, 25, 25, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17]


  src_seq = torch.tensor(src_seq, dtype=torch.long).unsqueeze(0).to(device)
