# Embeddings

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import os
import numpy as np
import utils
import bcolz
import data
import fastai
import pickle
from torch.utils.data import DataLoader
from tqdm import tqdm

Using TensorFlow backend.


In [3]:
data_path = os.curdir + '/data/'
pdb_path = data_path + 'pdb/structures/pdb/'
aa2vec = bcolz.open(data_path + 'c5_embs.bc')

In [4]:
aa2ix = {'G': 0,'P': 1,'A': 2,'V': 3,'L': 4,
         'I': 5,'M': 6,'C': 7,'F': 8,'Y': 9,
         'W': 10,'H': 11,'K': 12,'R': 13,'Q': 14,
         'N': 15,'E': 16,'D': 17,'S': 18,'T': 19}
ix2aa = {}
for k,v in aa2ix.iteritems():
    ix2aa[v] = k

In [5]:
context_sz = 2
emb_sz = 50

In [6]:
dataset = data.ProteinDataset(data_path, '1', encoding=None)
dloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=data.sequence_collate)

In [7]:
for i, sampled_batch in enumerate(dloader):
    sequence = np.array(sampled_batch['sequence'])[0][0]
    if i == 2: 
        break

In [8]:
def get_context(sequence, context_sz):
    data = []
    for i in range(context_sz, len(sequence)-context_sz):
        context = [sequence[i+pos] for pos in range(-context_sz, context_sz+1) if pos != 0]
        target = sequence[i]
        
        data.append((context,target))
        
    return data

In [9]:
def make_context_vector(context, aa2ix):
    idxs = [aa2ix[aa] for aa in context] 
    return torch.tensor(idxs, dtype=torch.long)

In [36]:
aa2vec = bcolz.open(data_path + 'c3_embs.bc')

def create_emb_layer(aa2vec, requires_grad=True):
        aa2vec = torch.tensor(aa2vec, requires_grad=requires_grad)
        vocab_sz, embed_dim = aa2vec.size()
        emb_layer = nn.Embedding(vocab_sz, embed_dim)
        emb_layer.load_state_dict({'weight': aa2vec})
        if requires_grad == False:
            emb_layer.weight.requires_grad=False
            
        return emb_layer, vocab_sz, embed_dim

In [24]:
class CBOW(nn.Module):
    def __init__(self, aa2vec, linear_units):
        super(CBOW, self).__init__()
        
        self.embeds, vocab_sz, embed_dim = create_emb_layer(aa2vec, requires_grad=True)
        #self.embeds = nn.Embedding(vocab_sz, embed_dim)
        
        self.linear1 = nn.Linear(embed_dim, linear_units)
        self.act1 = nn.ReLU()
        self.linear2 = nn.Linear(linear_units, vocab_sz)
        self.act2 = nn.LogSoftmax(dim=-1)
        
    def save_emb_vecs(self, save_path):
        aav = self.embeds.weight.detach().numpy()
        aav = bcolz.carray(aav, rootdir=save_path, mode='w')
        aav.flush()
        
    def forward(self, inputs):
        emb = sum(self.embeds(inputs)).view(1,-1)
        out = self.linear1(emb)
        out = self.act1(out)
        out = self.linear2(out)
        out = self.act2(out)
        
        return out

In [32]:
model = CBOW(aa2vec=aa2vec, linear_units=128)
loss = nn.NLLLoss()
opt = optim.SGD(model.parameters(), lr=1e-3)

In [34]:
for epoch in range(20):
    total_loss = 0
    for i, sampled_batch in tqdm(enumerate(dloader)):
        sequence = sampled_batch['sequence'][0][0]
        d = get_context(sequence, 3)
        for context,target in d:
            try:
                c_vec = make_context_vector(context, aa2ix)
                model.zero_grad()
                log_probs = model(c_vec)
                l = loss(log_probs, torch.tensor(aa2ix[target]).view(1))
                l.backward()
                opt.step()

                total_loss += l.data
            except KeyboardInterrupt:
                raise
            except:
                pass
        
    #save embedding weights
    model.save_emb_vecs(save_path=data_path + 'c3_embs.bc') 
        
    print(total_loss/i)

2047it [05:21,  6.37it/s]
1it [00:00,  8.05it/s]

tensor(741.3653)


2047it [05:17,  6.45it/s]
0it [00:00, ?it/s]

tensor(741.3047)


2047it [05:19,  6.41it/s]
0it [00:00, ?it/s]

tensor(741.3055)


2047it [05:18,  6.44it/s]
1it [00:00,  7.56it/s]

tensor(741.1518)


2047it [05:17,  6.45it/s]
1it [00:00,  5.26it/s]

tensor(741.0647)


2047it [05:18,  6.43it/s]
1it [00:00,  5.78it/s]

tensor(741.0394)


2047it [05:17,  6.44it/s]
1it [00:00,  6.71it/s]

tensor(741.0118)


2047it [05:18,  6.43it/s]
0it [00:00, ?it/s]

tensor(740.8817)


2047it [05:17,  6.45it/s]
2it [00:00, 12.71it/s]

tensor(740.8979)


2047it [05:18,  6.43it/s]
2it [00:00, 15.69it/s]

tensor(740.8441)


2047it [05:18,  6.43it/s]
1it [00:00,  9.83it/s]

tensor(740.8369)


2047it [05:19,  6.42it/s]
0it [00:00, ?it/s]

tensor(740.7706)


2047it [05:19,  6.40it/s]
0it [00:00, ?it/s]

tensor(740.6522)


2047it [05:21,  6.36it/s]
1it [00:00,  5.20it/s]

tensor(740.6934)


2047it [05:18,  6.42it/s]
1it [00:00,  5.33it/s]

tensor(740.6207)


978it [02:32,  6.40it/s]

KeyboardInterrupt: 

In [35]:
model.save_emb_vecs(save_path=data_path + 'c3_embs.bc') 

In [45]:
aa_vectors = {aa: aa2vec[aa2ix[aa]] for aa in aa2ix.keys()}