# Build an efficient unsupervised word translator

Based on : "Word Translation Without Parallel Data" by Alexis Conneau, Guillaume Lample, Marc Aurelio Ranzato, Ludovic Denoyer & Hervé Jégou (2017)

In this notebook we will explore Generative Adversarial Networks (GANs) for word translator. GANs belong to the set of algorithms named generative models. Generative models learn the intrinsic distribution function of the input words data p(x), allowing them to generate both synthetic inputs/sources x’ and output/targets y’.

## Data pre-processing

In [1]:
import io

import numpy as np
import matplotlib.pyplot as plt
import math

import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
#from torch.autograd.variable import Variable

from scipy.stats import special_ortho_group

#from sklearn.metrics.pairwise import cosine_similarity

In [2]:
# load function for pretrained versions of word embeddings
def load_embeddings(emb_path, nmax=50000):
    vectors = []
    word2id = {}
    with io.open(emb_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
        next(f)
        for i, line in enumerate(f):
            word, vect = line.rstrip().split(' ', 1)
            vect = np.fromstring(vect, sep=' ')
            assert word not in word2id, 'word found twice'
            vectors.append(vect)
            word2id[word] = len(word2id)
            if len(word2id) == nmax:
                break
    id2word = {v: k for k, v in word2id.items()}
    embeddings = np.vstack(vectors)
    return embeddings, id2word, word2id

# load ground-truth bilingual dictionaries function
def load_dic(path):
    dico_full = {}
    vectors_src=[]
    vectors_tgt = []
    with io.open(path,'r',encoding='utf_8') as f:
        for i,line in enumerate(f):
            word_src, word_tgt = line.rstrip().split(' ',1)
            if word_tgt in tgt_word2id :
                dico_full[word_src]=word_tgt
    for key in dico_full.keys() :
            vectors_src.append(src_embeddings[src_word2id[key]])
            vectors_tgt.append(tgt_embeddings[tgt_word2id[dico_full[key]]])
    X = np.vstack(vectors_src)
    Z = np.vstack (vectors_tgt)
    return dico_full,X,Z

In [4]:
nmax = 50000

eng_path = '/Users/minh/Downloads/CSE293_NLP-master/wiki.en.vec'
fr_path = '/Users/minh/Downloads/CSE293_NLP-master/wiki.fr.vec'

# load monolingual word embeddings 
src_embeddings, src_id2word, src_word2id = load_embeddings(fr_path, nmax) # source = french 
tgt_embeddings, tgt_id2word, tgt_word2id = load_embeddings(eng_path, nmax) # target = english

In [5]:
# train & test bilingual dictionaries

path_train = '/Users/minh/Downloads/CSE293_NLP-master/fr-en.0-5000.txt' 
path_test = '/Users/minh/Downloads/CSE293_NLP-master/fr-en.5000-6500.txt'

dico_train, X_train, Z_train = load_dic(path_train)
dico_test, X_test, Z_test = load_dic(path_test)

# convert embeddings vectors into torch tensors 
print(type(X_train[0]))
X_train, Z_train, X_test, Z_test = map(torch.tensor, (X_train, Z_train, X_test, Z_test)) 
print(type(X_train[0]))

<class 'numpy.ndarray'>
<class 'torch.Tensor'>


In [6]:
print(X_train.shape[0], "training samples")
print(X_test.shape[0], "test samples")
dim = X_train.shape[1]
print("Vectors dimension :", dim)

4971 training samples
1483 test samples
Vectors dimension : 300


# Coding GANs

## Build the discriminator 

Discriminator model’s goal is to recognize if an input data is ‘real’ y (target) — belongs to the original dataset — or if it is ‘fake’ Wx (source and translation matrix trained) — generated by a forger. 

In [7]:
class Discriminator(nn.Module):
    def __init__(self, dim):
        super(Discriminator,self).__init__()
        self.h1 = nn.Linear(dim, 2048,bias=True) # 1st hidden layer
        self.h2 = nn.Linear(2048,2048,bias=True) # 2nd hidden layer
        self.out = nn.Linear(2048,1,bias=True) # output layer
        
    def forward(self, x):
        x = F.dropout(x, p = 0.1) # dropout pour ajouter du bruit
        x = F.leaky_relu(self.h1(x), negative_slope=0.2)
        x = F.leaky_relu(self.h2(x), negative_slope=0.2)
        y = torch.sigmoid(self.out(x)) # ouput = proba
        return y

## Build the generator 

Generator aims to generate new data similar to the expected one.

In [8]:
# simple linear function 
# can be seen at a neural network whose weights are elements of W 
class Generator(nn.Module):
    def __init__(self, dim):
        super(Generator, self).__init__()
        self.l1 = nn.Linear(dim, dim)

    def forward(self,x):
        y = self.l1(x)
        return y

In [9]:
# we could put this inside the class...

# to ensure that the matrix stays close to the manifold of orthogonal matrices after each update
def ortho_update(W, beta):
    W = (1+beta)*W - beta*torch.mm(torch.mm(W, W.t()), W)

## Optimization

Here we’ll use SGD as the optimization algorithm for both neural networks, with a learning rate of 0.1. The proposed learning rate was obtained after testing with several values, though it isn’t necessarily the optimal value for this task. 

The loss function we’ll be using for this task is named Binary Cross Entopy Loss (BCE Loss)

<img src='https://miro.medium.com/max/5728/1*IcuF1_TXjngF2VHQjdwzjg.png'>

In [10]:
discrim = Discriminator(dim) #Definie Discriminator and Generator
gen = Generator(dim)

optimD = optim.SGD(discrim.parameters(), lr=0.1)#Definie optimizer
optimG = optim.SGD(gen.parameters(), lr=0.1)

In [16]:
LossD = nn.BCELoss() #Definie loss function
LossG = nn.BCELoss()

# Trainning GANs 

Now that we’ve defined the dataset, networks, optimization and learning algorithms we can train our GAN. 

<img src='https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/05/Summary-of-the-Generative-Adversarial-Network-Training-Algorithm.png'>

In [69]:
src_embedding_learnable = nn.Embedding(nmax,dim)
src_embedding_learnable.weight = nn.Parameter(torch.from_numpy(src_embeddings).float())
src_embeddings_norm = src_embedding_learnable.weight/src_embedding_learnable.weight.norm(2, 1, keepdim=True).expand_as(src_embedding_learnable.weight)
src_embedding_learnable.weight.data.copy_(src_embeddings_norm)
target_embedding_learnable = nn.Embedding(nmax,dim)
target_embedding_learnable.weight = nn.Parameter(torch.from_numpy(tgt_embeddings).float())
target_embeddings_norm = target_embedding_learnable.weight/target_embedding_learnable.weight.norm(2, 1, keepdim=True).expand_as(target_embedding_learnable.weight)
target_embedding_learnable.weight.data.copy_(target_embeddings_norm)

tensor([[-0.0113, -0.0021, -0.0515,  ...,  0.0436, -0.0077,  0.0724],
        [-0.0469, -0.0006, -0.0751,  ...,  0.0268, -0.0514,  0.0166],
        [-0.0324, -0.0462, -0.0087,  ...,  0.0827, -0.0650,  0.0176],
        ...,
        [ 0.0869, -0.0376, -0.1196,  ...,  0.0690,  0.0346, -0.0111],
        [ 0.0167, -0.0221, -0.0283,  ...,  0.0110,  0.1053, -0.0534],
        [-0.0230,  0.0129, -0.0647,  ...,  0.0033,  0.0202,  0.0013]],
       grad_fn=<CopyBackwards>)

In [83]:
def top_words(emb1, emb2):
    #top translation pairs
    size = 64

    ranked_scores = []
    ranked_targets = []
    num = 2000

    # average distances to 10 nearest neighbors
    average_dist_src = avg_10_distance(emb2, emb1)
    average_dist_target = avg_10_distance(emb2, emb1)
    average_dist_src = average_dist_src.type_as(emb1)
    average_dist_target = average_dist_target.type_as(emb2)

    for i in range(0, num, size):
        scores = emb2.mm(emb1[i:min(num, i + size)].t()).t()
        scores.mul_(2)
        scores.sub_(average_dist_src[i:min(num, i + size)][:, None] + average_dist_target[None, :])
        best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True)

        # update scores / potential targets
        ranked_scores.append(best_scores.cpu())
        ranked_targets.append(best_targets.cpu())

    ranked_scores = torch.cat(ranked_scores, 0)
    ranked_targets = torch.cat(ranked_targets, 0)

    ranked_pairs = torch.cat([torch.arange(0, ranked_targets.size(0)).long().unsqueeze(1),ranked_targets[:, 0].unsqueeze(1)], 1)

    #Reordering them
    diff = ranked_scores[:, 0] - ranked_scores[:, 1]
    reordered = diff.sort(0, descending=True)[1]
    ranked_scores = ranked_scores[reordered]
    ranked_pairs = ranked_pairs[reordered]

    selected = ranked_pairs.max(1)[0] <= num
    mask = selected.unsqueeze(1).expand_as(ranked_scores).clone()
    ranked_scores = ranked_scores.masked_select(mask).view(-1, 2)
    ranked_pairs = ranked_pairs.masked_select(mask).view(-1, 2)

    return ranked_pairs

