In [3]:
cd /kaggle/input/dataset-eng-por

/kaggle/input/dataset-eng-por


In [4]:
import pdb
import torch
import itertools
import numpy as np
import torch.nn as nn
from collections import Counter
from utils_PT import (sentences, train_dataset, val_dataset, train_loader, val_loader,
                   tokenizer_eng, tokenizer_por, masked_loss, masked_acc, ids_to_text, encode_sample, pt_lower_and_split_punct)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data preparation

In [6]:
english_sentences, portuguese_sentences = sentences

print(f"English (to translate) sentence:\n\n{english_sentences[-5]}\n")
print(f"Portuguese (translation) sentence:\n\n{portuguese_sentences[-5]}")

English (to translate) sentence:

No matter how much you try to convince people that chocolate is vanilla, it'll still be chocolate, even though you may manage to convince yourself and a few others that it's vanilla.

Portuguese (translation) sentence:

Não importa o quanto você tenta convencer os outros de que chocolate é baunilha, ele ainda será chocolate, mesmo que você possa convencer a si mesmo e poucos outros de que é baunilha.


In [7]:
del portuguese_sentences
del english_sentences
del sentences

In [8]:
print(f"First 10 words of the english vocabulary:\n\n{sorted(tokenizer_eng.get_vocab().items(), key=lambda item: item[1])[:10]}\n")
print(f"First 10 words of the portuguese vocabulary:\n\n{sorted(tokenizer_por.get_vocab().items(), key=lambda item: item[1])[:10]}")

First 10 words of the english vocabulary:

[('[PAD]', 0), ('[UNK]', 1), ('[EOS]', 2), ('[SOS]', 3), ('.', 4), ('tom', 5), ('i', 6), ('to', 7), ('you', 8), ('the', 9)]

First 10 words of the portuguese vocabulary:

[('[PAD]', 0), ('[UNK]', 1), ('[EOS]', 2), ('[SOS]', 3), ('.', 4), ('tom', 5), ('que', 6), ('o', 7), ('nao', 8), ('eu', 9)]


In [9]:
# Size of the vocabulary
vocab_size_por = tokenizer_eng.get_vocab_size()
vocab_size_eng = tokenizer_eng.get_vocab_size()

print(f"Portuguese vocabulary is made up of {vocab_size_por} words")
print(f"English vocabulary is made up of {vocab_size_eng} words")

Portuguese vocabulary is made up of 12000 words
English vocabulary is made up of 12000 words


In [10]:
def word_to_id(token):
    return tokenizer_por.token_to_id(token)


def ids_to_words(id):
    return tokenizer_por.id_to_token(id)

In [11]:
unk_id = word_to_id("[UNK]")
sos_id = word_to_id("[SOS]")
eos_id = word_to_id("[EOS]")
baunilha_id = word_to_id("baunilha")

print(f"The id for the [UNK] token is {unk_id}")
print(f"The id for the [SOS] token is {sos_id}")
print(f"The id for the [EOS] token is {eos_id}")
print(f"The id for baunilha (vanilla) is {baunilha_id}")

The id for the [UNK] token is 1
The id for the [SOS] token is 3
The id for the [EOS] token is 2
The id for baunilha (vanilla) is 5242


## TODO: SO, there's 2 options why the inference is wrong:
1. Because the model learns the wrong pattern (either the preprocessing is incorrect or other) and the inference is correct.
2. Or because the model learns the right pattern and the inference is implemented incorrectly.


Checking #1 first:
in the cell bellow preprocess exactly the same data sample for both TF and PT versions to check if our version outputs the same 

In [12]:
(to_translate, sr_translation), translation = next(iter(train_loader))

print(f"Tokenized english sentence:\n{to_translate[0, :].numpy()}\n\n")
print(f"Tokenized portuguese sentence (shifted to the right):\n{sr_translation[0, :].numpy()}\n\n")
print(f"Tokenized portuguese sentence:\n{translation[0, :].numpy()}\n\n")

