In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn

from captum.attr import (
    GradientShap,
    DeepLift,
    DeepLiftShap,
    IntegratedGradients,
    LayerConductance,
    NeuronConductance,
    NoiseTunnel
)

import torch.nn.functional as F
from torch.optim import SGD, Adam
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import pairwise2

## Setup

In [2]:
torch.manual_seed(12345)
k = 10000
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
na_list = ['A', 'C', 'G', 'T'] #nucleic acids
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids
hydrophobicity = {'G': 0, 'A': 41, 'L':97, 'M': 74, 'F':100, 'W':97, 'K':-23, 'Q':-10, 'E':-31, 'S':-5, 'P':-46, 'V':76, 'I':99, 'C':49, 'Y':63, 'H':8, 'R':-14, 'N':-28, 'D':-55, 'T':13}
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))

## Dataset & Sampling

In [3]:
def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    ds = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide in peptides:
            ds.append((aptamer, peptide))
    ds = list(set(ds)) #removed duplicates
    return ds

# Sample x from P_X (assume apatamers follow uniform)
def get_x():
    x_idx = np.random.randint(0, 4, 40)
    x = ""
    for i in x_idx:
        x += na_list[i]
    return x

# Sample y from P_y (assume peptides follow NNK)
def get_y():
    y_idx = np.random.choice(20, 7, p=pvals)
    y = "M"
    for i in y_idx:
        y += aa_list[i]
    return y

# S'(train/test) contains S_train/S_test with double the size of S_train/S_test
def get_S_prime(kind="train"):
    if kind == "train":
        dset = S_train
    else:
        dset = S_test
    k = len(dset)
    S_prime_dict = dict.fromkeys(dset, 0) #indicator 0 means in S
    for _ in range(k):
        pair = (get_x(), get_y())
        S_prime_dict[pair] = 1 #indicator 1 means not in S
    S_prime = [[k,int(v)] for k,v in S_prime_dict.items()] 
    np.random.shuffle(S_prime)
    return S_prime

# S new contains unseen new examples
def get_S_new(k):
    S_new = []
    for i in range(k):
        pair = (get_x(), get_y())
        S_new.append(pair)
    np.random.shuffle(S_new)
    return S_new
    
# Returns pmf of an aptamer
def get_x_pmf():
    return 0.25**40

