<a href="https://colab.research.google.com/github/francescopatane96/protein-xAI/blob/main/xAI_captum_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.nn.utils.rnn as rnn_utils
import os
import time
from sklearn.metrics import auc, roc_curve, average_precision_score, precision_recall_curve
from termcolor import colored
import pdb


os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
def generate_data(file):
    # Amino acid dictionary
    aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
               'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
               'W': 20, 'Y': 21, 'V': 22, 'X': 23}
    # open csv file 
    with open(file, 'r') as inf:
        lines = inf.read().splitlines()

    pep_codes = []
    labels = []
    peps = []
    
    for pep in lines:                           # for every row
        pep, label = pep.split(",")             # sequence and label split
        peps.append(pep)
        labels.append(int(label))
        current_pep = []
        for aa in pep:
            current_pep.append(aa_dict[aa])
        pep_codes.append(torch.tensor(current_pep))

        

    data = rnn_utils.pad_sequence(pep_codes, batch_first=True)  # Fill the sequence to the same length
  

    return data, torch.tensor(labels)

In [4]:
data, label = generate_data("./SSP_dataset.csv")

train_data, train_label= data[:1894], label[:1894]   #I primi 1894 sono usati per il training
test_data, test_label = data[1894:], label[1894:]

train_dataset = Data.TensorDataset(train_data, train_label)
test_dataset = Data.TensorDataset(test_data, test_label)

batch_size = 64
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
# lunghezza delle proteine x fase di inferenza

len(data[45])


299

In [207]:
class xAInet(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.hidden_dim = 25
    self.batch_size = 32
    self.embedding_dim = 512

    self.embedding_layer = nn.Embedding(24, self.embedding_dim, padding_idx=0)
    self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

    self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
    self.gru = nn.GRU(self.embedding_dim, self.hidden_dim, num_layers=2,
                      bidirectional=True, dropout=.2)
    
    self.block_seq = nn.Sequential(nn.Linear(15050, 2048),
                                   nn.BatchNorm1d(2048),
                                   nn.LeakyReLU(),
                                   nn.Linear(2048, 1024),
                                   nn.BatchNorm1d(1024),
                                   nn.LeakyReLU(),
                                   nn.Linear(1024, 256),
                                   nn.BatchNorm1d(256),
                                   nn.ReLU(),
                                   nn.Linear(256, 8),
                                   nn.Linear(8, 2))
    
    
  def forward(self, seq):
        seq = seq.long()
        embeddings = self.embedding_layer(seq)
        output = self.transformer_encoder(embeddings).permute(1, 0, 2)
        output, hn = self.gru(output)
        output = output.permute(1, 0, 2)
        hn = hn.permute(1, 0, 2)
       
        output = output.reshape(output.shape[0], -1)
        hn = hn.reshape(output.shape[0], -1)
        
        output = torch.cat([output, hn], 1)
        output = self.block_seq(output)
        
        #output = F.softmax(output, dim=0)

        return output

  def train_model(self, seq):
    #with torch.no_grad():
        output = self.forward(seq)

        return output

In [176]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

In [177]:
def collate(batch):
    seq1_ls = []
    seq2_ls = []
    label1_ls = []
    label2_ls = []
    label_ls = []


    batch_size = len(batch)
    for i in range(int(batch_size / 2)):
        seq1, label1= batch[i][0], batch[i][1]
        seq2, label2= batch[i + int(batch_size / 2)][0], \
                                       batch[i + int(batch_size / 2)][1], \
                                       
        label1_ls.append(label1.unsqueeze(0))
        label2_ls.append(label2.unsqueeze(0))
        label = (label1 ^ label2)
        seq1_ls.append(seq1.unsqueeze(0))
        seq2_ls.append(seq2.unsqueeze(0))
        label_ls.append(label.unsqueeze(0))

        

    seq1 = torch.cat(seq1_ls).to(device)
    seq2 = torch.cat(seq2_ls).to(device)

    

    label = torch.cat(label_ls).to(device)
    label1 = torch.cat(label1_ls).to(device)
    label2 = torch.cat(label2_ls).to(device)
    return seq1, seq2, label, label1, label2


train_iter_cont = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                              shuffle=True, collate_fn=collate)

device = torch.device("cuda", 0)