Tokenized english sentence:
[   3  173   46   66  282   66   22 2167  793    4    2    0    0    0
    0    0    0    0    0]


Tokenized portuguese sentence (shifted to the right):
[  3 103 171   6  12 744 378   4   0   0   0   0   0   0   0   0   0   0
   0]


Tokenized portuguese sentence:
[103 171   6  12 744 378   4   2   0   0   0   0   0   0   0   0   0   0
   0]




# Encoder

In [13]:
VOCAB_SIZE = 12000
UNITS = 256

In [14]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, units):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, units, padding_idx=0)
        self.rnn = nn.LSTM(units, units, bidirectional=True, batch_first=True)


    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.rnn(x)
        # Summarizing the bidirectional RNNs to follow the TF version
        forward_output = x[:, :, :UNITS]
        backward_output = x[:, :, UNITS:]
        x = forward_output + backward_output

        return x

In [15]:
encoder = Encoder(VOCAB_SIZE, UNITS)

encoder_output = encoder(to_translate)

print(f'Tensor of sentences in english has shape: {to_translate.shape}\n')
print(f'Encoder output has shape: {encoder_output.shape}')

Tensor of sentences in english has shape: torch.Size([64, 19])

Encoder output has shape: torch.Size([64, 19, 256])


In [16]:
print(to_translate[0].shape)
print(to_translate[0])

torch.Size([19])
tensor([   3,  173,   46,   66,  282,   66,   22, 2167,  793,    4,    2,    0,
           0,    0,    0,    0,    0,    0,    0])


In [17]:
print(ids_to_text([to_translate[0].tolist()], tokenizer_eng))

['[SOS] lets go as soon as it stops raining . [EOS]']


In [18]:
print(sr_translation[0].shape)
print(sr_translation[0])

torch.Size([19])
tensor([  3, 103, 171,   6,  12, 744, 378,   4,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0])


In [19]:
print(ids_to_text([sr_translation[0].tolist()], tokenizer_por))

['[SOS] vamos assim que a chuva parar .']


In [20]:
print(translation[0].shape)
print(translation[0])

torch.Size([19])
tensor([103, 171,   6,  12, 744, 378,   4,   2,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0])


In [21]:
print(ids_to_text([translation[0].tolist()], tokenizer_por))

['vamos assim que a chuva parar . [EOS]']


# Cross Attention

In [22]:
class CrossAttention(nn.Module):
    def __init__(self, units):
        super().__init__()

        self.mha = nn.MultiheadAttention(units, 1, batch_first=True)
        self.layernorm = nn.LayerNorm(units)

    def forward(self, context, target):
        attn_output = self.mha(query=target,key=context, value=context)
        x = target + attn_output[0] # [0] because we only need the attention output and no weights
        x = self.layernorm(x) 

        return x

In [23]:
attention_layer = CrossAttention(UNITS)

sr_translation_embed = nn.Embedding(VOCAB_SIZE, UNITS, 0)(sr_translation)

attention_result = attention_layer(encoder_output, sr_translation_embed)

print(f'Tensor of contexts has shape: {encoder_output.shape}')
print(f'Tensor of translations has shape: {sr_translation_embed.shape}')
print(f'Tensor of attention scores has shape: {attention_result.shape}')

Tensor of contexts has shape: torch.Size([64, 19, 256])
Tensor of translations has shape: torch.Size([64, 19, 256])
Tensor of attention scores has shape: torch.Size([64, 19, 256])


# Decoder

In [24]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, units):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, units, padding_idx=0)
        self.pre_attention_rnn = nn.LSTM(units, units, batch_first=True)
        self.attention = CrossAttention(units)
        self.post_attention_rnn = nn.LSTM(units, units, batch_first=True)
        self.output_layer = nn.Linear(units, vocab_size)
        self.activation = nn.LogSoftmax(dim=-1)

    def forward(self, context, target_in, state=None, return_state=False):
        x = self.embedding(target_in)
        x, (hidden_state, cell_state) = self.pre_attention_rnn(x, state)
        x = self.attention(context, x)
        x, _ = self.post_attention_rnn(x)
        x = self.output_layer(x)
        logits = self.activation(x)

        if return_state:
            return logits, [hidden_state, cell_state]

        return logits