# Returns pmf of a peptide
def get_y_pmf(y):
    pmf = 1
    for char in y[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf    

In [4]:
aptamer_dataset_file = "../../data/aptamer_dataset.json"
S = construct_dataset()
n = len(S)
m = int(0.8*n) #length of S_train
S_train = S[:m]
S_test = S[m:]
S_prime_train = get_S_prime("train") #use for sgd 
S_prime_test = get_S_prime("test") #use for sgd 
S_new = get_S_new(4000) #use for eval
#train_ds = np.hstack((S_train, S_prime_train[:len(S_prime_train)//2]))

## NN Model

## Helper methods

In [39]:
## Takes a peptide and aptamer sequence and converts to one-hot matrix
def translate(sequence, seq_type='peptide', single_alphabet=False):
    if single_alphabet:
        apt = sequence[0]
        pep = sequence[1]
        
        encoding = np.zeros(len(apt) + len(pep))
        
        # Encode the aptamer first
        for i in range(len(apt)):
            char = apt[i]
            idx = na_list.index(char)
            encoding[i] = idx
            
        # Encode the peptide second
        for i in range(len(pep)):
            char = pep[i]
            idx = aa_list.index(char)
            encoding[i+len(apt)] = idx
        return encoding     
    else:
        if seq_type == 'peptide':
            letters = aa_list
        else:
            letters = na_list
        
        encoding = np.zeros(len(sequence))
        for i in range(len(sequence)):
            char = sequence[i]
            idx = letters.index(char)
            encoding[i] = idx
        return encoding

# Convert a pair to one-hot tensor
def convert(apt, pep, label, single_alphabet=False): 
    if single_alphabet:
        pair = translate([apt, pep], single_alphabet=True) #(48, )
        print(str(pair.shape))
        pair = torch.FloatTensor(np.reshape(pair, (1, pair.shape[0]))).to(device)
        label = torch.FloatTensor([label]).to(device)
        return pair, label
    else:
        apt = translate(apt, seq_type='aptamer') #(40, )
        pep = translate(pep, seq_type='peptide') #(8, )
        print("Apt shape: ", apt.shape)
        print("Pep shape: ", pep.shape)
        apt = torch.FloatTensor(np.reshape(apt, (-1, 1, apt.shape[0]))).to(device) #(1, 1, 40)
        pep = torch.FloatTensor(np.reshape(pep, (-1, 1, pep.shape[0]))).to(device) #(1, 1, 8)
        label = torch.FloatTensor([[label]]).to(device)
        return apt, pep, label

# Getting the output of the model for a pair (aptamer, peptide)
def update(x, y, p, single_alphabet=False):
    if single_alphabet:
        p.requires_grad=True
        p = p.to(device)
        out = model(p)
        return out
    else:
        x.requires_grad=True
        y.requires_grad=True
        x = x.to(device)
        y = y.to(device)
        out = model(x, y)
        return out

In [40]:
class TranslateBatchNet(nn.Module):
    def __init__(self):
        super(TranslateBatchNet, self).__init__()
        self.name = "TranslateBatchNet"
        
        self.cnn_apt_1 = nn.Conv1d(1, 20, 3) 
        self.cnn_apt_2 = nn.Conv1d(20, 30, 3, padding=2) 
        self.cnn_apt_3 = nn.Conv1d(30, 20, 3, padding=2) 
        self.cnn_apt_4 = nn.Conv1d(20, 5, 1) 
        
        self.cnn_pep_1 = nn.Conv1d(1, 15, 3, padding=2)
        self.cnn_pep_2 = nn.Conv1d(15, 30, 3, padding=2)
        self.cnn_pep_3 = nn.Conv1d(30, 10, 3, padding=2)
        self.cnn_pep_4 = nn.Conv1d(10, 5, 2, padding=2)

        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(2) 
        
        self.cnn_apt = nn.Sequential(self.cnn_apt_1, self.maxpool, self.relu, 
                                     self.cnn_apt_2, self.maxpool, self.relu,
                                     self.cnn_apt_3, self.maxpool, self.relu,
                                     self.cnn_apt_4, self.maxpool, self.relu)
        self.cnn_pep = nn.Sequential(self.cnn_pep_1, self.maxpool, self.relu,
                                     self.cnn_pep_2, self.maxpool, self.relu,
                                     self.cnn_pep_3, self.maxpool, self.relu,
                                     self.cnn_pep_4, self.maxpool, self.relu)
        
        self.fc1 = nn.Linear(25, 10)
        self.fc2 = nn.Linear(10, 1)
    
    def forward(self, apt, pep):
        apt = self.cnn_apt(apt)
        pep = self.cnn_pep(pep)
        
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return x

## Captum

In [41]:
apt, pep = S_new[0]
baseline_pair, baseline_label = convert(apt, pep, 0, single_alphabet=True)
print(str(baseline_pair.shape))

(48,)
torch.Size([1, 48])


In [42]:
apt, pep = S_test[0]
input_pair, input_label = convert(apt, pep, 1, single_alphabet=True)
print(str(input_pair.shape))

(48,)
torch.Size([1, 48])


In [43]:
class TranslateSingleAlphabetBatchNet(nn.Module):
    def __init__(self):
        super(TranslateSingleAlphabetBatchNet, self).__init__()
        self.name = "TranslateSingleAlphabetBatchNet"
        
        self.cnn_1 = nn.Conv1d(1, 20, 3) 
        self.cnn_2 = nn.Conv1d(20, 30, 3, padding=2) 
        self.cnn_3 = nn.Conv1d(30, 20, 3, padding=2) 
        self.cnn_4 = nn.Conv1d(20, 5, 1) 

        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(2) 
        
        self.cnns = nn.Sequential(self.cnn_1, self.maxpool, self.relu, 
                                     self.cnn_2, 
                                     self.cnn_3, 
                                     self.cnn_4)

        
        self.fc1 = nn.Linear(135, 10)
        self.fc2 = nn.Linear(10, 2)
    
    def forward(self, pair):
        x = self.cnns(pair)
        print(str(x.shape))
        x = x.view(1, -1)
        print(str(x.shape))
        #x = x.view(-1, 1).T
        
        x = self.fc1(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        print("Output shape: ", x.shape)
        return x

In [84]:
# ToyModel is the one directly from Captum website
class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(48, 24)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(24, 2)

        # initialize weights and biases
#         self.lin1.weight = nn.Parameter(torch.arange(-4.0, 44.0).view(24, 48))
#         self.lin1.bias = nn.Parameter(torch.zeros(1,48))
#         self.lin2.weight = nn.Parameter(torch.arange(-3.0, 93.0).view(2, 48))
#         self.lin2.bias = nn.Parameter(torch.ones(1,2))

    def forward(self, input):
        x = self.relu(self.lin1(input))
        print(str(x.shape))
        x = self.lin2(x)
        print(str(x.shape))
        return x

In [85]:
### checkpoint = None #torch.load('../model_checkpoints/binary//06172020.pth')
model = ToyModel()
#optim = SGD(model.parameters(), lr=1e-2)
#model.load_state_dict(checkpoint['model_state_dict'])
#optim.load_state_dict(checkpoint['optimizer_state_dict'])
#epoch = checkpoint['epoch']
model.to(device)
model.eval()

ToyModel(
  (lin1): Linear(in_features=48, out_features=24, bias=True)
  (relu): ReLU()
  (lin2): Linear(in_features=24, out_features=2, bias=True)
)

In [86]:
print(str(input_pair.shape))
x = model(input_pair)
print(str(x))

torch.Size([1, 48])
torch.Size([1, 24])
torch.Size([1, 2])
tensor([[-0.8161,  0.6035]], device='cuda:0', grad_fn=<AddmmBackward>)


In [87]:
dl = DeepLift(model)
attributions, delta = dl.attribute(input_pair, baseline_pair, target=1, return_convergence_delta=True)
print('IG Attributions:', attributions)
print('Convergence Delta:', delta)

torch.Size([2, 24])
torch.Size([2, 2])
torch.Size([1, 24])
torch.Size([1, 2])
torch.Size([1, 24])
torch.Size([1, 2])
IG Attributions: tensor([[ 0.0269, -0.0290,  0.0370, -0.0000,  0.0240,  0.0000, -0.0000,  0.0453,
          0.0448,  0.0117,  0.0417,  0.0000, -0.0598,  0.0188,  0.0383,  0.0245,
          0.0000,  0.0172, -0.0638, -0.0000, -0.0133, -0.0093, -0.0136, -0.0427,
         -0.0000,  0.0532, -0.0212,  0.0000,  0.0000, -0.0000, -0.2063, -0.0064,
         -0.0480,  0.0069,  0.0000,  0.0000, -0.0157, -0.0385, -0.0000, -0.0506,
          0.0000, -0.0946,  0.0966, -0.0511, -0.0000, -0.2303,  0.2787,  0.1100]],
       device='cuda:0', grad_fn=<MulBackward0>)
Convergence Delta: tensor([7.4506e-08], device='cuda:0')


### Example from Captum's site

In [None]:
input = torch.rand(2, 3)
baseline = torch.zeros(2, 3)

In [None]:
# ToyModel is the one directly from Captum website
class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(3, 3)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(3, 2)

        # initialize weights and biases
        self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
        self.lin1.bias = nn.Parameter(torch.zeros(1,3))
        self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
        self.lin2.bias = nn.Parameter(torch.ones(1,2))

    def forward(self, input):
        x = self.lin2(self.relu(self.lin1(input)))
        print(str(x.shape))
        return x

In [None]:
model = ToyModel()
model.eval()
ig = IntegratedGradients(model)
attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)
print('IG Attributions:', attributions)
print('Convergence Delta:', delta)

In [None]:
dl = DeepLift(model)
attributions, delta = dl.attribute(input, baseline, target=0, return_convergence_delta=True)
print('DeepLift Attributions:', attributions)
print('Convergence Delta:', delta)

In [None]:
nc = NeuronConductance(model, model.lin1)
attributions = nc.attribute(input, neuron_index=2, target=0)
print('Neuron Attributions:', attributions)