In [1]:
%cd ..
%load_ext autoreload
%autoreload 2


/home/samuel.assis/MatchImm/MatchImmNet


In [2]:
from pyGeoMatchImm.data.datasets import TCRpMHCDataset, ChannelsPairDataset, ChannelsGraph
from pyGeoMatchImm.train.train import CrossValidator
from pyGeoMatchImm.models.models import (CrossAttentionGCN, 
                                         CrossAttentionGAT, 
                                         CrossAttentionGIN,
                                            MultiGCN,   
                                            MultiGAT,
                                            MultiGIN) 
from pyGeoMatchImm.utils.utils import validate_data
from pyGeoMatchImm.metrics.train_metrics import (plot_precision_recall_curve,
                                                 plot_roc_curve,
                                                 plot_loss_curve_logocv)

import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import TensorDataset
from torch.nn.functional import binary_cross_entropy_with_logits
from torch_geometric.data import Data, Batch
import os
import json
import logging as log

# logger
log.basicConfig(level=log.INFO)
log.getLogger("pyGeoMatchImm").setLevel(log.INFO)
log.getLogger("MDAnalysis").setLevel(log.WARNING)

In [22]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

map_cols = {
    'TCR_ID': 'id',
    'TRA_ref': 'TRA',
    'TRB_ref': 'TRB',
    'MHCseq_ref': 'MHCseq',
    'assigned_allele': 'mhc_allele',
    'peptide': 'epitope'
}

# Load Experimental Data
tcr3d_data_path = "data/02-processed/tcr3d_20251004_renamed.csv"
tcr3d_data = pd.read_csv(tcr3d_data_path, index_col=0)
tcr3d_data.drop(['TRA', 'TRB', 'MHCseq'], axis=1, inplace=True)
tcr3d_data.rename(columns=map_cols, inplace=True)
tcr3d_data.reset_index(drop=True, inplace=True)
print(tcr3d_data.head())

# Load AF Data
af_data_path = "data/01-raw/AF_vdjdb_score3_20251209.csv"
af_data = pd.read_csv(af_data_path)
af_data.rename(columns=map_cols, inplace=True)
af_data.reset_index(drop=True, inplace=True)
print(af_data.head())

# remove overlapping epitopes
tcr3d_data = tcr3d_data[~tcr3d_data['epitope'].isin(af_data['epitope'])]

train_data = pd.concat([tcr3d_data, af_data], ignore_index=True)
train_data = train_data.dropna(subset=['filepath_a', 'filepath_b'], how='any')

select_columns = ['id', 'TRA', 'TRB', 'CDR1A', 'CDR2A', 'CDR3A', 'CDR1B', 'CDR2B', 'CDR3B', 'TRA_num', 'TRB_num', 'epitope', 'MHCseq', 'mhc_allele', 'filepath_a', 'filepath_b', 'label', 'source']
train_data.drop_duplicates(subset=["TRA", "TRB", "epitope", "MHCseq"], inplace=True)
train_data = train_data[select_columns].head(30).copy()
train_data.to_csv("data/02-processed/tcrpMHC_combined_train_data.csv", index=False)


model_params = {
    "n_output": 1,
    "dropout": 0.3,
    "n_layers": 2
}

train_params = {
    "learning_rate"   : 0.0001,
    #"milestones"      : [1, 5],
    #"gamma"           : 0.5,
    "num_epochs"      : 60,
    "batch_size"      : 16,
    "pep_freq_range" : [0.005, 0.1],
    "k_top_peptides" : 10,
    "weight_decay"   : 0.01
}

config = {
    "source"          : "pdb",
    "channels"        : ["TCR", "pMHC"],
    "pairing_method"  : "basic",
    "embed_method"    : ["atchley"],
    "graph_method"    : "graphein",
    "negative_prop"   : 1,
    "edge_params"     : ["distance_threshold"],
    "node_params"     : ["amino_acid_one_hot", "hbond_donors", "hbond_acceptors", "dssp_config"],
    "graph_params"    : ["rsa"],
    "other_params"    : {"granularity": "centroids"},
    "dist_threshold"  : 8.0,
    "concat_embed"    : "all",
    "train_params"    : train_params,
    "model_params"    : model_params,
    "save_dir"        : ''
}