In [25]:
decoder = Decoder(VOCAB_SIZE, UNITS)

logits = decoder(encoder_output, sr_translation)

print(f'Tensor of contexts has shape: {encoder_output.shape}')
print(f'Tensor of right-shifted translations has shape: {sr_translation.shape}')
print(f'Tensor of logits has shape: {logits.shape}')

Tensor of contexts has shape: torch.Size([64, 19, 256])
Tensor of right-shifted translations has shape: torch.Size([64, 19])
Tensor of logits has shape: torch.Size([64, 19, 12000])


# Translator

In [26]:
class Translator(nn.Module):
    def __init__(self, vocab_size, units):
        super().__init__()

        self.encoder = Encoder(vocab_size, units)
        self.decoder = Decoder(vocab_size, units)

    def forward(self, inputs):
        context, targets = inputs

        encoded_context = self.encoder(context)
        logits = self.decoder(encoded_context, targets)

        return logits

In [27]:
translator = Translator(VOCAB_SIZE, UNITS).to(device)

# Loading the model
#translator.load_state_dict(torch.load('/kaggle/working/model_weights.pth', map_location=torch.device(device), weights_only=True))

logits = translator((to_translate.to(device), sr_translation.to(device)))

print(f'Tensor of sentences to translate has shape: {to_translate.shape}')
print(f'Tensor of right-shifted translations has shape: {sr_translation.shape}')
print(f'Tensor of logits has shape: {logits.shape}')

Tensor of sentences to translate has shape: torch.Size([64, 19])
Tensor of right-shifted translations has shape: torch.Size([64, 19])
Tensor of logits has shape: torch.Size([64, 19, 12000])


In [28]:
optimizer = torch.optim.Adam(params=translator.parameters())
criterion = masked_loss
acc = masked_acc

# Training

In [29]:
"""NUM_EPOCHS = 20
STEPS_PER_EPOCH = 500
patience = 3
min_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    # Mini batch loss
    running_loss = 0.0
    # Epoch loss for early stopping
    epoch_loss = 0.0
    translator.train()

    # Using itertools for fixed length iteration over non subscriptable DataLoader
    for i, data in enumerate(itertools.islice(train_loader,  STEPS_PER_EPOCH)):
        (context, target_in), target_out = data

        context, target_in, target_out = context.to(device), target_in.to(device), target_out.to(device)

        optimizer.zero_grad()
        outputs = translator((context, target_in))
        accuracy = acc(target_out, outputs)
        loss = criterion(target_out, outputs)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        #Getting the loss of the epoch
        if i+1 == STEPS_PER_EPOCH:
            epoch_loss = running_loss

        if i % 100 == 99:
            print(f"\n[epoch: {epoch+1}, mini batch: {i+1}] loss: {running_loss:.4f}, accuracy: {accuracy:.4f}\n")
            running_loss = 0

    # Update the best loss if it's better than the previous one
    if epoch_loss < min_loss:
        min_loss = epoch_loss
        patience = 3

    else:
        # Losing patience
        patience -= 1

        if patience == 0:
            print("Early stopping was triggered")"""

'NUM_EPOCHS = 20\nSTEPS_PER_EPOCH = 500\npatience = 3\nmin_loss = float(\'inf\')\n\nfor epoch in range(NUM_EPOCHS):\n    # Mini batch loss\n    running_loss = 0.0\n    # Epoch loss for early stopping\n    epoch_loss = 0.0\n    translator.train()\n\n    # Using itertools for fixed length iteration over non subscriptable DataLoader\n    for i, data in enumerate(itertools.islice(train_loader,  STEPS_PER_EPOCH)):\n        (context, target_in), target_out = data\n\n        context, target_in, target_out = context.to(device), target_in.to(device), target_out.to(device)\n\n        optimizer.zero_grad()\n        outputs = translator((context, target_in))\n        accuracy = acc(target_out, outputs)\n        loss = criterion(target_out, outputs)\n        loss.backward()\n        optimizer.step()\n\n        running_loss += loss.item()\n\n        #Getting the loss of the epoch\n        if i+1 == STEPS_PER_EPOCH:\n            epoch_loss = running_loss\n\n        if i % 100 == 99:\n            prin

