In [1]:
import pickle
import numpy as np
import json
import torch
import itertools

In [2]:
k = 10000 # number of samples used to calculate loss
N_GRAM = 3
FILE_PATH = '/media/scratch/yuhaowan/data/'

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
na_list = ['A', 'C', 'G', 'T'] #nucleic acids for aptamer
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids for peptide
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))

In [4]:
# Get dictionary of {token: index} for all possible ngrams with size <= N_GRAM
def get_vocab():
    def generator(letters):
        letters = "".join(letters)
        for n in range(1, N_GRAM+1):
            for item in itertools.product(letters, repeat=n):
                yield "".join(item)
    a = [i for i in generator(na_list)]
    p = [i for i in generator(aa_list)]
    vocab_apt = {a[i]: i for i in range(len(a))}
    vocab_pep = {p[i]: i for i in range(len(p))}
    return vocab_apt, vocab_pep

In [5]:
vocab_apt, vocab_pep = get_vocab()
VOCAB_SIZE_APT = len(vocab_apt) #84
VOCAB_SIZE_PEP = len(vocab_pep) #8420

In [6]:
# Iterates the broken-down tokens of the given sequence with N_GRAM
def ngrams_iterator(seq):
    for char in seq:
        yield char
    for n in range(2, N_GRAM + 1):
        for char in zip(*[seq[i:] for i in range(n)]):
            yield ''.join(char)

            
# Encodes aptamer/peptide to a binary vector, 1 being the correspoinding ngram is present
def binary_encoding(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        vocab_size = VOCAB_SIZE_PEP
        vocab = vocab_pep
    else:
        vocab_size = VOCAB_SIZE_APT
        vocab = vocab_apt
    x = torch.zeros(vocab_size)
    for i in ngrams_iterator(sequence):
        x[vocab[i]] = 1
    x = x.cuda()
    return x


## Takes a peptide and aptamer sequence and converts to binary matrix
def one_hot(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    one_hot = np.zeros((len(sequence), len(letters)))
    for i in range(len(sequence)):
        char = sequence[i]
        for _ in range(len(letters)):
            idx = letters.index(char)
            one_hot[i][idx] = 1
    return one_hot


# Convert a pair to one-hot tensor
def convert(apt, pep): 
    apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1]))).cuda() #(1, 40, 4)
    pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1]))).cuda() #(1, 8, 20)
    return apt, pep


