<a href="https://colab.research.google.com/github/muyeby/MlInAction/blob/master/%E2%80%9CTest_pytorch1_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from google.colab import drive
drive.mount('/content/drive/')

%cd /content/drive/My Drive/
%mkdir mywork
%cd /content/drive/My Drive/mywork/
%mkdir data

In [0]:
!curl -Lo data/wiki.en.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.vec
!curl -Lo data/wiki.es.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.es.vec

In [0]:
%mkdir data/dictionaries
!curl -Lo data/dictionaries/en-es.txt https://dl.fbaipublicfiles.com/arrival/dictionaries/en-es.txt
!curl -Lo data/dictionaries/en-es.0-5000.txt https://dl.fbaipublicfiles.com/arrival/dictionaries/en-es.0-5000.txt
!curl -Lo data/dictionaries/en-es.5000-6500.txt https://dl.fbaipublicfiles.com/arrival/dictionaries/en-es.5000-6500.txt

In [0]:
import io
import numpy as np
def load_txt_embeddings(emb_path, full_vocab=True):
    """
    Reload pretrained embeddings from a text file.
    """
    word2id = {}
    vectors = []
    _emb_dim_file = 300
    max_vocab=200000
    
    # load pretrained embeddings
    with io.open(emb_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
        for i, line in enumerate(f):
            if i == 0:
                split = line.split()
                assert len(split) == 2
            else:
                word, vect = line.rstrip().split(' ', 1)
                word = word.lower()
                vect = np.fromstring(vect, sep=' ')
                if np.linalg.norm(vect) == 0:  # avoid to have null embeddings
                    vect[0] = 0.01
                if word in word2id:
                    print("Word {} found twice in embedding file".format(word.encode('utf-8')))
                else:
                    if not vect.shape == (_emb_dim_file,):
                        print("Invalid dimension (%i) for word '%s' in line %i."
                                       % (vect.shape[0], word, i))
                        continue
                    assert vect.shape == (_emb_dim_file,), i
                    word2id[word] = len(word2id)
                    vectors.append(vect[None])
            if max_vocab > 0 and len(word2id) >= max_vocab and not full_vocab:
                break

    assert len(word2id) == len(vectors)
    print("Loaded %i pre-trained word embeddings." % len(vectors))

    # compute new vocabulary / embeddings
    id2word = {v: k for k, v in word2id.items()}
    dico = (id2word, word2id)
    embeddings = np.concatenate(vectors, 0)

    #assert embeddings.size() == (len(dico), params.emb_dim)
    return dico, embeddings


In [0]:
def initialize_exp(seed):
    if seed >= 0:
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

def center_embeddings(emb):
    print('Centering the embeddings')
    mean = emb.mean(0)
    emb = emb-mean
    return emb
def norm_embeddings(emb):
    print('Normalizing the embeddings')
    norms = np.linalg.norm(emb,axis=1,keepdims=True)
    norms[norms == 0] = 1
    emb = emb / norms
    return emb

In [26]:
data_dir='/content/drive/My Drive/mywork/data'
src_file='{}/embeddings/wiki.en.vec'.format(data_dir)
tgt_file='{}/embeddings/wiki.es.vec'.format(data_dir)
src_dico, src_emb = load_txt_embeddings(src_file,full_vocab=False)
tgt_dico, tgt_emb = load_txt_embeddings(tgt_file, full_vocab=False)

print("Centering the word embeddings...")
src_emb = center_embeddings(src_emb)
tgt_emb = center_embeddings(tgt_emb)

src_emb = torch.from_numpy(src_emb).float()
tgt_emb = torch.from_numpy(tgt_emb).float()

if torch.cuda.is_available():
    cuda = torch.device('cuda')
    src_emb = src_emb.to(cuda)
    tgt_emb = tgt_emb.to(cuda)


Loaded 200000 pre-trained word embeddings.
Loaded 200000 pre-trained word embeddings.
Centering the word embeddings...
Centering the embeddings
Centering the embeddings


In [0]:
import torch.nn as nn
import torch.nn.functional as F

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.drop1 = nn.Dropout(0.1)
        self.map1 = nn.Linear(300, 300, bias=False)
        self.map2 = nn.Linear(300, 300, bias=False)
        # nn.init.eye(self.map1.weight)

    def encode(self, x):
        # x = self.drop1(x)
        encoded = self.map1(x)
        return encoded

    def decode(self, z):
        # decoded = F.linear(z, self.map1.weight.t(), bias=None)
        decoded = self.map2(z)
        return decoded

    def forward(self, x):
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return encoded, decoded

In [0]:
import torch.optim as optim
from timeit import default_timer as timer
from torch.autograd import Variable
import sys

class BiAAE(object):
    def __init__(self):
        self.X_AE = AE()
        self.Y_AE = AE()
        self.nets = [self.X_AE,self.Y_AE]
        self.loss_fn2 = torch.nn.CosineSimilarity(dim=1,eps=1e-7)
        
    def weights_init(self, m):          
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal(m.weight)
            if m.bias is not None:
                torch.nn.init.constant(m.bias, 0.01)

    def weights_init2(self, m):  # xavier_normal     
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_normal(m.weight)
            if m.bias is not None:
                torch.nn.init.constant(m.bias, 0.01)

    def weights_init3(self, m):  #
        if isinstance(m, torch.nn.Linear):
            m.weight.data.copy_(torch.diag(torch.ones(self.params.g_input_size)))
    
    def init_state(self,seed=-1):
        if torch.cuda.is_available():
            cuda = torch.device('cuda')
            for net in self.nets:
                net.to(cuda)
            self.loss_fn2 = self.loss_fn2.to(cuda)
            
        self.X_AE.apply(self.weights_init)  #      
        self.Y_AE.apply(self.weights_init)  #
        
    def CORAL(self,source,target):
        d = source.data.shape[1]
        # source covariance
        xm = torch.mean(source, 1, keepdim=True) - source
        xc = torch.matmul(torch.transpose(xm, 0, 1), xm)
        # target covariance
        xmt = torch.mean(target, 1, keepdim=True) - target
        xct = torch.matmul(torch.transpose(xmt, 0, 1), xmt)
        # frobenius norm between source and target
        loss = torch.mean(torch.mul((xc - xct), (xc - xct)))
        loss = loss/(4*d*d)
        return loss
    
    def get_batch_data_fast_new(self, emb_en, emb_it):
        cuda = torch.device('cuda')
        random_en_indices = torch.LongTensor(32).random_(75000)
        random_it_indices = torch.LongTensor(32).random_(75000)
        en_batch = emb_en[random_en_indices.to(cuda)]
        it_batch = emb_it[random_it_indices.to(cuda)]

        return en_batch, it_batch
    
    def train(self,src_dico,tgt_dico,src_emb,tgt_emb,seed):
        
        src_word2id = src_dico[1]
        tgt_word2id = tgt_dico[1]

        en = src_emb
        it = tgt_emb
        AE_optimizer = optim.SGD(filter(lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())), lr=0.1)
        
        G_AB_recon_epochs = []
        G_BA_recon_epochs = []
        g_loss_epochs = []
        coral_loss_epoches = []
        
        try:
            for epoch in range(50):
                G_AB_recon = []
                G_BA_recon = []
                coral_losses = []
                g_losses = []
                start_time = timer()
                
                for mini_batch in range(0, 75000 // 32):
                    AE_optimizer.zero_grad()
                    view_X, view_Y = self.get_batch_data_fast_new(en, it)
                    X_Z = self.X_AE.encode(view_X)
                    X_recon = self.X_AE.decode(X_Z)
                    Y_fake = self.Y_AE.decode(X_Z)
                    L_recon_X = 1.0 - torch.mean(self.loss_fn2(view_X, X_recon))
                    
                    Y_Z = self.Y_AE.encode(view_Y)
                    Y_recon = self.Y_AE.decode(Y_Z)
                    X_fake = self.X_AE.decode(Y_Z)
                    L_recon_Y = 1.0 - torch.mean(self.loss_fn2(view_Y, Y_recon))
                    
                    L_coral = self.CORAL(X_Z,Y_Z)
                    G_loss = 1.0 * (L_recon_X+L_recon_Y) + 1.0 * L_coral
                    
                    G_loss.backward()
                    g_losses.append(G_loss.item())
                    G_AB_recon.append(L_recon_X.item())
                    G_BA_recon.append(L_recon_Y.item())
                    coral_losses.append(L_coral.item())
                    
                    AE_optimizer.step()  # Only optimizes G's parameters
                    
                    sys.stdout.write("[%d/%d] ::                                     Generator Loss: %.3f \r" % (
                                mini_batch, 75000 // 32, np.asscalar(np.mean(g_losses))))
                    sys.stdout.flush()
                    
                G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon)))
                G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon)))
                coral_loss_epoches.append(np.asscalar(np.mean(coral_losses)))
                g_loss_epochs.append(np.asscalar(np.mean(g_losses)))
                
                print( "Epoch {} : Generator Loss: {:.3f}, Coral Loss: {:.3f}, Time elapsed {:.2f} mins".
                        format(epoch, np.asscalar(np.mean(g_losses)),np.asscalar(np.mean(coral_losses)), (timer() - start_time) / 60))
            
            return [G_AB_recon_epochs,G_BA_recon_epochs,coral_loss_epoches,g_loss_epochs]
            
        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(self.X_AE.state_dict(), 'X_model_interrupt.t7')
            torch.save(self.Y_AE.state_dict(), 'Y_model_interrupt.t7')
            log_file.close()
            exit()

