In [10]:
import torch 
import pandas as pd
import numpy as np
from dataset import LuciferaseDataset
from models import fc_encoder, fc_decoder, MSA_VAE
from utils import seq_to_ohe, ohe_to_seq

In [6]:
# sample 3000
# pca encoded

In [8]:
ENCODER_KWARGS = {'latent_dim' : 10,
        'seq_len' : 360,
        'encoder_hidden' : [256, 256],
        'encoder_dropout' : [0, 0],
        }

DECODER_KWARGS = {'latent_dim' : 10,
        'seq_len' : 360,
        'decoder_hidden' : [256, 256],
        'decoder_dropout' : [0, 0]}

luxA = 'MKFGNFLLTYQPPQFSQTEVMKRLVKLGRISEECGFDTVWLLEHHFTEFGLLGNPYVAAAYLLGATKKLNVGTAAIVLPTAHPVRQLEDVNLLDQMSKGRFRFGICRGLYNKDFRVFGTDMNNSRALAECWYGLIKNGMTEGYMEADNEHIKFHKVKVNPAAYSRGGAPVYVVAESASTTEWAAQFGLPMILSWIINTNEKKAQLELYNEVAQEYGHDIHNIDHCLSYITSVDHDSIKAKEICRKFLGHWYDSYVNATTIFDDSDQTRGYDFNKGQWRDFVLKGHKDTNRRIDYSYEINPVGTPQECIDIIQKDIDATGISNICCGFEANGTVDEIIASMKLFQSDVMPFLKEKQRSLLY'

In [9]:
print(len(luxA))

360


In [108]:
def sample(model_checkpoint, n_samples, stdev=0.01):
    
    # load model
    encoder = fc_encoder(**ENCODER_KWARGS)
    decoder = fc_decoder(**DECODER_KWARGS)
    model = MSA_VAE(encoder, decoder)
    checkpoint = torch.load(model_checkpoint, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # encode template luxA
    luxa_ohe = torch.tensor(seq_to_ohe(luxA)).unsqueeze(0).reshape(1, 21, 360).float()
    luxa_ohe = luxa_ohe.repeat(n_samples, 1, 1)
    
    mu, logvar = model.encoder(luxa_ohe)
    z = model.reparameterize(mu, logvar)
    
    # add noise
    if stdev > 0.0:
        noise = torch.normal(0, stdev, z.shape)
        z += noise
    
    # decode
    decoded = model.decoder(z).detach().permute(0, 2, 1).numpy()
    
    # ohe to sequence
    seqs = []
    for i in range(decoded.shape[0]):
        seq = ohe_to_seq(decoded[i, ...])
        seqs.append(seq)
        
    df = pd.DataFrame(seqs, columns='generated')
    df['sample'] = []
    
    return df
    
    

In [109]:
decoded = sample('../checkpoints/luciferase_14.pt', 3000, 0)

In [110]:
decoded

Unnamed: 0,0
0,EAKRAAGAHQAAACVGWEADFCAIAQAARAGHIMSAAND-AAAREA...
1,EAKRAAGAHQAAACVGWEADFCAIAQAARAGHIMSAAND-AAAREA...
2,EAKRAAGAHQAAACVGWEADFCAIAQAARAGHIMSAAND-AAAREA...
3,EAKRAAGAHQAAACVPWEADFCAIAQSARAGHIMSAAND-AAKREA...
4,EAKAAAGAHQAAACVPWEADFCAIAQSARAGHAMSAAND-AAAREA...
...,...
2995,EAKRAAAAHQAAACVGWEADFCAIAQAARAGHIMSAAND-AAKREA...
2996,EAKRAAGAHQAAACVGWEADFCAIAQSARAGHAMSAAND-AAKREA...
2997,EAKRAAGAHQAAACVGWEADFCAIAQAARAGHIMSAAND-AAAREA...
2998,EAKRAAAAHQAAACVGWEADFCAIAQAARAGHIMSAAND-AAAREA...