dataset = TCRpMHCDataset("data/02-processed/tcrpMHC_combined_train_data.csv", config=config)

models_dict = {
    "cagcn": CrossAttentionGCN,
    "cagat": CrossAttentionGAT,
    "cagin": CrossAttentionGIN,
    "multigcn": MultiGCN,
    "multigat": MultiGAT,
    "multigin": MultiGIN
}

embed = config['embed_method'][0]
arch = "cagin"

model_class = models_dict[arch]

log.info(f"Running architecture: {arch}")
log.info(f"Using embedding method: {embed}")

if embed == "atchley":
    config["train_params"]['norm'] = True
    config["train_params"]['out_channels'] = 16
else:
    config["train_params"]['norm'] = False
    config["train_params"]['out_channels'] = 128


save_dir = f"./{arch}_neg{config['negative_prop']}_{embed}_dp0{int(config['model_params']['dropout']*10)}"
config["save_dir"] = save_dir

channels = ChannelsGraph(dataset, config=config, embed_method=embed)
peptides = channels.get_seq_chain(name="pMHC", chain="C")


ds = ChannelsPairDataset(
        ids=channels.ids,
        ch1_graphs=channels.ch1,
        ch2_graphs=channels.ch2,
        labels=channels.get_labels(),   # or channels.y_tensor
        ch1_name=channels.channel_names[0],
        ch2_name=channels.channel_names[1],
    )

debug = ds[0]
print(debug)


validate_data(train_data, dict_lists={
    "id": channels.ids,
    "epitope": peptides,
    "label": channels.get_labels(),
    "TRA": channels.get_seq_chain(name="TCR", chain="D"),
    "TRB": channels.get_seq_chain(name="TCR", chain="E")
},
    cols_to_check=["epitope", "label", "TRA", "TRB"])


  class PDBDataset(BaseDataset):
  class ModelDataset(BaseDataset):
INFO:pyGeoMatchImm.data.datasets:Creating Dataset...

INFO:pyGeoMatchImm.data.datasets:Number of samples: 30
INFO:pyGeoMatchImm.data.datasets:PDBDataset initialized
INFO:pyGeoMatchImm.data.datasets:
Creating dataset...

INFO:pyGeoMatchImm.data.datasets:Number of samples: 30
INFO:pyGeoMatchImm.data.datasets:Annotations:                                       {'TCR': {'TRA': 'D', 'TRB': 'E'}, 'pMHC': {'epitope': 'C', 'MHCseq': 'A'}}                      {'TCR': 'filepath_a', 'pMHC': 'filepath_b'}                           {'A': True, 'C': False, 'D': True, 'E': True}
INFO:pyGeoMatchImm.data.datasets:
Building Pairing samples based on channels...

INFO:pyGeoMatchImm.data.datasets:
Parsing structures...

INFO:pyGeoMatchImm.utils.utils:
Run function: <pyGeoMatchImm.data.parser.StructureParser object at 0x78eac077c490>
INFO:pyGeoMatchImm.utils.utils:Using 25 CPU cores for multiprocessing.



Using device: cuda
        id      CDR1A          CDR2A           CDR3A    CDR1B  \
0  PDB8shi    YKTSINN   IRSNEREKHSGR    DALYSGGGADGL  QDMNHEY   
1  PDB4n5e   YSYSATPY  YYSGDPVVQGVNG       SAKGTGSKL  QTNSHNY   
2  PDB3kxf  YETRDTTYY  RNSFDEQNEISGR       SGFYNTDKL  QDMNHNS   
3  PDB8gvg   YSYGATPY  YFSGDTLVQGIKG       GFTGGGNKL  PISEHNR   
4  PDB4qrr  YETSWWSYY   QGSDEQNAKSGR  GELAGAGGTSYGKL  PISGHRS   

              CDR2B       CDR3B  \
0   SVGAGITDQGEVPNG   SYSEGEDEA   
1   SYGAGNLQIGDVPDG     SDAPGQL   
2   SASEGTTDKGEVPNG   PGLAGEYEQ   
3  FQNEAQLEKSRLLSDR  SDRDRVPETQ   
4   YFSETQRNKGNFPGR  SLEGGYYNEQ   

                                                 TRA  \