def avg_10_distance(emb2,emb1):
    size = 128
    all_distances = []
    emb2 = emb2.t().contiguous()
    for i in range(0, emb1.shape[0], size):
        distances = emb1[i:i + size].mm(emb2)
        best_distances, _ = distances.topk(10, dim=1, largest=True, sorted=True)
        all_distances.append(best_distances.mean(1)
    all_distances = torch.cat(all_distances)
    return all_distances

def proxy_construct_dictionary(src_emb_map_validation,target_emb_map_validation,src_to_target_dictionary,target_to_src_dictionary):    
    target_to_src_dictionary = torch.cat([target_to_src_dictionary[:, 1:], target_to_src_dictionary[:, :1]], 1)
    src_to_target_dictionary = set([(a, b) for a, b in src_to_target_dictionary])
    target_to_src_dictionary = set([(a, b) for a, b in target_to_src_dictionary])
    final_pairs = src_to_target_dictionary.intersection(target_to_src_dictionary)
    if len(final_pairs) == 0:
        return None
    dictionary = torch.Tensor(list([[a, b] for (a, b) in final_pairs])).long()
    return dictionary

In [147]:
for epoch in range(3): #3 Epochs 
    for iteration in range(10000):
        if iteration % 10 == 0 :
            print("epoch = %d, iteration = %d"%(epoch,iteration))
        # discrim trained 3 times for every mapping training
        for i in range(3):
            discrim.train()
            #Set gradient to zero before computation at every step
            optimD.zero_grad()
            rand_src_word_id = torch.Tensor(32).random_(50000).long()
            src_word_emb = src_embedding_learnable(rand_src_word_id)
            rand_tgt_word_id = torch.Tensor(32).random_(50000).long()
            tgt_word_emb = target_embedding_learnable(rand_tgt_word_id)
            wsrc_gen = gen(src_word_emb)
            input_tensor = torch.cat([wsrc_gen,tgt_word_emb],0)
            output_tensor = torch.Tensor(64).zero_().float()
            output_tensor[:32] = 1 -0.2 #Smoothing
            output_tensor[32:] = 0.2
            prediction = discrim(input_tensor)
            #Compute loss and propogate backward
            loss = LossD(prediction,output_tensor)
            loss.backward()
            optimD.step()

        # mapping training 
        discrim.eval()
        #Set gradient to zero before computation at every step
        optimG.zero_grad()
        rand_src_word_id = torch.Tensor(32).random_(50000).long()
        src_word_emb = src_embedding_learnable(rand_src_word_id)
        rand_tgt_word_id = torch.Tensor(32).random_(50000).long()
        tgt_word_emb = target_embedding_learnable(rand_tgt_word_id)
        wsrc_gen = gen(src_word_emb)
        input_tensor = torch.cat([wsrc_gen,tgt_word_emb],0)
        output_tensor = torch.Tensor(64).zero_().float()
        output_tensor[:32] = 1 -0.2 #Smoothing
        output_tensor[32:] = 0.2
        prediction = discrim(input_tensor)
        loss = LossG(prediction,1-output_tensor)
        loss.backward()
        optimG.step()
        mapping_tensor = gen.l1.weight.data
        mapping_tensor.copy_((1.01) * mapping_tensor - 0.01 * mapping_tensor.mm(mapping_tensor.t().mm(mapping_tensor)))

        
    #Validation through proxy parralel dictionary construction (both directions) and CSLS
    src_emb_map_validation = gen(src_embedding_learnable.weight)
    target_emb_map_validation = target_embedding_learnable.weight
    src_emb_map_validation = src_emb_map_validation/src_emb_map_validation.norm(2, 1, keepdim=True).expand_as(src_emb_map_validation)
    target_emb_map_validation = target_emb_map_validation/target_emb_map_validation.norm(2, 1, keepdim=True).expand_as(target_emb_map_validation)
    src_to_target_dictionary = top_words(src_emb_map_validation,target_emb_map_validation)
    target_to_src_dictionary = top_words(target_emb_map_validation,src_emb_map_validation)
    dictionary = proxy_construct_dictionary(src_emb_map_validation,target_emb_map_validation,src_to_target_dictionary,target_to_src_dictionary)
    if dictionary is None:
        mean_cosine = -1e9
    else:
        mean_cosine = (src_emb_map_validation[dictionary[:, 0]] * target_emb_map_validation[dictionary[:, 1]]).sum(1).mean()

    # Dampenining by 0.95
    optimD.param_groups[0]['lr'] = 0.95*optimD.param_groups[0]['lr']
    optimG.param_groups[0]['lr'] = 0.95*optimG.param_groups[0]['lr']

epoch = 0, iteration = 0
epoch = 0, iteration = 10
epoch = 0, iteration = 20
epoch = 0, iteration = 30
epoch = 0, iteration = 40
epoch = 0, iteration = 50
epoch = 0, iteration = 60
epoch = 0, iteration = 70
epoch = 0, iteration = 80
epoch = 0, iteration = 90
epoch = 0, iteration = 100
epoch = 0, iteration = 110
epoch = 0, iteration = 120
epoch = 0, iteration = 130
epoch = 0, iteration = 140
epoch = 0, iteration = 150
epoch = 0, iteration = 160
epoch = 0, iteration = 170
epoch = 0, iteration = 180
epoch = 0, iteration = 190
epoch = 0, iteration = 200
epoch = 0, iteration = 210
epoch = 0, iteration = 220
epoch = 0, iteration = 230
epoch = 0, iteration = 240
epoch = 0, iteration = 250
epoch = 0, iteration = 260
epoch = 0, iteration = 270
epoch = 0, iteration = 280
epoch = 0, iteration = 290
epoch = 0, iteration = 300
epoch = 0, iteration = 310
epoch = 0, iteration = 320
epoch = 0, iteration = 330
epoch = 0, iteration = 340
epoch = 0, iteration = 350
epoch = 0, iteration = 360
epoch = 0, i

epoch = 0, iteration = 2970
epoch = 0, iteration = 2980
epoch = 0, iteration = 2990
epoch = 0, iteration = 3000
epoch = 0, iteration = 3010
epoch = 0, iteration = 3020
epoch = 0, iteration = 3030
epoch = 0, iteration = 3040
epoch = 0, iteration = 3050
epoch = 0, iteration = 3060
epoch = 0, iteration = 3070
epoch = 0, iteration = 3080
epoch = 0, iteration = 3090
epoch = 0, iteration = 3100
epoch = 0, iteration = 3110
epoch = 0, iteration = 3120
epoch = 0, iteration = 3130
epoch = 0, iteration = 3140
epoch = 0, iteration = 3150
epoch = 0, iteration = 3160
epoch = 0, iteration = 3170
epoch = 0, iteration = 3180
epoch = 0, iteration = 3190
epoch = 0, iteration = 3200
epoch = 0, iteration = 3210
epoch = 0, iteration = 3220
epoch = 0, iteration = 3230
epoch = 0, iteration = 3240
epoch = 0, iteration = 3250
epoch = 0, iteration = 3260
epoch = 0, iteration = 3270
epoch = 0, iteration = 3280
epoch = 0, iteration = 3290
epoch = 0, iteration = 3300
epoch = 0, iteration = 3310
epoch = 0, iteration

epoch = 0, iteration = 5900
epoch = 0, iteration = 5910
epoch = 0, iteration = 5920
epoch = 0, iteration = 5930
epoch = 0, iteration = 5940
epoch = 0, iteration = 5950
epoch = 0, iteration = 5960
epoch = 0, iteration = 5970
epoch = 0, iteration = 5980
epoch = 0, iteration = 5990
epoch = 0, iteration = 6000
epoch = 0, iteration = 6010
epoch = 0, iteration = 6020
epoch = 0, iteration = 6030
epoch = 0, iteration = 6040
epoch = 0, iteration = 6050
epoch = 0, iteration = 6060
epoch = 0, iteration = 6070
epoch = 0, iteration = 6080
epoch = 0, iteration = 6090
epoch = 0, iteration = 6100
epoch = 0, iteration = 6110
epoch = 0, iteration = 6120
epoch = 0, iteration = 6130
epoch = 0, iteration = 6140
epoch = 0, iteration = 6150
epoch = 0, iteration = 6160
epoch = 0, iteration = 6170
epoch = 0, iteration = 6180
epoch = 0, iteration = 6190
epoch = 0, iteration = 6200
epoch = 0, iteration = 6210
epoch = 0, iteration = 6220
epoch = 0, iteration = 6230
epoch = 0, iteration = 6240
epoch = 0, iteration

epoch = 0, iteration = 8830
epoch = 0, iteration = 8840
epoch = 0, iteration = 8850
epoch = 0, iteration = 8860
epoch = 0, iteration = 8870
epoch = 0, iteration = 8880
epoch = 0, iteration = 8890
epoch = 0, iteration = 8900
epoch = 0, iteration = 8910
epoch = 0, iteration = 8920
epoch = 0, iteration = 8930
epoch = 0, iteration = 8940
epoch = 0, iteration = 8950
epoch = 0, iteration = 8960
epoch = 0, iteration = 8970
epoch = 0, iteration = 8980
epoch = 0, iteration = 8990
epoch = 0, iteration = 9000
epoch = 0, iteration = 9010
epoch = 0, iteration = 9020
epoch = 0, iteration = 9030
epoch = 0, iteration = 9040
epoch = 0, iteration = 9050
epoch = 0, iteration = 9060
epoch = 0, iteration = 9070
epoch = 0, iteration = 9080
epoch = 0, iteration = 9090
epoch = 0, iteration = 9100
epoch = 0, iteration = 9110
epoch = 0, iteration = 9120
epoch = 0, iteration = 9130
epoch = 0, iteration = 9140
epoch = 0, iteration = 9150
epoch = 0, iteration = 9160
epoch = 0, iteration = 9170
epoch = 0, iteration

epoch = 1, iteration = 1800
epoch = 1, iteration = 1810
epoch = 1, iteration = 1820
epoch = 1, iteration = 1830
epoch = 1, iteration = 1840
epoch = 1, iteration = 1850
epoch = 1, iteration = 1860
epoch = 1, iteration = 1870
epoch = 1, iteration = 1880
epoch = 1, iteration = 1890
epoch = 1, iteration = 1900
epoch = 1, iteration = 1910
epoch = 1, iteration = 1920
epoch = 1, iteration = 1930
epoch = 1, iteration = 1940
epoch = 1, iteration = 1950
epoch = 1, iteration = 1960
epoch = 1, iteration = 1970
epoch = 1, iteration = 1980
epoch = 1, iteration = 1990
epoch = 1, iteration = 2000
epoch = 1, iteration = 2010
epoch = 1, iteration = 2020
epoch = 1, iteration = 2030
epoch = 1, iteration = 2040
epoch = 1, iteration = 2050
epoch = 1, iteration = 2060
epoch = 1, iteration = 2070
epoch = 1, iteration = 2080
epoch = 1, iteration = 2090
epoch = 1, iteration = 2100
epoch = 1, iteration = 2110
epoch = 1, iteration = 2120
epoch = 1, iteration = 2130
epoch = 1, iteration = 2140
epoch = 1, iteration

epoch = 1, iteration = 4730
epoch = 1, iteration = 4740
epoch = 1, iteration = 4750
epoch = 1, iteration = 4760
epoch = 1, iteration = 4770
epoch = 1, iteration = 4780
epoch = 1, iteration = 4790
epoch = 1, iteration = 4800
epoch = 1, iteration = 4810
epoch = 1, iteration = 4820
epoch = 1, iteration = 4830
epoch = 1, iteration = 4840
epoch = 1, iteration = 4850
epoch = 1, iteration = 4860
epoch = 1, iteration = 4870
epoch = 1, iteration = 4880
epoch = 1, iteration = 4890
epoch = 1, iteration = 4900
epoch = 1, iteration = 4910
epoch = 1, iteration = 4920
epoch = 1, iteration = 4930
epoch = 1, iteration = 4940
epoch = 1, iteration = 4950
epoch = 1, iteration = 4960
epoch = 1, iteration = 4970
epoch = 1, iteration = 4980
epoch = 1, iteration = 4990
epoch = 1, iteration = 5000
epoch = 1, iteration = 5010
epoch = 1, iteration = 5020
epoch = 1, iteration = 5030
epoch = 1, iteration = 5040
epoch = 1, iteration = 5050
epoch = 1, iteration = 5060
epoch = 1, iteration = 5070
epoch = 1, iteration

epoch = 1, iteration = 7660
epoch = 1, iteration = 7670
epoch = 1, iteration = 7680
epoch = 1, iteration = 7690
epoch = 1, iteration = 7700
epoch = 1, iteration = 7710
epoch = 1, iteration = 7720
epoch = 1, iteration = 7730
epoch = 1, iteration = 7740
epoch = 1, iteration = 7750
epoch = 1, iteration = 7760
epoch = 1, iteration = 7770
epoch = 1, iteration = 7780
epoch = 1, iteration = 7790
epoch = 1, iteration = 7800
epoch = 1, iteration = 7810
epoch = 1, iteration = 7820
epoch = 1, iteration = 7830
epoch = 1, iteration = 7840
epoch = 1, iteration = 7850
epoch = 1, iteration = 7860
epoch = 1, iteration = 7870
epoch = 1, iteration = 7880
epoch = 1, iteration = 7890
epoch = 1, iteration = 7900
epoch = 1, iteration = 7910
epoch = 1, iteration = 7920
epoch = 1, iteration = 7930
epoch = 1, iteration = 7940
epoch = 1, iteration = 7950
epoch = 1, iteration = 7960
epoch = 1, iteration = 7970
epoch = 1, iteration = 7980
epoch = 1, iteration = 7990
epoch = 1, iteration = 8000
epoch = 1, iteration

epoch = 2, iteration = 620
epoch = 2, iteration = 630
epoch = 2, iteration = 640
epoch = 2, iteration = 650
epoch = 2, iteration = 660
epoch = 2, iteration = 670
epoch = 2, iteration = 680
epoch = 2, iteration = 690
epoch = 2, iteration = 700
epoch = 2, iteration = 710
epoch = 2, iteration = 720
epoch = 2, iteration = 730
epoch = 2, iteration = 740
epoch = 2, iteration = 750
epoch = 2, iteration = 760
epoch = 2, iteration = 770
epoch = 2, iteration = 780
epoch = 2, iteration = 790
epoch = 2, iteration = 800
epoch = 2, iteration = 810
epoch = 2, iteration = 820
epoch = 2, iteration = 830
epoch = 2, iteration = 840
epoch = 2, iteration = 850
epoch = 2, iteration = 860
epoch = 2, iteration = 870
epoch = 2, iteration = 880
epoch = 2, iteration = 890
epoch = 2, iteration = 900
epoch = 2, iteration = 910
epoch = 2, iteration = 920
epoch = 2, iteration = 930
epoch = 2, iteration = 940
epoch = 2, iteration = 950
epoch = 2, iteration = 960
epoch = 2, iteration = 970
epoch = 2, iteration = 980
e

epoch = 2, iteration = 3560
epoch = 2, iteration = 3570
epoch = 2, iteration = 3580
epoch = 2, iteration = 3590
epoch = 2, iteration = 3600
epoch = 2, iteration = 3610
epoch = 2, iteration = 3620
epoch = 2, iteration = 3630
epoch = 2, iteration = 3640
epoch = 2, iteration = 3650
epoch = 2, iteration = 3660
epoch = 2, iteration = 3670
epoch = 2, iteration = 3680
epoch = 2, iteration = 3690
epoch = 2, iteration = 3700
epoch = 2, iteration = 3710
epoch = 2, iteration = 3720
epoch = 2, iteration = 3730
epoch = 2, iteration = 3740
epoch = 2, iteration = 3750
epoch = 2, iteration = 3760
epoch = 2, iteration = 3770
epoch = 2, iteration = 3780
epoch = 2, iteration = 3790
epoch = 2, iteration = 3800
epoch = 2, iteration = 3810
epoch = 2, iteration = 3820
epoch = 2, iteration = 3830
epoch = 2, iteration = 3840
epoch = 2, iteration = 3850
epoch = 2, iteration = 3860
epoch = 2, iteration = 3870
epoch = 2, iteration = 3880
epoch = 2, iteration = 3890
epoch = 2, iteration = 3900
epoch = 2, iteration

epoch = 2, iteration = 6490
epoch = 2, iteration = 6500
epoch = 2, iteration = 6510
epoch = 2, iteration = 6520
epoch = 2, iteration = 6530
epoch = 2, iteration = 6540
epoch = 2, iteration = 6550
epoch = 2, iteration = 6560
epoch = 2, iteration = 6570
epoch = 2, iteration = 6580
epoch = 2, iteration = 6590
epoch = 2, iteration = 6600
epoch = 2, iteration = 6610
epoch = 2, iteration = 6620
epoch = 2, iteration = 6630
epoch = 2, iteration = 6640
epoch = 2, iteration = 6650
epoch = 2, iteration = 6660
epoch = 2, iteration = 6670
epoch = 2, iteration = 6680
epoch = 2, iteration = 6690
epoch = 2, iteration = 6700
epoch = 2, iteration = 6710
epoch = 2, iteration = 6720
epoch = 2, iteration = 6730
epoch = 2, iteration = 6740
epoch = 2, iteration = 6750
epoch = 2, iteration = 6760
epoch = 2, iteration = 6770
epoch = 2, iteration = 6780
epoch = 2, iteration = 6790
epoch = 2, iteration = 6800
epoch = 2, iteration = 6810
epoch = 2, iteration = 6820
epoch = 2, iteration = 6830
epoch = 2, iteration

epoch = 2, iteration = 9420
epoch = 2, iteration = 9430
epoch = 2, iteration = 9440
epoch = 2, iteration = 9450
epoch = 2, iteration = 9460
epoch = 2, iteration = 9470
epoch = 2, iteration = 9480
epoch = 2, iteration = 9490
epoch = 2, iteration = 9500
epoch = 2, iteration = 9510
epoch = 2, iteration = 9520
epoch = 2, iteration = 9530
epoch = 2, iteration = 9540
epoch = 2, iteration = 9550
epoch = 2, iteration = 9560
epoch = 2, iteration = 9570
epoch = 2, iteration = 9580
epoch = 2, iteration = 9590
epoch = 2, iteration = 9600
epoch = 2, iteration = 9610
epoch = 2, iteration = 9620
epoch = 2, iteration = 9630
epoch = 2, iteration = 9640
epoch = 2, iteration = 9650
epoch = 2, iteration = 9660
epoch = 2, iteration = 9670
epoch = 2, iteration = 9680
epoch = 2, iteration = 9690
epoch = 2, iteration = 9700
epoch = 2, iteration = 9710
epoch = 2, iteration = 9720
epoch = 2, iteration = 9730
epoch = 2, iteration = 9740
epoch = 2, iteration = 9750
epoch = 2, iteration = 9760
epoch = 2, iteration

In [149]:
for i in range(100):
    index = src_to_target_dictionary.data.tolist()[i]
    print(src_id2word[index[0]], '--', tgt_id2word[index[1]])

red -- it
label -- as
sportif -- mass
femme -- male
européen -- mass
article -- blood
situation -- total
zones -- body
français -- individuals
terrain -- produce
till -- death
blue -- it
grandes -- unit
autre -- use
protection -- body
black -- it
ferroviaire -- division
francis -- him
gentilé -- mass
deux -- count
fiche -- trial
langues -- body
fond -- production
jeunesse -- body
communication -- body
avenue -- rail
agit -- males
culte -- entire
non -- body
secteur -- mm
espagne -- produce
tombe -- line
} -- claim
populaire -- mark
santé -- log
solution -- body
police -- it
manière -- but
culturel -- every
éd -- body
entrée -- body
comprend -- goals
parties -- it
sans -- males
nombreux -- have
continue -- years
effets -- count
québec -- gas
moyens -- it
gouvernement -- produce
large -- produce
séries -- html
plein -- body
photo -- males
réponse -- standard
mondial -- body
évolution -- females
code -- it
limite -- it
peu -- now
x -- love
peter -- production
créé -- worth
date -- rate
va

In [None]:
    W_trained = gen.l1.weight.data # get the weights of the generator which are the elements of W
    
    # to ensure that the matrix stays close to the manifold of orthogonal matrices after each update
    W_ortho = (1+beta)*W_trained - beta*torch.mm(torch.mm(W_trained, W_trained.t()), W_trained) 
    gen.l1.weight.data = W_ortho