<a href="https://colab.research.google.com/github/hereagain-Y/TCR_VAE/blob/main/Attention_cpu_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import random
import itertools
import numpy as np

In [2]:
def getSampleIndex(sample):
    sample_list = list( set(sample) )
    sample_array = np.array(sample)
    sample_index = []
    for s in sample_list:
        sample_index.append( np.where(sample_array==s)[0] )
    return sample_list, sample_index

In [4]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 16 # 512 node fully connected layer
        self.D =  8# 128 node attention layer
        self.K = 1

        # get attention score  make it into 1 dim 
        self.attention = nn.Sequential(
            nn.Linear(self.L, self.K),
            nn.Tanh() # nn,SELU()
           
            
        )
        

        self.classifier = nn.Sequential(
            nn.Linear(self.L * self.K, 1), # 16,1 
            nn.Sigmoid()
        )
    
    def forward(self,sequence,sample,count):
        A = self.attention(sequence)
        count_col = count.reshape(-1,1)
        A = A*count_col
        newA= torch.clone(A)
        sample_list, sample_index = getSampleIndex(sample)
        sample_feature = []
        for i in range(len(sample_list)):
            seq_numbers = sample_index[i]
            seqs_in_bag = sequence[seq_numbers]
            attention_bag = A[seq_numbers]
            attention_bag = torch.softmax(attention_bag,dim=0)
            seq_feature = seqs_in_bag*attention_bag
            sample_feature.append(seq_feature.sum(dim=0))
            newA[seq_numbers] = attention_bag
        sample_feature =torch.stack(sample_feature,dim=0)
        
        predictions = self.classifier(sample_feature) 
        Y_hat = torch.ge(predictions, 0.5).float() 
        return predictions,Y_hat, sample_list, newA

In [15]:
def loss_function( label, prediction ):
    reproduction_loss = nn.functional.binary_cross_entropy(label, prediction, reduction='sum')
    return reproduction_loss
       
def calculate_classification_error(Y_hat,Y):
    Y = Y.float()
    error = 1. - Y_hat.eq(Y).cpu().float().mean().data

    return error, Y_hat  

In [6]:
cuda = False
DEVICE = torch.device("cuda" if cuda else "cpu")
n_feature = 16   
model = Attention()

In [11]:

def gen_multi_list(amount, length):
    seqs = np.random.default_rng()
    return [seqs.random(length) for _ in range(amount)]

sequences = gen_multi_list(1000,16) #50*10
sequences = torch.from_numpy( np.array(sequences)).float()
counts =[]
for i in range(0,1000):
    n=random.randint(1, 5)
    counts.append(n)

counts = torch.from_numpy( np.array( counts ) ).float()
A = ['a','b','c','d','e','f','g','h','i','j']


samples=list(itertools.chain.from_iterable(itertools.repeat(x, 10) for x in A))


In [12]:

label =[]
for i in range(0,10):
    n=random.randint(0, 1)
    label.append(n)
labels = np.repeat(label,10)
  

sample_label_map = { i:j for (i,j) in zip(samples,labels) }
label=torch.from_numpy(np.array( label ).reshape( (len(label),1) )).float()


In [13]:

from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=1e-3)  
from sklearn.metrics import accuracy_score

In [16]:
epochs = 10000
for ite in range(epochs):
    overall_loss =0
    train_error = 0    
    predictions,Y_hat,sample_list, attention_weights = model(sequences, samples, counts)
    #get true sample label, for now, use binary
    #label_true = [ sample_label_map[s] for s in sample_list ]
    #column vectorat
    #label_true = np.array( label_true ).reshape( (len(label_true),1) )
    #label_true = torch.from_numpy( label_true ).float()
    label_true = label
    #Y_hat.shape 10*1
    #print(label_true[:5])
    #print(Y_hat[:5])
    loss = loss_function(predictions, label_true)
    overall_loss +=loss.item()
    
    error, predicted_label = calculate_classification_error(Y_hat, label_true)
    train_error += error
    
    loss.backward()
    optimizer.step()
    if (ite % 1000 == 0):
      print('Train Set, Epoch: {}, Loss: {:.4f},Error: {:.4f}, Accuracy: {:.2f}%'.format(ite+1, overall_loss,train_error,accuracy_score(label_true, Y_hat)*100))



Train Set, Epoch: 1, Loss: 6.7557,Error: 0.3000, Accuracy: 70.00%
Train Set, Epoch: 1001, Loss: 1.6365,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 2001, Loss: 0.1769,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 3001, Loss: 0.0240,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 4001, Loss: 0.0066,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 5001, Loss: 0.0024,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 6001, Loss: 0.0010,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 7001, Loss: 0.0005,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 8001, Loss: 0.0003,Error: 0.0000, Accuracy: 100.00%
Train Set, Epoch: 9001, Loss: 0.0002,Error: 0.0000, Accuracy: 100.00%


In [18]:
print(attention_weights[:5]) 

tensor([[0.0892],
        [0.0121],
        [0.0892],
        [0.0121],
        [0.0328]], grad_fn=<SliceBackward0>)