0  SQQGEEDPQALSIQEGENATMNCSYKTSINNLQWYRQNSGRGLVHL...   
1  AQSVTQPDARVTVSEGASLQLRCKYSYSATPYLFWYVQYPRQGLQM...   
2  QKVTQAQTEISVVEDEDVTLDCVYETRDTTYYLFWYKQPPSGELVF...   
3  AQSVTQPDIHITVSEGASLELRCNYSYGATPYLFWYVQSPGQGLQL...   
4  KVTQAQSSVSMPVRKAVTLNCLYETSWWSYYIFWYKQLPSKEMIFL...   

                                          

INFO:pyGeoMatchImm.data.datasets:
Generating graphs using graphein...

INFO:pyGeoMatchImm.utils.utils:
Run function: <pyGeoMatchImm.data.generators.graphein.GrapheinGeneratorRes object at 0x78eac077d030>
INFO:pyGeoMatchImm.utils.utils:Using 25 CPU cores for multiprocessing.

INFO:pyGeoMatchImm.data.datasets:
Generate negatives

INFO:pyGeoMatchImm.utils.negatives:Preparating data for chains Peptide chain C and  MHC chain A
INFO:pyGeoMatchImm.utils.negatives:Data: 30 positive samples.
INFO:pyGeoMatchImm.data.datasets:
Combining positives and negatives...

INFO:pyGeoMatchImm.data.datasets:Embedding sequences using Atchley factors...
INFO:pyGeoMatchImm.data.datasets:IDs: ['PDB4qrr_PDB8i5d', 'PDB6am5', 'PDB7jwi_PDB6rpa', 'PDB5nme', 'PDB4g8g', 'PDB3tjh', 'PDB7n5p_PDB7nmf', 'PDB6rpa_PDB3kxf', 'PDB5til_PDB6vmx', 'PDB7jwi', 'PDB5c0b', 'PDB4jfe_PDB7n4k', 'PDB2nx5_PDB8ryn', 'PDB8rym_PDB2e7l', 'PDB8i5d_PDB8rym', 'PDB8rym_PDB5nme', 'PDB8gvg', 'PDB4jfd', 'PDB2e7l_PDB6vrm', 'PDB4qrr', 'PDB4mxq', 'PDB

HeteroData(
  id='PDB4qrr_PDB8i5d',
  y=[1, 1],
  TCR={
    x=[33, 5],
    resid=[33],
    resname=[33],
    attmask=[33, 46],
  },
  pMHC={
    x=[46, 5],
    resid=[46],
    resname=[46],
    attmask=[46, 33],
  },
  (TCR, intra, TCR)={ edge_index=[2, 132] },
  (pMHC, intra, pMHC)={ edge_index=[2, 202] }
)


'Validation passed!'

In [None]:
print(channels.ids)
dataset.show_graph(channels.ids[0])

In [None]:

cv = CrossValidator(model_class=model_class,
                    dataset=ds, 
                    device=device,
                    peptides=peptides,
                    configs=config)

cv.run()


json.dump(config, open(f"./{save_dir}/config.json", "w"), indent=4)

plot_precision_recall_curve(cv.raw_results['labels'], 
                            cv.raw_results['predictions'], 
                            save_dir)

plot_roc_curve(cv.raw_results['labels'],
                cv.raw_results['predictions'],
                save_dir)

plot_loss_curve_logocv(cv.losses, save_dir=save_dir)

# reset pytorch
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# clean cache, memory release
del channels
del ds
del cv


In [29]:
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import torch.nn as nn

embed_size = 4
seqs_tgt = [torch.randn(6, embed_size), torch.randn(3, embed_size)]  # lengths: 6, 4
seqs_src = [torch.randn(8, embed_size), torch.randn(5, embed_size)]  # lengths: 8, 5

target_seq = pad_sequence(seqs_tgt, batch_first=True)   # (N, tgt_len, embed_size)
source_seq = pad_sequence(seqs_src, batch_first=True)   # (N, src_len, embed_size)


src_lens = [s.size(0) for s in seqs_src]
src_len = source_seq.size(1)
key_padding_mask = torch.zeros(len(seqs_src), src_len, dtype=torch.bool)
for i, l in enumerate(src_lens):
     if l <= src_len:
         key_padding_mask[i, l:] = True
print(key_padding_mask)


attention_mask = torch.zeros(target_seq.size(1), source_seq.size(1), dtype=torch.bool)
# e.g., mask out even positions in source sequence
attention_mask[:, 0:source_seq.size(1):2] = True

print("Target sequence shape:", target_seq.shape)  # (N, tgt_len, embed_size)
print("Source sequence shape:", source_seq.shape)  # (N, src_len, embed

# Define Q, K, V
Q = nn.Linear(embed_size, embed_size)(target_seq)
K = nn.Linear(embed_size, embed_size)(source_seq)
V = nn.Linear(embed_size, embed_size)(source_seq)

# Cross-attention
print(Q)
cross_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=2, batch_first=True)
output, attention_weights = cross_attention(Q, K, V, key_padding_mask=key_padding_mask, attn_mask=attention_mask)
print("Output shape:", output.shape)  # Should be (N, tgt_len, embed_size)
print(output)
print(attention_weights)