def evaluate(data_iter, net):
    pred_prob = []
    label_pred = []
    label_real = []
    for x, y in data_iter:
        x, y = x.to(device), y.to(device)
        outputs = net.train_model(x)
        outputs_cpu = outputs.cpu()
        y_cpu = y.cpu()
        pred_prob_positive = outputs_cpu[:, 1]
        pred_prob = pred_prob + pred_prob_positive.tolist()
        label_pred = label_pred + outputs.argmax(dim=1).tolist()
        label_real = label_real + y_cpu.tolist()
    performance, roc_data, prc_data = caculate_metric(pred_prob, label_pred, label_real)
    return performance, roc_data, prc_data


def caculate_metric(pred_prob, label_pred, label_real):
    test_num = len(label_real)
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    for index in range(test_num):
        if label_real[index] == 1:
            if label_real[index] == label_pred[index]:
                tp = tp + 1
            else:
                fn = fn + 1
        else:
            if label_real[index] == label_pred[index]:
                tn = tn + 1
            else:
                fp = fp + 1

    # Accuracy
    ACC = float(tp + tn) / test_num

    # Sensitivity
    if tp + fn == 0:
        Recall = Sensitivity = 0
    else:
        Recall = Sensitivity = float(tp) / (tp + fn)

    # Specificity
    if tn + fp == 0:
        Specificity = 0
    else:
        Specificity = float(tn) / (tn + fp)

    # MCC
    if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) == 0:
        MCC = 0
    else:
        MCC = float(tp * tn - fp * fn) / (np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)))

    # ROC and AUC
    FPR, TPR, thresholds = roc_curve(label_real, pred_prob, pos_label=1)

    AUC = auc(FPR, TPR)

    # PRC and AP
    precision, recall, thresholds = precision_recall_curve(label_real, pred_prob, pos_label=1)
    AP = average_precision_score(label_real, pred_prob, average='macro', pos_label=1, sample_weight=None)

    performance = [ACC, Sensitivity, Specificity, AUC, MCC]
    roc_data = [FPR, TPR, AUC]
    prc_data = [recall, precision, AP]
    return performance, roc_data, prc_data


def to_log(log):
    with open("./results/ExamPle_Log.log", "a+") as f:
        f.write(log + '\n')

In [178]:
net = xAInet().to(device)
lr = 0.0001
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = ContrastiveLoss()
criterion_model = nn.CrossEntropyLoss(reduction='sum')
best_acc = 0
EPOCH = 5
for epoch in range(EPOCH):
    loss_ls = []
    loss1_ls = []
    loss2_3_ls = []
    t0 = time.time()
    net.train()
    for seq1, seq2, label, label1, label2 in train_iter_cont:
        output1 = net(seq1)
        output2 = net(seq2)
         
        #pdb.set_trace()
        output3 = net.train_model(seq1)

        output4 = net.train_model(seq2)
        loss1 = criterion(output1, output2, label)
        loss2 = criterion_model(output3, label1)
        loss3 = criterion_model(output4, label2)
        loss = loss1 + loss2 + loss3
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_ls.append(loss.item())
        loss1_ls.append(loss1.item())
        loss2_3_ls.append((loss2 + loss3).item())

    net.eval()
    with torch.no_grad():
        train_performance, train_roc_data, train_prc_data = evaluate(train_iter, net)
        test_performance, test_roc_data, test_prc_data = evaluate(test_iter, net)

    results = f"\nepoch: {epoch + 1}, loss: {np.mean(loss_ls):.5f}, loss1: {np.mean(loss1_ls):.5f}, loss2_3: {np.mean(loss2_3_ls):.5f}\n"
    results += f'train_acc: {train_performance[0]:.4f}, time: {time.time() - t0:.2f}'
    results += '\n' + '=' * 16 + ' Test Performance. Epoch[{}] '.format(epoch + 1) + '=' * 16 \
               + '\n[ACC,\tSE,\t\tSP,\t\tAUC,\tMCC]\n' + '{:.4f},\t{:.4f},\t{:.4f},\t{:.4f},\t{:.4f}'.format(
        test_performance[0], test_performance[1], test_performance[2], test_performance[3],
        test_performance[4]) + '\n' + '=' * 60
    print(results)
    # to_log(results)
    test_acc = test_performance[0]  # test_performance: [ACC, Sensitivity, Specificity, AUC, MCC]
    if test_acc > best_acc:
        best_acc = test_acc
        best_performance = test_performance
        filename = '{}, {}[{:.3f}].pt'.format('ExamPle' + ', epoch[{}]'.format(epoch + 1), 'ACC', best_acc)
        save_path_pt = os.path.join('./Model', filename)
        # torch.save(net.state_dict(), save_path_pt, _use_new_zipfile_serialization=False)
        best_results = '\n' + '=' * 16 + colored(' Best Performance. Epoch[{}] ', 'red').format(epoch + 1) + '=' * 16 \
                       + '\n[ACC,\tSE,\t\tSP,\t\tAUC,\tMCC]\n' + '{:.4f},\t{:.4f},\t{:.4f},\t{:.4f},\t{:.4f}'.format(
            best_performance[0], best_performance[1], best_performance[2], best_performance[3],
            best_performance[4]) + '\n' + '=' * 60
        print(best_results)
        best_ROC = test_roc_data
        best_PRC = test_prc_data