In [0]:

init_seed = 430
t = BiAAE()
initialize_exp(init_seed)
t.init_state(seed=init_seed)
t.train(src_dico,tgt_dico,src_emb,tgt_emb,init_seed)



%matplotlib inline
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(range(0, len(G_AB_recon_epochs)), G_AB_recon_epochs, color='b', label='G_AB')
plt.plot(range(0, len(G_BA_recon_epochs)), G_BA_recon_epochs, color='r', label='G_BA')
plt.ylabel('G_recon_loss')
plt.xlabel('epochs')
plt.legend()
fig.savefig('seed_{}_G_Recon.png'.format(seed))
            
fig = plt.figure()
plt.plot(range(0, len(g_loss_epochs)), g_loss_epochs, color='b', label='G_loss')
plt.ylabel('g_loss')
plt.xlabel('epochs')
plt.legend()
fig.savefig('seed_{}_g_loss.png'.format(seed))

fig = plt.figure()
plt.plot(range(0, len(coral_loss_epochs)), coral_loss_epochs, color='b', label='Coral loss')
plt.ylabel('Coral_loss')
plt.xlabel('epochs')
plt.legend()
fig.savefig(self.tune_dir + '/seed_{}_coral_loss.png'.format(seed))
plt.close('all')



  from ipykernel import kernelapp as app


Epoch 0 : Generator Loss: 0.829, Coral Loss: 0.000, Time elapsed 0.29 mins
Epoch 1 : Generator Loss: 0.312, Coral Loss: 0.000, Time elapsed 0.30 mins
Epoch 2 : Generator Loss: 0.166, Coral Loss: 0.000, Time elapsed 0.29 mins