## Validation

In [30]:
STEPS_PER_EPOCH = 500
patience = 3
min_loss = float('inf')

running_loss = 0.0
translator.eval()

with torch.no_grad():
    for i, data in enumerate(itertools.islice(val_loader,  STEPS_PER_EPOCH)):
        (context, target_in), target_out = data

        context, target_in, target_out = context.to(device), target_in.to(device), target_out.to(device)

        outputs = translator((context, target_in))
        loss = criterion(target_out, outputs)
        accuracy = acc(target_out, outputs)

        running_loss += loss.item()

        if i % 100 == 99:
            print(f"\n[mini batch: {i+1}] validation loss: {running_loss:.4f}, validation accuracy: {accuracy:.4f}\n")
            running_loss = 0


[mini batch: 100] validation loss: 939.8885, validation accuracy: 0.0000


[mini batch: 200] validation loss: 939.9164, validation accuracy: 0.0000


[mini batch: 300] validation loss: 939.9967, validation accuracy: 0.0000


[mini batch: 400] validation loss: 939.9012, validation accuracy: 0.0000


[mini batch: 500] validation loss: 939.9417, validation accuracy: 0.0000



# Using the model for inference

In [31]:
def generate_next_token(context, decoder, next_token, state, done, temperature=0.0):
    logits, state = decoder(context, next_token, state, return_state=True)
    logits = logits[:, -1, :]

    if temperature == 0.0:
        next_token = torch.argmax(logits, dim=-1)

    else:
        logits = torch.exp(logits)
        logits /= temperature
        next_token = torch.multinomial(logits, 1)
        logits = torch.log(logits)

    logits = torch.squeeze(logits)

    next_token = torch.squeeze(next_token)

    logit = logits[next_token].detach().numpy()

    next_token = torch.reshape(next_token, shape=(1,1))

    if next_token == eos_id:
        done = True

    return next_token, logit, state, done

In [32]:
eng_sentence = "I love languages"

context = torch.tensor(encode_sample(eng_sentence))
context = torch.unsqueeze(context, dim=0)
context = encoder(context)

next_token = torch.full((1,1), sos_id)

state = [torch.rand((1, 1, UNITS)), torch.rand((1, 1, UNITS))]
done = False

next_token, logit, state, done = generate_next_token(context, decoder, next_token, state, done, temperature=0.5)
print(f"Next token: {next_token}\nLogit: {logit:.4f}\nDone? {done}")
next_token = next_token.tolist()
print(ids_to_text(next_token, tokenizer_por))

Next token: tensor([[616]])
Logit: -8.6960
Done? False
['gostou']


# Translate

In [33]:
def translate(model, text, max_length=50, temperature=0.0):
    
    tokens, logits = [], []

    pre_text = text
    text = torch.tensor(encode_sample(pre_text))
    text = torch.unsqueeze(text, dim=0)

    context = encoder(text)

    next_token = torch.full((1,1), sos_id)

    # Try uniform instead of zeros here
    state = [torch.zeros((1, 1, UNITS)), torch.zeros((1, 1, UNITS))]

    done = False
    for iteration in range(max_length):
        try:
            next_token, logit, state, done = generate_next_token(
                context=context,
                decoder=model.decoder,
                next_token=next_token,
                state=state,
                done=done,
                temperature=temperature
            )
        except:
            raise Exception("Problem generating the next token")

        if done:
            break
            
        tokens.append(next_token)
        
        logits.append(logit)

    tokens = torch.cat(tokens, dim=-1).tolist()
    
    translation = ids_to_text(tokens, tokenizer_por)

    return translation, logits[-1], tokens