epoch: 1, loss: 20.05409, loss1: 0.83335, loss2_3: 19.22074
train_acc: 0.9757, time: 9.27
[ACC,	SE,		SP,		AUC,	MCC]
0.9430,	0.9617,	0.9247,	0.9855,	0.8867

[ACC,	SE,		SP,		AUC,	MCC]
0.9430,	0.9617,	0.9247,	0.9855,	0.8867

epoch: 2, loss: 10.09637, loss1: 0.53051, loss2_3: 9.56586
train_acc: 0.9852, time: 9.29
[ACC,	SE,		SP,		AUC,	MCC]
0.9726,	0.9660,	0.9791,	0.9903,	0.9452

[ACC,	SE,		SP,		AUC,	MCC]
0.9726,	0.9660,	0.9791,	0.9903,	0.9452

epoch: 3, loss: 7.25422, loss1: 0.57480, loss2_3: 6.67942
train_acc: 0.9979, time: 9.36
[ACC,	SE,		SP,		AUC,	MCC]
0.9599,	0.9787,	0.9414,	0.9935,	0.9205

epoch: 4, loss: 4.99697, loss1: 0.55507, loss2_3: 4.44190
train_acc: 0.9984, time: 9.44
[ACC,	SE,		SP,		AUC,	MCC]
0.9641,	0.9830,	0.9456,	0.9938,	0.9290

epoch: 5, loss: 4.09475, loss1: 0.65421, loss2_3: 3.44053
train_acc: 0.9958, time: 9.50
[ACC,	SE,		SP,		AUC,	MCC]
0.9557,	0.9830,	0.9289,	0.9947,	0.9128


In [179]:
torch.save(net, './model.pt')

- Load model

In [208]:
model = torch.load('model.pt', map_location=torch.device('cpu'))

In [209]:
def generate_data(file):
    # Amino acid dictionary
    aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
               'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
               'W': 20, 'Y': 21, 'V': 22, 'X': 23}
    
    with open(file, 'r') as inf:
        lines = inf.read().splitlines()

    pep_codes = []
    labels = []
    peps = []
    
    for pep in lines:
        pep, label = pep.split(",")
        peps.append(pep)
        labels.append(int(label))
        current_pep = []
        for aa in pep:
            current_pep.append(aa_dict[aa])
        pep_codes.append(torch.tensor(current_pep))

        
    desired_length = 299
    padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in pep_codes]
# Apply pad_sequence on the padded_sequences
    #data = pad_sequence(padded_sequences, batch_first=True)
    
    data = rnn_utils.pad_sequence(padded_sequences, batch_first=True)  # Fill the sequence to the same length
  

    return data

In [210]:
data = generate_data('SSP_dataset.csv')

In [130]:
data = data.to('cpu')
model = model.to('cpu')

In [211]:
def model_output(inputs):

    #inputs = inputs[0].unsqueeze(0)
  
    out = model(inputs)
    # Apply softmax to convert prediction scores to probabilities
    probabilities = torch.softmax(out, dim=1)
    

    # Get the predicted classes by selecting the class with the highest probability
    predicted_classes = torch.argmax(probabilities, dim=1)  
    return predicted_classes


In [212]:
data = data.long()

In [213]:
model_output(data)

tensor([0, 1, 0, 1, 1, 0, 1])

xAI

In [None]:
pip install captum

In [214]:
import captum
from captum.attr import LayerIntegratedGradients

lig = LayerIntegratedGradients(model_output, model.embedding_layer)

