In [1]:
import h5py
from datetime import datetime
import os
import pickle
import argparse
import itertools
import pandas as pd

import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

from transformer import Models
from transformer import Beam
from transformer import Translator
from transformer.Optim import ScheduledOptim

from tools import CharacterTable
from translator import SignalTranslator
torch.cuda.synchronize()

with open('../outputs/ctable_token.pkl', 'rb') as f:
    ctable = pickle.load(f)

  from ._conv import register_converters as _register_converters
  (fname, cnt))
  (fname, cnt))


In [2]:
with open('../data/gen_test.pkl', 'rb') as f:
    test = pickle.load(f)
len(test[0])

100

In [3]:
# Helper functions for tokenizing new inputs
alphabet = ' .$ACDEFGHIKLMNPQRSTUVWXYZ'
max_len_in = 107 # max length of prot seq (105 aa) + 2 for tokens
max_len_out = 72
n_chars = len(alphabet)

with open('../data/ctable_copies/ctable_token_master.pkl', 'rb') as f:
    ctable = pickle.load(f)

def encode(seqs, max_len, ctable):
    if ctable.one_hot:
        X = np.zeros((len(seqs), max_len, n_chars))
    else:
        X = np.zeros((len(seqs), max_len))
    seqs = ['$' + seq + '.' for seq in seqs]
    seqs = [seq + ' ' * ((max_len) - len(seq))for seq in seqs]
    for i, seq in enumerate(seqs):
        X[i] = ctable.encode(seq, max_len)
    return X

def to_h5py(seqs, fname, ctable):
    chunksize = 500
    with h5py.File(fname, 'w') as f:
        if ctable.one_hot:
            print('true')
            X = f.create_dataset('X', (len(seqs), max_len_in, n_chars))
        else:
            X = f.create_dataset('X', (len(seqs), max_len_in))          
        for i in range(0, len(seqs), chunksize):
            X[i:i + chunksize, :] = encode([seq for seq in seqs[i:i+chunksize]], max_len_in, ctable)
        left = len(seqs) % chunksize
        if left > 0:
            X[-left:, :] = encode([seq for seq in seqs[-left:]], max_len_in, ctable)

In [4]:
# Load sample input, convert to h5py file for generator
df = pd.read_excel('../data/example_test_input.xlsx')
test_seqs = df['protein_sequence'].values
test_seqs = [s[:100] for s in test_seqs]
test_filename = ('../data/example_test_tokens.hdf5')
to_h5py(test_seqs, test_filename, ctable)

In [5]:
# Load a Model Checkpoint
chkpt_name = 'SIM99_550_12500_64_6_5_0.1_64_100_0.0001_-0.03_99'
chkpt = "../outputs/models/model_checkpoints/" + chkpt_name + ".chkpt"
clf = SignalTranslator.load_model(chkpt)

Namespace(cuda=True, d_inner_hid=1100, d_k=64, d_model=550, d_v=64, d_word_vec=550, dropout=0.1, embs_share_weight=True, max_token_seq_len=107, n_head=5, n_layers=6, proj_share_weight=True, src_vocab_size=27, tgt_vocab_size=27) Namespace(beam_size=1, ctable=<tools.CharacterTable object at 0x7ff7a23905c0>, max_trans_length=72, n_best=1) Namespace(d_model=None, decay_power=-0.03, lr_max=0.0001, n_warmup_steps=12500, optim=<class 'torch.optim.adam.Adam'>)
position_encoding
position_encoding
Initiated Transformer with 27403200 parameters.


In [6]:
# test_gen_data = []
# Generate SPs for Proteins
file = h5py.File(test_filename)
training_data = SignalTranslator.generator_from_h5_noy(file, 64, shuffle=False, use_cuda=True)
src = next(training_data) # src is prot sequence, tgt is signal pep
file.close()
clf_outputs  = clf.translate_batch(src, 5)
decoded, all_hyp, all_scores, enc_outputs, dec_outputs,  \
    enc_slf_attns, dec_slf_attns, dec_enc_attn = clf_outputs

for src, dec in zip(src[0], decoded):
#     print(ctable.decode(src.data.cpu().numpy())[:]) # prot sequence from Zach's excel
#     print(dec) # model's predictions
#     print()
    
    input_seq = ctable.decode(src.data.cpu().numpy())[:]
    output_seq = dec

  result = self.forward(*input, **kwargs)
  out = self.model.prob_projection(dec_output)


In [7]:
# Reading in Model criterion

In [8]:
import transformer.Constants as Constants

In [9]:
def get_criterion(vocab_size):
    ''' With PAD token zero weight '''
    weight = torch.ones(vocab_size)
    weight[Constants.PAD] = 0
    return nn.CrossEntropyLoss(weight, size_average=False)

crit = get_criterion(27)

if clf.cuda:
    crit = crit.cuda()

In [10]:
# Accuracy function
def num_matches(tgt, dec):
    ''' Return number of matches between true target and decoded sequence
    '''
    true = ctable.decode(tgt.data.cpu().numpy())[:].strip()
    true = true[1:]
    gen = dec.strip()  
    
    diff = len(true) - len(gen)
    if diff > 0:
        pad = '_'*diff
        gen = gen + pad

    matches = 0
    for i, aa in enumerate(true):
        if gen[i]==aa:
            matches += 1
    
    return matches

In [13]:
# Get validation files
# val_file = h5py.File('../data/validate_tokens.hdf5')
val_file = h5py.File('../data/filtered_datasets/validate_tokens_99.hdf5')
val_dataloader = SignalTranslator.generator_from_h5(val_file, batch_size=1, shuffle=False, use_cuda=True)

clf.model.eval(); # Evaluation Mode

log_lks = []
accs = []
for i, batch in enumerate(val_dataloader):
    src, tgt = batch
    trans_outs = clf.translate_batch(src, 5) # predict signal pep from src (prot seq)
    decoded, all_hyp, all_scores, enc_outputs, dec_outputs, \
        enc_slf_attns, dec_slf_attns, dec_enc_attn = trans_outs

    scores = [i.cpu().numpy()[0] for i in all_scores]
    log_lks += scores

    for tgt, dec in zip(tgt[0], decoded):
        matches = num_matches(tgt,dec)
        accs.append(matches/len(tgt[tgt!=0]))

print('accuracy mean/std:', np.average(accs), np.std(accs))
print('Log likelihood mean/std:', np.average(log_lks), np.std(log_lks))    
# print(len(log_lks), np.average(log_lks), np.std(log_lks))

  result = self.forward(*input, **kwargs)
  out = self.model.prob_projection(dec_output)


accuracy mean/std: 0.34501638552782715 0.3010739483939744
Log likelihood mean/std: -3.0666218 1.9270155


In [14]:
print('accuracy mean/std:', np.average(accs), np.std(accs))
print('Log likelihood mean/std:', np.average(log_lks), np.std(log_lks)) 

accuracy mean/std: 0.34501638552782715 0.3010739483939744
Log likelihood mean/std: -3.0666218 1.9270155
