<a href="https://colab.research.google.com/github/mr-alamdari/NLP-Neural-Machine-Translation-Beginner/blob/main/NLP_Neural_Machine_Translation_Beginner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import os
import re
import random
import numpy as np

In [5]:
# !pip install trax
import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

In [6]:
train_stream_function = trax.data.TFDS('opus/medical',
                                 data_dir='./data/',
                                 keys=('en', 'de'),
                                 eval_holdout_size=0.01, 
                                 train=True)

eval_stream_function = trax.data.TFDS('opus/medical',
                                data_dir='./data/',
                                keys=('en', 'de'),
                                eval_holdout_size=0.01, 
                                train=False)

  "jax.host_count has been renamed to jax.process_count. This alias "


[1mDownloading and preparing dataset opus/medical/0.1.0 (download: 34.29 MiB, generated: 188.85 MiB, total: 223.13 MiB) to ./data/opus/medical/0.1.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to ./data/opus/medical/0.1.0.incomplete8ETHDE/opus-train.tfrecord


  0%|          | 0/1108752 [00:00<?, ? examples/s]

[1mDataset opus downloaded and prepared to ./data/opus/medical/0.1.0. Subsequent calls will reuse this data.[0m


In [11]:
train_stream = train_stream_function()
print(next(train_stream))
print()

eval_stream = eval_stream_function()
print(next(eval_stream))

(b'In the pregnant rat the AUC for calculated free drug at this dose was approximately 18 times the human AUC at a 20 mg dose.\n', b'Bei tr\xc3\xa4chtigen Ratten war die AUC f\xc3\xbcr die berechnete ungebundene Substanz bei dieser Dosis etwa 18-mal h\xc3\xb6her als die AUC beim Menschen bei einer 20 mg Dosis.\n')

(b'Subcutaneous use and intravenous use.\n', b'Subkutane Anwendung und intraven\xc3\xb6se Anwendung.\n')


In [14]:
VOCAB_FILE = 'ende_32k.subword'
VOCAB_DIR = 'data/'

tokenized_train_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)
tokenized_eval_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)

In [16]:
EOS = 1
def add_eos(stream):
    for (inputs, targets) in stream:
        inputs_with_eos = list(inputs) + [EOS]
        targets_with_eos = list(targets) + [EOS]
        yield np.array(inputs_with_eos), np.array(targets_with_eos)

tokenized_train_stream = add_eos(tokenized_train_stream)
tokenized_eval_stream = add_eos(tokenized_eval_stream)

In [19]:
filtered_train_stream = trax.data.FilterByLength(max_length=256, length_keys=[0, 1])(tokenized_train_stream)
filtered_eval_stream = trax.data.FilterByLength(max_length=512, length_keys=[0, 1])(tokenized_eval_stream)

train_input, train_target = next(filtered_train_stream)
print(train_input)
print(train_target)

In [20]:
def tokenize(input_str, vocab_file=None, vocab_dir=None, EOS=1):
    inputs =  next(trax.data.tokenize(iter([input_str]), vocab_file=vocab_file, vocab_dir=vocab_dir))
    inputs = list(inputs) + [EOS]
    batch_inputs = np.reshape(np.array(inputs), [1, -1])
    
    return batch_inputs

In [21]:
def detokenize(integers, vocab_file=None, vocab_dir=None, EOS=1):
    integers = list(np.squeeze(integers))
    if EOS in integers:
        integers = integers[:integers.index(EOS)] 
    return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)

In [None]:
boundaries =  [8,   16,  32, 64, 128, 256, 512]
batch_sizes = [256, 128, 64, 32, 16,    8,   4,  2]

train_batch_stream = trax.data.BucketByLength(boundaries, batch_sizes, length_keys=[0, 1])(filtered_train_stream)
eval_batch_stream = trax.data.BucketByLength(boundaries, batch_sizes, length_keys=[0, 1]  )(filtered_eval_stream)

train_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)
eval_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(eval_batch_stream)

In [None]:
input_batch, target_batch, mask_batch = next(train_batch_stream)

In [23]:
index = random.randrange(len(input_batch))

In [24]:
def input_encoder_fn(input_vocab_size, d_model, n_encoder_layers):
    input_encoder = tl.Serial( 
        tl.Embedding(vocab_size=input_vocab_size, d_feature=d_model),
        [tl.LSTM(n_units=d_model) for _ in range(n_encoder_layers)])
    return input_encoder

In [25]:
def pre_attention_decoder_fn(mode, target_vocab_size, d_model):
    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(vocab_size=target_vocab_size, d_feature=d_model),
        tl.LSTM(n_units=d_model))
    
    return pre_attention_decoder

In [26]:
def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    keys = encoder_activations
    values = encoder_activations
    queries = decoder_activations
    mask = inputs != 0
    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))
    return queries, keys, values, mask

In [29]:
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):

   nn = cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature))
   return nn

In [30]:
def NMTAttn(input_vocab_size=33300,target_vocab_size=33300,d_model=1024,n_encoder_layers=2,n_decoder_layers=2,n_attention_heads=4,attention_dropout=0.0,mode='train'):
    input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers)
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model)
    model = tl.Serial( 
      tl.Select([0,1,0,1]),
      tl.Parallel(input_encoder, pre_attention_decoder),
      tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),
      tl.Residual(tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)),
      tl.Select([0,2]),
      [tl.LSTM(n_units=d_model) for _ in range(n_decoder_layers)],
      tl.Dense(target_vocab_size),
       tl.LogSoftmax()
    )
    return model