def construct_dataset():
    with open(dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    ds = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide in peptides:
            pep_pmf = get_y_pmf(peptide)
            ds.append((aptamer, peptide, pep_pmf))
    ds = list(set(ds)) #removed duplicates
    return ds


# Sample x from P_X (assume apatamers follow uniform)
def get_x():
    x_idx = np.random.randint(0, 4, 40)
    x = ""
    for i in x_idx:
        x += na_list[i]
    return x


# Sample y from P_y (assume peptides follow NNK)
def get_y():
    y_idx = np.random.choice(20, 7, p=pvals)
    y = "M"
    for i in y_idx:
        y += aa_list[i]
    y_pmf = get_y_pmf(y)
    return y, y_pmf


# S'(train/test) contains S_train/S_test with double the size of S_train/S_test
def get_S_prime(kind="train"):
    if kind == "train":
        dset = S_train
    else:
        dset = S_test
    k = len(dset)
    S_prime_dict = dict.fromkeys(dset, 0) #indicator 0 means in S
    for _ in range(k):
        x = get_x()
        y, pmf = get_y()
        pair = (x, y, pmf)
        S_prime_dict[pair] = 1 #indicator 1 means not in S
    S_prime = [[k,int(v)] for k,v in S_prime_dict.items()] 
    np.random.shuffle(S_prime)
    return S_prime


# S new contains unseen new examples
def get_S_new(k):
    S_new = []
    for i in range(k):
        x = get_x()
        y, pmf = get_y()
        S_new.append((x, y, pmf))
    np.random.shuffle(S_new)
    return S_new
    
    
# Returns pmf of an aptamer
def get_x_pmf():
    return 0.25**40


# Returns pmf of a peptide
def get_y_pmf(y):
    pmf = 1
    for char in y[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf    

In [7]:
dataset_file = "../data/aptamer_dataset.json"
S = construct_dataset()
n = len(S)
m = int(0.8*n) #length of S_train
S_train = S[:m]
S_test = S[m:]
S_prime_train = get_S_prime("train") #use for sgd 
S_prime_test = get_S_prime("test") #use for sgd 
S_new = get_S_new(10*n) #use for eval
train_ds = np.hstack((S_train, S_prime_train[:len(S_prime_train)//2]))

In [8]:
# Generates the samples used to calculate loss
def loss_samples(k, ds='train'): # S_train/S_test
    if ds == 'train':
        dset = S_train
    else:
        dset = S_test
    pairs = []
    for (apt, pep, _) in dset[:k]:
        pairs.append((apt, pep))
    return pairs


# Generates the samples used to calculate loss from S_prime_train/S_prime_test
def prime_loss_samples(k, ds='train'):
    if ds == "train":
        dset = S_prime_train[len(S_prime_train)//2:]    
    else:
        dset = S_prime_test[len(S_prime_test)//2:]
    items = []
    for triple, ind in dset[:k]:
        items.append((triple[0], triple[1], triple[2], ind))
    return items

In [9]:
train_loss_samples = loss_samples(k, 'train') #apt, pep, pep_pmf
test_loss_samples = loss_samples(k, 'test')
prime_train_loss_samples = prime_loss_samples(k, 'train') #apt, pep, pep_pmf, indicator
prime_test_loss_samples = prime_loss_samples(k, 'test')

In [10]:
ptr_loss_n = open(FILE_PATH + 'ptr_loss_n.pkl', 'wb')
ptr_loss = open(FILE_PATH + 'ptr_loss.pkl', 'wb')

In [11]:
for apt, pep, pmf, ind in prime_train_loss_samples:
    a_n = binary_encoding(apt, seq_type='apt')    
    p_n = binary_encoding(pep)
    a = one_hot(apt, seq_type='apt')
    p = one_hot(pep)
    x , y = convert(a, p)
    pickle.dump((a_n, p_n, pmf, ind), ptr_loss_n)    
    pickle.dump((x, y, pmf, ind), ptr_loss)
ptr_loss_n.close()
ptr_loss.close()

In [12]:
pte_loss_n = open(FILE_PATH + 'pte_loss_n.pkl', 'wb')
pte_loss = open(FILE_PATH + 'pte_loss.pkl', 'wb')

In [14]:
for apt, pep, pmf, ind in prime_test_loss_samples:
    a_n = binary_encoding(apt, seq_type='apt')    
    p_n = binary_encoding(pep)
    a = one_hot(apt, seq_type='apt')
    p = one_hot(pep)
    x, y = convert(a, p)
    pickle.dump((a_n, p_n, pmf, ind), pte_loss_n)    
    pickle.dump((x, y, pmf, ind), pte_loss)
pte_loss_n.close()
pte_loss.close()

In [15]:
te_loss_n = open(FILE_PATH + 'te_loss_n.pkl', 'wb')
te_loss = open(FILE_PATH + 'te_loss.pkl', 'wb')

In [16]:
for apt, pep in test_loss_samples:
    a_n = binary_encoding(apt, seq_type='apt')    
    p_n = binary_encoding(pep)
    a = one_hot(apt, seq_type='apt')
    p = one_hot(pep)
    pickle.dump((a_n, p_n), te_loss_n)    
    pickle.dump(convert(a, p), te_loss)
te_loss_n.close()
te_loss.close()

In [17]:
tr_loss_n = open(FILE_PATH + 'tr_loss_n.pkl', 'wb')
tr_loss = open(FILE_PATH + 'tr_loss.pkl', 'wb')

In [18]:
for apt, pep in train_loss_samples:
    a_n = binary_encoding(apt, seq_type='apt')    
    p_n = binary_encoding(pep)
    a = one_hot(apt, seq_type='apt')
    p = one_hot(pep)
    pickle.dump((a_n, p_n), tr_loss_n)    
    pickle.dump(convert(a, p), tr_loss)
tr_loss_n.close()
tr_loss.close()

In [19]:
S_tr_n = open(FILE_PATH + 's_tr_n.pkl', 'wb')
S_tr = open(FILE_PATH + 's_tr.pkl', 'wb')

In [20]:
for aptamer, peptide, _, (apt_prime, pep_prime, pep_pmf), indicator in train_ds:
    a_n = binary_encoding(aptamer, seq_type='apt')
    a_p_n = binary_encoding(apt_prime, seq_type='apt')
    p_n = binary_encoding(peptide)
    p_p_n = binary_encoding(pep_prime)

    a = one_hot(aptamer, seq_type='apt')
    a_p = one_hot(apt_prime, seq_type='apt')
    p = one_hot(peptide)
    p_p = one_hot(pep_prime)

    x, y = convert(a, p)
    x_p, y_p = convert(a_p, p_p)

    pickle.dump((a_n, p_n, a_p_n, p_p_n, pep_pmf, indicator), S_tr_n)
    pickle.dump((x, y, x_p, y_p, pep_pmf, indicator), S_tr)
S_tr_n.close()
S_tr.close()

In [21]:
S_te_n = open(FILE_PATH + 's_te_n.pkl', 'wb')
S_te = open(FILE_PATH + 's_te.pkl', 'wb')

In [22]:
for aptamer, peptide, _ in S_test:
    a_n = binary_encoding(aptamer, seq_type='apt')
    p_n = binary_encoding(peptide)
    a = one_hot(aptamer, seq_type='apt')
    p = one_hot(peptide)
    x, y = convert(a, p)
    pickle.dump((a_n, p_n), S_te_n)
    pickle.dump((x, y), S_te)
S_te_n.close()
S_te.close()

In [23]:
S_n = open(FILE_PATH + 's_new.pkl', 'wb')
S_new_n = open(FILE_PATH + 's_new_n.pkl', 'wb')

In [24]:
for aptamer, peptide, _ in S_new:
    a_n = binary_encoding(aptamer, seq_type='apt')
    p_n = binary_encoding(peptide)
    a = one_hot(aptamer, seq_type='apt')
    p = one_hot(peptide)
    x, y = convert(a, p)
    pickle.dump((a_n, p_n), S_new_n)
    pickle.dump((x, y), S_n)
S_new_n.close()
S_n.close()