tensor([[False, False, False, False, False, False, False, False],
        [False, False, False, False, False,  True,  True,  True]])
Target sequence shape: torch.Size([2, 6, 4])
Source sequence shape: torch.Size([2, 8, 4])
tensor([[[ 0.1018,  0.8960, -0.9340,  0.3432],
         [ 0.0162,  0.0803, -0.7547, -0.1137],
         [ 0.0714, -0.7860, -0.1212,  0.2724],
         [ 0.3702, -0.4750, -0.2246, -0.0122],
         [ 0.5270, -0.1421, -0.2989,  0.1006],
         [ 0.2423, -1.0428, -0.1593, -0.3020]],

        [[ 0.6596,  0.2517, -0.3194,  0.3779],
         [ 0.6394,  0.7453, -0.3424,  0.9749],
         [-0.3149, -1.8017, -0.0607, -0.4215],
         [ 0.1360, -0.4031, -0.3952,  0.0157],
         [ 0.1360, -0.4031, -0.3952,  0.0157],
         [ 0.1360, -0.4031, -0.3952,  0.0157]]], grad_fn=<ViewBackward0>)
Output shape: torch.Size([2, 6, 4])
tensor([[[ 0.0759,  0.0064,  0.0395, -0.0513],
         [ 0.0823,  0.0064,  0.0369, -0.0478],
         [ 0.1083,  0.0032,  0.0184, -0.0235],
       

In [25]:
import torch
import torch.nn as nn

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        enc_layer = torch.nn.TransformerEncoderLayer(d_model=16, nhead=8, batch_first=True)
        self.layer = torch.nn.TransformerEncoder(enc_layer, num_layers=1)

    def forward(self, x, src_mask, key_mask):
        return self.layer(x, mask=src_mask, src_key_padding_mask=key_mask)

mod = test()
mod.eval()
out = mod(x=torch.randn(2, 22, 16), src_mask=torch.ones(8*2, 22, 22), key_mask=torch.ones(2, 22))
print(out)

tensor([[[-0.9416,  1.1450,  0.1306,  0.7909, -0.2337, -0.7684, -0.0434,
          -1.6090,  1.7747, -1.2908, -0.0066, -0.6836, -0.6529,  1.9001,
          -0.2477,  0.7365],
         [-1.5972, -0.7078,  1.1439,  0.2116, -0.4470, -0.3343, -1.7567,
          -0.4159, -0.5748, -0.6892,  0.8963, -0.0448,  1.9548,  0.7934,
           1.3895,  0.1782],
         [ 0.9891,  0.6050, -0.0239,  0.6277, -1.1375, -0.2669,  1.9786,
          -1.0510, -0.8519, -1.3660, -1.7038, -0.2443,  0.7395, -0.0208,
           0.4858,  1.2405],
         [-1.6902,  0.3960, -0.6675,  0.5224,  0.4621,  1.8230,  0.5850,
           0.4541,  0.8629, -0.0564, -0.8281, -0.4918, -2.0284, -0.4689,
           1.4402, -0.3143],
         [-0.0819,  1.9674, -0.2708, -0.0950, -0.1744,  0.8236, -0.8234,
           1.1195, -0.8682, -0.2539,  1.0151, -1.2534, -1.9518,  0.1913,
          -0.6517,  1.3076],
         [-1.1189, -0.5485,  0.5204, -1.7007,  0.1578, -0.0831,  0.9369,
           0.4599, -0.1658,  1.2799,  0.6909,  1.810