In [215]:
def construct_input_and_baseline(text):

    max_length = 512
    #baseline_token_id = rnn_utils.pad_sequence()
    

    input_ids = []
    token_list = []
    
    aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
               'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
               'W': 20, 'Y': 21, 'V': 22, 'X': 23}
    
    for char in text:
      if char in aa_dict:
        input_ids.append(aa_dict[char])
        token_list.append(char)

    baseline_token_id = 23
    baseline_input_ids = [baseline_token_id] * len(input_ids)

    input_ids_tensor = torch.tensor([input_ids], device='cpu')
    baseline_input_ids_tensor = torch.tensor([baseline_input_ids], device='cpu')

    return input_ids_tensor, baseline_input_ids_tensor, token_list

In [216]:
text = 'MSKSKMLVFKSKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFK'

input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')
print(f'all tokens: {all_tokens}')


original text: tensor([[13, 17, 12, 17, 12, 13, 11, 22, 14, 12, 17, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 17, 12, 17, 12, 13, 11, 22,
         14, 12, 13, 17, 12, 17, 12, 13, 11, 22, 14, 12, 13, 17, 12, 17, 12, 13,
         11, 22, 14, 12, 13, 17, 12, 17, 12, 13, 11, 22, 14, 12]])
baseline text: tensor([[23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
         23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
         23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
         23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
         23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
         23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23]])
all tokens: ['M', 'S', 'K'

In [217]:
desired_length = 299
padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in input_ids]
# Apply pad_sequence on the padded_sequences
    #data = pad_sequence(padded_sequences, batch_first=True)
    
input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)

In [218]:
desired_length = 299
padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in baseline_input_ids]
# Apply pad_sequence on the padded_sequences
    #data = pad_sequence(padded_sequences, batch_first=True)
    
baseline_input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)

In [72]:
print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')
print(f'all tokens: {all_tokens}')


original text: tensor([[13., 17., 12., 17., 12., 13., 11., 22., 14., 12., 17., 12., 12., 12.,
         12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
         12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
         12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
         12., 12., 12., 12., 12., 12., 12., 12., 13., 17., 12., 17., 12., 13.,
         11., 22., 14., 12., 13., 17., 12., 17., 12., 13., 11., 22., 14., 12.,
         13., 17., 12., 17., 12., 13., 11., 22., 14., 12., 13., 17., 12., 17.,
         12., 13., 11., 22., 14., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.

In [99]:
input_ids = input_ids.long()
baseline_input_ids = baseline_input_ids.long()

In [100]:
with torch.no_grad():
  attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )

RuntimeError: ignored

In [101]:
from captum.attr import IntegratedGradients

input_ids = input_ids.long()
baseline_input_ids = baseline_input_ids.long()
target=0


ig = IntegratedGradients(model, model.embedding_layer)

# Calcola le attribuzioni degli input_ids utilizzando gli Integrated Gradients
attributions, delta = ig.attribute(input_id, targets)

# Stampa le attribuzioni ottenute
print("Attribuzioni degli input_ids:", attributions)

# Stampa la convergenza delta
print("Convergenza delta:", delta)


AssertionError: ignored

In [None]:
! pip install git+https://github.com/francescopatane96/captum.git

In [None]:
! pip install git+https://github.com/francescopatane96/pytorch.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/francescopatane96/pytorch.git
  Cloning https://github.com/francescopatane96/pytorch.git to /tmp/pip-req-build-jpn1_rsd
  Running command git clone --filter=blob:none --quiet https://github.com/francescopatane96/pytorch.git /tmp/pip-req-build-jpn1_rsd
  Resolved https://github.com/francescopatane96/pytorch.git to commit f19535a948fa711e9fea7b9da5cd04043074dfc3
  Running command git submodule update --init --recursive -q
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone


In [219]:
from captum.attr import IntegratedGradients

input_ids = input_ids.to(torch.long)
baseline_input_ids = baseline_input_ids.to(torch.long)



ig = IntegratedGradients(model_output, model.embedding_layer)

In [226]:
out = model(input_ids[0:3])

In [227]:
out

tensor([[ 3.4333, -0.5113]], grad_fn=<AddmmBackward0>)

In [186]:
input = torch.tensor(input_ids)
baseline = torch.tensor(baseline_input_ids)

  input = torch.tensor(input_ids)
  baseline = torch.tensor(baseline_input_ids)


In [None]:
ig = IntegratedGradients(model)


attribution = ig.attribute(inputs = input, baselines=baseline, target=0)           #(, baselines = baseline, target=0)