In [1]:
import os,sys
b_directory = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'model'))
sys.path.insert(0, b_directory)
from pan_epitope_double import *
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import (
    average_precision_score,
    roc_auc_score
)

In [2]:
class sx_Dataset(Dataset):
    def __init__(self,data1,data2,data3,data4):
        self.x1 = data1
        self.x2 = data2
        self.x3 = data3
        self.x4 = data4
        self.len = data1.shape[0]
 
    def __getitem__(self, index):
        return self.x1[index],self.x2[index],self.x3[index],self.x4[index]
 
    def __len__(self):
        return self.len

In [3]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
def cal_(beta_train_emb,alpha_train_emb,ep_train_emb,train_labels,seed=1,lr=0.0001,EPOCH=100,BATCH_SIZE=16, device='cuda:1',save_path='./model.pt'):
    
    set_seed(seed)
    dataset = sx_Dataset(beta_train_emb,alpha_train_emb,ep_train_emb,train_labels)
    train_dataloader=DataLoader(dataset=dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4,drop_last=True)

    model=classification_model(tcr_dim=beta_train_emb.shape[-1], pep_dim=ep_train_emb.shape[-1], 
                               )
    model=model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(EPOCH):
        model.train()       

        for tra_step, (btr,atr,pep,tl) in enumerate(train_dataloader):   
            pep=torch.tensor(pep,dtype=torch.float32).to(device) 
            btr=torch.tensor(btr,dtype=torch.float32).to(device)    
            atr=torch.tensor(atr,dtype=torch.float32).to(device) 
            tl=torch.tensor(tl,dtype=torch.float32).to(device)

            pred = model(btr,atr,pep).flatten()  
            loss = F.binary_cross_entropy(pred,tl)

            optimizer.zero_grad()
            loss.requires_grad_(True)
            loss.backward()
            optimizer.step()
            
    torch.save(model,save_path)

In [5]:
beta_train_emb=np.load('../../tmp_data/3/beta_train_emb.npy')
alpha_train_emb=np.load('../../tmp_data/3/alpha_train_emb.npy')
ep_train_emb=np.load('../../tmp_data/3/ep_train_emb.npy')
train_labels=pd.read_csv('../../tmp_data/3/train.csv')['Target'].to_numpy()
model_save_path='./model.pt'
cuda='cuda:0'

In [6]:
cal_(beta_train_emb,alpha_train_emb,ep_train_emb,train_labels,device=cuda,save_path=model_save_path)

  pep=torch.tensor(pep,dtype=torch.float32).to(device)
  btr=torch.tensor(btr,dtype=torch.float32).to(device)
  atr=torch.tensor(atr,dtype=torch.float32).to(device)
  tl=torch.tensor(tl,dtype=torch.float32).to(device)


In [7]:
beta_test_emb=np.load('../../tmp_data/3/beta_test_emb.npy')
alpha_test_emb=np.load('../../tmp_data/3/alpha_test_emb.npy')
ep_test_emb=np.load('../../tmp_data/3/ep_test_emb.npy')
test_labels=pd.read_csv('../../tmp_data/3/test.csv')['Target'].to_numpy()

In [8]:
model=torch.load(model_save_path)

  model=torch.load(model_save_path)


In [9]:
model.eval()

classification_model(
  (cdr3_beta_linear): Linear(in_features=1024, out_features=256, bias=True)
  (cdr3_alpha_linear): Linear(in_features=1024, out_features=256, bias=True)
  (pep_linear): Linear(in_features=1024, out_features=256, bias=True)
  (gate_conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
  (sigmoid): Sigmoid()
  (inter_layers): ModuleList(
    (0-1): 2 x Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (seqlevel_outlyer): Sequential(
    (0): AdaptiveMaxPool2d(output_size=1)
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
    (4): Sigmoid()
  )
)

In [10]:
preds=model(torch.tensor(beta_test_emb,dtype=torch.float32).to(cuda),
            torch.tensor(alpha_test_emb,dtype=torch.float32).to(cuda),
            torch.tensor(ep_test_emb,dtype=torch.float32).to(cuda),
           )

In [11]:
roc_auc_score(test_labels,preds.detach().cpu().numpy())

0.96