In [31]:
train_task = training.TrainTask(
    labeled_data= train_batch_stream,
    loss_layer= tl.CrossEntropyLoss(),
    optimizer= trax.optimizers.Adam(0.01),
    lr_schedule= trax.lr.warmup_and_rsqrt_decay(1000, 0.01),
    n_steps_per_checkpoint= 10,)

In [32]:
eval_task = training.EvalTask(labeled_data=eval_batch_stream,metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],)

In [33]:
output_dir = 'output_dir/'

training_loop = training.Loop(NMTAttn(mode='train'),
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)
training_loop.run(10)

In [None]:
model = NMTAttn(mode='eval')

model.init_from_file("model.pkl.gz", weights_only=True)
model = tl.Accelerate(model)

In [34]:
def logsoftmax_sample(log_probs, temperature=1.0):  # pylint: disable=invalid-name
  u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
  g = -np.log(-np.log(u))
  return np.argmax(log_probs + g * temperature, axis=-1)

In [35]:
def next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature):
    token_length = len(cur_output_tokens)
    padded_length = np.power(2, int(np.ceil(np.log2(token_length + 1))))
    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.expand_dims(padded, axis=0)
    output, _ = NMTAttn((input_tokens, padded_with_batch))
    log_probs = output[0, token_length, :]
    symbol = int(tl.logsoftmax_sample(log_probs, temperature))
    return symbol, float(log_probs[symbol])

In [36]:
def sampling_decode(input_sentence, NMTAttn = None, temperature=0.0, vocab_file=None, vocab_dir=None):
    input_tokens = tokenize(input_sentence,vocab_file,vocab_dir)
    cur_output_tokens = []
    cur_output = 0
    EOS = 1
    while cur_output != EOS:
        
        cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
        cur_output_tokens.append(cur_output)
    sentence = detokenize(cur_output_tokens, vocab_file, vocab_dir)
    return cur_output_tokens, log_prob, sentence

In [None]:
sampling_decode("I love languages.", model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)

In [38]:
def greedy_decode_test(sentence, NMTAttn=None, vocab_file=None, vocab_dir=None):
    _,_, translated_sentence = sampling_decode(sentence, NMTAttn, vocab_file=vocab_file, vocab_dir=vocab_dir)
    print(f"English: {sentence}")
    print(f"German: {translated_sentence}")    
    return translated_sentence

In [39]:
def generate_samples(sentence, n_samples, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None):
    samples, log_probs = [], []

    for _ in range(n_samples):
        sample, logp, _ = sampling_decode(sentence, NMTAttn, temperature, vocab_file=vocab_file, vocab_dir=vocab_dir)
        samples.append(sample)
        log_probs.append(logp)
    return samples, log_probs

In [40]:
def jaccard_similarity(candidate, reference):
    can_unigram_set, ref_unigram_set = set(candidate), set(reference)  
    joint_elems = can_unigram_set.intersection(ref_unigram_set)
    all_elems = can_unigram_set.union(ref_unigram_set)
    overlap = len(joint_elems) / len(all_elems)
    return overlap

In [41]:
from collections import Counter

def rouge1_similarity(system, reference):
    sys_counter = Counter(system)
    ref_counter = Counter(reference)
    overlap = 0
    for token in sys_counter:
        
        token_count_sys = sys_counter.get(token,0)
        token_count_ref = ref_counter.get(token,0)
        overlap += min(token_count_sys, token_count_ref)
    
    precision = overlap / sum(sys_counter.values())
    recall = overlap / sum(ref_counter.values())
    sum_p_r = precision + recall
    if sum_p_r != 0:
        rouge1_score = 2 * ((precision * recall)/(sum_p_r)) if sum_p_r != 0 else 0
    
    return rouge1_score

In [44]:
rouge1_similarity([19, 5, 70], [9, 5, 70, 85])

0.5714285714285715

In [45]:
def average_overlap(similarity_fn, samples, *ignore_params):
    scores = {}
    for index_candidate, candidate in enumerate(samples):    
        overlap = 0.0
        for index_sample, sample in enumerate(samples): 

            if index_candidate == index_sample:
                continue
            sample_overlap = similarity_fn(candidate,sample)
            overlap += sample_overlap
        score = overlap/index_sample
        scores[index_candidate] = score
    return scores

In [46]:
average_overlap(jaccard_similarity, [[1, 2, 3], [1, 2, 4], [1, 2, 4, 5]], [0.4, 0.2, 0.5])

{0: 0.45, 1: 0.625, 2: 0.575}

In [48]:
def weighted_avg_overlap(similarity_fn, samples, log_probs):
    scores = {}
    for index_candidate, candidate in enumerate(samples):    
        overlap, weight_sum = 0.0, 0.0
        for index_sample, (sample, logp) in enumerate(zip(samples, log_probs)):

            if index_candidate == index_sample:
                continue
                
            sample_p = float(np.exp(logp))
            weight_sum += sample_p
            sample_overlap = similarity_fn(candidate, sample)
            overlap += sample_p * sample_overlap
        score = overlap / weight_sum
        scores[index_candidate] = score
    return scores


In [49]:
weighted_avg_overlap(jaccard_similarity, [[1, 2, 3], [1, 2, 4], [1, 2, 4, 5]], [0.4, 0.2, 0.5])

{0: 0.44255574831883415, 1: 0.631244796869735, 2: 0.5575581009406329}

In [50]:
def mbr_decode(sentence, n_samples, score_fn, similarity_fn, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None):
    samples, log_probs = generate_samples(sentence, n_samples, NMTAttn, temperature, vocab_file, vocab_dir)
    scores = weighted_avg_overlap(jaccard_similarity, samples, log_probs)
    max_index = max(scores, key=scores.get)
    translated_sentence = detokenize(samples[max_index], vocab_file, vocab_dir)
    return (translated_sentence, max_index, scores)