In [34]:
# Running this cell multiple times should return the same output since temp is 0

temp = 0.0
original_sentence = "I am Christian"

translation, logit, tokens = translate(translator.to("cpu"), original_sentence, temperature=temp)

print(f"Temperature: {temp}\n\nOriginal sentence: {original_sentence}\nTranslation: {translation}\nTranslation tokens:{tokens}\nLogit: {logit:.3f}")

Temperature: 0.0

Original sentence: I am Christian
Translation: ['visitara traduza escolar calmos dilema medieval tamanho esteve esteve comestivel controlada bicicletas curto nativo elefantes encontrar teus carvalho universitarios voltarmos reviver reviver guaxinim ileso dei controlada controlada bicicletas curto atica consciencia comunicam atica ensinasse bocejar dormido excecoes equipamentos inclinou inclinou gelada aceitaria fibras atualizar charutos vemos faleceu numero medieval gire']
Translation tokens:[[6538, 4160, 3881, 5626, 8514, 8889, 1523, 586, 586, 5657, 8404, 3836, 2264, 1698, 2570, 268, 1659, 7347, 11146, 9333, 10933, 10933, 6875, 10329, 854, 8404, 8404, 3836, 2264, 11591, 3439, 7392, 11591, 10051, 5615, 4692, 5021, 8584, 5819, 5819, 3248, 3546, 10216, 8221, 11875, 3190, 2117, 590, 8889, 10284]]
Logit: -9.114


In [35]:
# Running this cell multiple times should return different outputs since temp is not 0
# You can try different temperatures

temp = 0.7
original_sentence = "I love languages"

translation, logit, tokens = translate(translator.to("cpu"), original_sentence, temperature=temp)

print(f"Temperature: {temp}\n\nOriginal sentence: {original_sentence}\nTranslation: {translation}\nTranslation tokens:{tokens}\nLogit: {logit:.3f}")

Temperature: 0.7

Original sentence: I love languages
Translation: ['caminhonete fundamento trilha cale topo dirigisse cedo juntar ananas convencela retorica tirar largos russa recurso presidenciais voo cacador baleia lago perdoado antartica baixo sugerindo cantando gostarias matar bosque relaxar cuidei vencedores penhasco trabalhamos psicologia mostrarlhes escrevame impotente desanimar jurei importo exagera espirrar barro grecia atarefada luto fuso ganhou tristeza caido']
Translation tokens:[[6091, 10254, 8032, 7334, 2998, 5325, 232, 2416, 11438, 8405, 5961, 703, 7704, 7939, 5946, 10791, 935, 7328, 3834, 1176, 7008, 11450, 1054, 6486, 1611, 4275, 1008, 6081, 2533, 4973, 11173, 6391, 2386, 6423, 10549, 6223, 7662, 8480, 5413, 1354, 6244, 10121, 11669, 4731, 11581, 7724, 10257, 890, 7165, 6640]]
Logit: -9.080


# Minimum Bayes-Risk Decoding

In [111]:
def generate_samples(model, text, n_samples=4, temperature=0.6):
    samples, log_probs = [], []
    
    for _ in range(n_samples):
        _, log_prob, sample = translate(model, text, temperature=temperature)
        
        samples.append(sample)
        
        log_probs.append(log_prob)
        
    return samples, log_probs

In [112]:
samples, log_probs = generate_samples(translator, 'I love languages')

for s, l in zip(samples, log_probs):
    print(f"Translated tensor: {s} has logit: {l:.3f}")

