### IMPORTS

In [1]:

import pickle
import glob
import random

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [2]:
%matplotlib inline
import seaborn as sns

In [3]:
# fix random state
random_state = np.random.RandomState(42)
random.seed = 42

In [None]:
## One-hot Functions 

In [None]:
n_categories = 3
all_categories =  [('H', 'G', 'I' ), 
                   ('E','B'), 
                   ('S','T','C','-')]


all_letters = ['_','C', 'V', 'T', 'F', 'Y', 'A', 'P', 'W', 'I', 'M', 'L', 'S', 'G', 'H', 'D', 'E', 'N', 'Q', 'R', 'K']
n_letters = len(all_letters)



# One-hot vector for category
def categoryTensor(category):
    li = all_categories.index(category)
    tensor = torch.zeros(1, n_categories)
    tensor[0][li] = 1
    return tensor

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.index(letter)] = 1
    return tensor

# LongTensor of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.index(line[li]) for li in range(1, len(line))]
    letter_indexes.append( all_letters.index('_' )) # EOS
    return torch.LongTensor(letter_indexes)

In [None]:
def from_vec2char(tensor):
    for k in tensor[0]:
        for idx, i in enumerate(k):
            if i.data.cpu().numpy() == 1:
                return int_to_char[idx]

            
            
def from_code2chars(tensor):
    resu = list()
    for a in tensor.data.cpu().numpy():
        resu.append(int_to_char[a])
    return ''.join(resu[:-1])


def decode_library(lib):
    renature = list()
    for item in lib:
        f = from_vec2char(item[1])
        seq = from_code2chars(item[2])
        renature.append(f+seq)
    return renature



In [1]:
# Model helper functions, generate



def sample_temperature(x, temperature=1.0):
    x = x.reshape(-1).astype(np.float)
    x /= temperature
    x = np.exp(x)
    x /= np.sum(x)
    x = random_state.multinomial(1, x)
    x = np.argmax(x)
    return x.astype(np.int64)



def sample(category, start_letter='A', temp=.8):
    max_length= 9
    with torch.no_grad():  # no need to track history in sampling
        category_tensor = categoryTensor(category)
        input = inputTensor(start_letter)
        hidden = rnn.initHidden()

        output_name = start_letter

        for i in range(max_length):
            output, hidden = rnn(category_tensor, input[0], hidden)
            #topv, topi = output.topk(1)
            #topi = topi[0][0]
            o = output.cpu()
            topi = sample_temperature(o.data.numpy(), temperature=temp)
            #if topi == n_letters - 1:
            #    break
            #else:
            letter = all_letters[topi]
            output_name += letter
            input = inputTensor(letter)

        return output_name


def library_generation( name,n=100, kind='H', t=.65):
    
    labels = {'H':('H', 'G', 'I' ),
             'E': ('E','B'),
             'C':('S','T','C','-')}
    firsttest = list()
    for i in range(n):
        a = random.choice(all_letters)
        seq = sample(  labels[kind], a, temp=t)
        if not '_' in seq:
            firsttest.append(seq)

    for i in firsttest:
        o = open('{}_{}_{}_t{}.fasta'.format(name,kind,i,t),'w')
        print('>'+i, file=o)
        print(i, file=o)

        o.close()

        
def fillthegaps(kind, motifs='AXPXXXPXXXK', temp=.65):
    max_length= 9
    
    labels = {'H':('H', 'G', 'I' ),
             'E': ('E','B'),
             'C':('S','T','C','-')}
    
    with torch.no_grad():  # no need to track history in sampling
        category_tensor = categoryTensor(labels[kind])
        
        hidden = rnn.initHidden()
        try:
            input = inputTensor(motifs[0])
        except:
            aa = random.choice(['C', 'V', 'T', 'F', 'Y', 'A', 'P', 'W', 'I', 'M', 'L', 'S', 'G', 'H', 'D', 'E', 'N', 'Q', 'R', 'K'])
            input = inputTensor(aa)
        output_name = list()
        
        hidden = rnn.initHidden()
        for i in range(len(motifs)):
            
            if motifs[i] != 'X':
                # reset hidden layer every time? ///// probably I need a context layer
                output_name.append(motifs[i])
                input = inputTensor(motifs[i])
                
            
            if motifs[i] == 'X':
            
                output, hidden = rnn(category_tensor, input[0], hidden)
                #topv, topi = output.topk(1)
                #topi = topi[0][0]
                o = output.cpu()
                topi = sample_temperature(o.data.numpy(), temperature=temp)
                #if topi == n_letters - 1:
                #    break
                #else:

                letter = all_letters[topi]
                output_name.append( letter)
                input = inputTensor(letter)

        return ''.join(output_name)

In [None]:
# plotting & Analysis

def get_avgcontent(path):
    results = dict()
    folders = glob.glob(path)
    data = {'C':list(),
            'H': list(),
            'E': list()}
    
    for f in folders:
        out = glob.glob(f+'/*.ss2')
        if len(out) > 0:
            dat = pd.read_csv(out[0], skip_blank_lines=True, skiprows=1, names=['idx','res','SS','Prob_C','Prob_H','Prob_E'], delim_whitespace=True)
            for struct, garbage_bin in data.items():
                garbage_bin.append(dat[dat['SS']==struct].shape[0]/10)
        #s = dat[dat['SS']==]
        
    for struct, garbage_bin in data.items():
        #t = 0
        #garbage_bin:
        #    if i > 4:
                #print(i)
                #t +=1
        results[struct] = garbage_bin
    return results

In [17]:
import subprocess

In [3]:
def run_psipred(seq):
        o = open('{}.fasta'.format(seq),'w')
        print('>'+seq, file=o)
        print(seq, file=o)

        o.close()
        
        p = subprocess.Popen(['run_psipred.pl', 
                              seq+'.fasta', 
                              '-d',
                              '/Users/ccorbi/Desktop/DEMO_insights/uniprot_sprot.fasta', '-o', seq])
        p.wait()
        
        dat = pd.read_csv('./{0}/{0}.fasta.ss2'.format(seq), skip_blank_lines=True, skiprows=1, names=['idx','res','SS','Prob_C','Prob_H','Prob_E'], delim_whitespace=True)
        pred = ''.join(dat['SS'].get_values())
        return pred
    