Translated tensor: [[1011, 8825, 786, 455, 8227, 11800, 9455, 835, 6202, 9944, 4677, 1601, 8496, 7975, 105, 6434, 5956, 3894, 1334, 11818, 8215, 3936, 2118, 9073, 5334, 3641, 4459, 10467, 1674, 4272, 7239, 10243, 11423, 11122, 7035, 8976, 2605, 5549, 771, 5412, 4563, 11599, 9811, 11277, 3826, 519, 986, 2122, 11435, 9372]] has logit: -8.967
Translated tensor: [[11050, 3894, 9757, 3674, 7566, 5674, 8020, 111, 6804, 10729, 9593, 5570, 2119, 96, 5485, 5263, 9972, 316, 10605, 11926, 5123, 10295, 4584, 1845, 8900, 11130, 6444, 5995, 5007, 8439, 4678, 776, 11417, 4219, 10144, 5719, 4496, 7545, 3568, 3801, 5497, 8081, 9399, 3212, 1540, 4325, 419, 8290, 3624, 5575]] has logit: -8.985
Translated tensor: [[9500, 8851, 8305, 10237, 45, 8254, 6269, 8251, 10604, 9082, 11238, 9226, 9588, 7147, 11821, 1256, 1449, 1397, 1518, 3643, 8749, 5885, 6073, 2989, 10600, 1691, 5838, 6212, 10378, 6688, 1571, 6725, 295, 7776, 8293, 5828, 1332, 1714, 11037, 11563, 370, 4184, 1377, 5320, 6696, 5805, 9105, 11626, 43

# Comparing overlaps

In [151]:
def jaccard_similarity(candidate, reference):
    
    if (isinstance(candidate, list) and all(isinstance(i, list) for i in candidate)) and \
       (isinstance(reference, list) and all(isinstance(i, list) for i in reference)):
        candidate_set = set(candidate[0])
        reference_set = set(reference[0])

    else:
        candidate_set = set(candidate)
        reference_set = set(reference)    
    
    common_tokens = candidate_set.intersection(reference_set)
    
    all_tokens = candidate_set.union(reference_set)
    
    overlap = len(common_tokens) / len(all_tokens)
    
    return overlap

In [152]:
l1 = [1,2,3]
l2 = [1,2,3,4]

js = jaccard_similarity(l1, l2)

print(f"jaccard similarity between lists: {l1} and {l2} is {js:.3f}")

jaccard similarity between lists: [1, 2, 3] and [1, 2, 3, 4] is 0.750


# Rouge1 similarity

In [153]:
def rouge1_similarity(candidate, reference):
    candidate_word_counts = Counter(candidate)
    reference_word_counts = Counter(reference)    
    
    overlap = 0
    
    for token in candidate_word_counts.keys():
        token_count_candidate = candidate_word_counts[token]
        token_count_reference = reference_word_counts[token]        
        
        overlap += min(token_count_candidate, token_count_reference)
        
    precision = overlap / len(candidate)
    
    recall = overlap / len(reference)
    
    if precision + recall != 0:
        f1_score = 2 * (precision * recall) / (precision + recall)
        return f1_score
    
    return 0

In [154]:
l1 = [0, 1]
l2 = [5, 5, 7, 0, 232]

r1s = rouge1_similarity(l1, l2)

print(f"rouge 1 similarity between lists: {l1} and {l2} is {r1s:.3f}")

rouge 1 similarity between lists: [0, 1] and [5, 5, 7, 0, 232] is 0.286


In [155]:
l1 = [1, 2, 3]
l2 = [1, 2, 3, 4]

r1s = rouge1_similarity(l1, l2)

print(f"rouge 1 similarity between lists: {l1} and {l2} is {r1s:.3f}")

rouge 1 similarity between lists: [1, 2, 3] and [1, 2, 3, 4] is 0.857


# Computing the overall score

# Average overlap

In [156]:
def average_overlap(samples, similarity_fn):
    
    scores = {}
    
    for index_candidate, candidate in enumerate(samples):
        overlap = 0
        
        for index_sample, sample in enumerate(samples):
            
            if index_candidate == index_sample:
                continue
                
            overlap += similarity_fn(candidate, sample)
            
        score = overlap / (len(samples) - 1)
        
        score = round(score, 3)
        
        scores[index_candidate] = score
        
    return scores

In [157]:
# Test with Jaccard similarity

l1 = [1, 2, 3]
l2 = [1, 2, 4]
l3 = [1, 2, 4, 5]

avg_ovlp = average_overlap([l1, l2, l3], jaccard_similarity)

print(f"average overlap between lists: {l1}, {l2} and {l3} using Jaccard similarity is:\n\n{avg_ovlp}")

average overlap between lists: [1, 2, 3], [1, 2, 4] and [1, 2, 4, 5] using Jaccard similarity is:

{0: 0.45, 1: 0.625, 2: 0.575}


In [158]:
# Test with Rouge1 similarity

l1 = [1, 2, 3]
l2 = [1, 4]
l3 = [1, 2, 4, 5]
l4 = [5,6]

avg_ovlp = average_overlap([l1, l2, l3, l4], rouge1_similarity)

print(f"average overlap between lists: {l1}, {l2}, {l3} and {l4} using Rouge1 similarity is:\n\n{avg_ovlp}")

average overlap between lists: [1, 2, 3], [1, 4], [1, 2, 4, 5] and [5, 6] using Rouge1 similarity is:

{0: 0.324, 1: 0.356, 2: 0.524, 3: 0.111}


In [159]:
def weighted_avg_overlap(samples, log_probs, similarity_fn):
    scores = {}
    
    for index_candidate, candidate in enumerate(samples):
        overlap, weighted_sum = 0.0, 0.0
        
        for index_sample, (sample, logprob) in enumerate(zip(samples, log_probs)):
            if index_candidate == index_sample:
                continue
                
            sample_prob = float(np.exp(logprob))
            weighted_sum += sample_prob
            
            sample_overlap = similarity_fn(candidate, sample)
            overlap += sample_overlap * sample_prob
            
        score = overlap / weighted_sum
        score = round(score, 3)
        
        scores[index_candidate] = score
        
    return scores

In [160]:
l1 = [1, 2, 3]
l2 = [1, 2, 4]
l3 = [1, 2, 4, 5]
log_probs = [0.4, 0.2, 0.5]

w_avg_ovlp = weighted_avg_overlap([l1, l2, l3], log_probs, jaccard_similarity)

print(f"weighted average overlap using Jaccard similarity is:\n\n{w_avg_ovlp}")

weighted average overlap using Jaccard similarity is:

{0: 0.443, 1: 0.631, 2: 0.558}


In [161]:
def mbr_decode(model, text, n_samples=5, temperature=0.6, similarity_fn=jaccard_similarity):
    samples, log_probs = generate_samples(model, text, n_samples=n_samples, temperature=temperature)
    
    scores = weighted_avg_overlap(samples, log_probs, similarity_fn)
    
    decoded_translations = [ids_to_text(sample,tokenizer_por) for sample in samples]
    
    max_score_key = max(scores, key=lambda k: scores[k])
    
    translation = decoded_translations[max_score_key]
    
    return translation, decoded_translations

In [162]:
english_sentence = "I love languages"

translation, candidates = mbr_decode(translator, english_sentence, n_samples=10, temperature=0.6)

print("Translation candidates:")
for c in candidates:
    print(c)

print(f"\nSelected translation: {translation}")

Translation candidates:
['pareceu defender aprovou amputou zangado minhas propriedade cinico mandam idade afeccoes fracassou culpados dividas caso rios assado automatica recebemos racista orgulhe cansaco proibido surpresas brincarem virou incompetente acordalo responsavel seguiremos valor parava deslize espessa roubo acostuma sortudos opcional odiaria demissao precisas visita aprendizagem chovesse italianos aborreceu adicao brucos anotacoes experiencias']
['convencelo escuras estabulo bibliotecaria rico fritando vivido esvazie madre preocupacoes fundamento zonzo possibilidades machucaria diabetico ignorou oportunidades bandagem cancelado ensaiar acalmaram cobraram abraca dados croissant nevar mundial adicionar concluido corrigir casando mantem empresarios cuida agenda trabalhos costumo bronzeado daninhas mexer panama secou vi pro precisares desinteressado festejava ditador prestem cabelo']
['preta croissant prisional salva identificaram sobreviver exposto reais aberta preferida folga p