In [1]:
import os,sys
b_directory = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'model'))
sys.path.insert(0, b_directory)
from repertoire_cls_mul import *
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import WeightedRandomSampler

In [2]:
class sx_Dataset(Dataset):
    def __init__(self,data1,data2):
        self.x1 = data1
        self.x2 = data2
        self.len = data1.shape[0]
 
    def __getitem__(self, index):
        return self.x1[index],self.x2[index]
 
    def __len__(self):
        return self.len


def calculate_metrics(label, pred):
    return [accuracy_score(label, pred),precision_score(label, pred, average='weighted', zero_division=0),
           recall_score(label, pred, average='weighted', zero_division=0),f1_score(label, pred, average='weighted', zero_division=0),
           precision_score(label, pred, average='macro', zero_division=0),recall_score(label, pred, average='macro', zero_division=0),
           f1_score(label, pred, average='macro', zero_division=0)]

In [3]:
def Focal_loss(logits, targets,alpha=1, gamma=2):
    loss = torch.nn.functional.cross_entropy(logits, targets, reduction='none')
    pt = torch.exp(-loss)
    loss = (alpha * (1-pt)**gamma * loss).mean()

    return loss

In [4]:
def cal_(train_array,train_labels,seed=3,lr=0.0001,EPOCH=130,BATCH_SIZE=64, device='cuda:0', save_path='./model.pt'):
    torch.manual_seed(seed)
    model=classification_model(tcr_dim=train_array.shape[-1],nums=train_array.shape[-2],class_nums=len(set(train_labels)))
    
    model=model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    labels = np.array(train_labels)
    class_sample_count = np.bincount(labels)  
    weights_per_class = 1. / class_sample_count
    weights = weights_per_class[labels] 
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
    
    for epoch in range(EPOCH):
        model.train()
        train_dataloader= DataLoader(dataset=sx_Dataset(train_array,train_labels),batch_size=BATCH_SIZE,sampler=sampler,num_workers=4,drop_last=True)

        for tra_step, (cdr3,label) in enumerate(train_dataloader):
            cdr3=torch.tensor(cdr3,dtype=torch.float32)
            cdr3=cdr3.to(device)
            
            label=torch.tensor(label,dtype=torch.float32)
            label=label.to(device)
            
            pred = model(cdr3)

            
            loss = Focal_loss(pred,label.long())
            
            optimizer.zero_grad()
            loss.requires_grad_(True)
            loss.backward()
            optimizer.step()

    torch.save(model,'./model.pt')

In [5]:
train_array=np.load('../../tmp_data/5/train_emb.npy')
train_labels=np.load('../../tmp_data/5/train_label.npy')
model_save_path='./model.pt'
cuda='cuda:0'

In [6]:
cal_(train_array,train_labels,device=cuda,save_path=model_save_path)

  cdr3=torch.tensor(cdr3,dtype=torch.float32)
  label=torch.tensor(label,dtype=torch.float32)


In [7]:
test_array=np.load('../../tmp_data/5/test_emb.npy')
test_labels=np.load('../../tmp_data/5/test_label.npy')

In [8]:
model=torch.load(model_save_path)

  model=torch.load(model_save_path)


In [9]:
model.eval()

classification_model(
  (get_beta): get_msg(
    (dropout): Dropout(p=0.5, inplace=False)
    (itm_head_1): Linear(in_features=30720, out_features=1024, bias=True)
    (itm_head_2): Linear(in_features=1024, out_features=256, bias=True)
  )
  (self_attention): MultiLayerSelfAttention(
    (attention_layers): ModuleList(
      (0-1): 2 x SelfAttention(
        (query): Linear(in_features=256, out_features=256, bias=True)
        (key): Linear(in_features=256, out_features=256, bias=True)
        (value): Linear(in_features=256, out_features=256, bias=True)
        (fc_out): Linear(in_features=256, out_features=256, bias=True)
      )
    )
  )
  (dense): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=256, out_features=20, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=20, out_features=8, bias=True)
  )
)

In [10]:
preds=model(torch.tensor(test_array,dtype=torch.float32).to(cuda))

In [11]:
calculate_metrics(test_labels,np.argmax(preds.detach().cpu().numpy(),axis=1))

[0.5047619047619047,
 0.5444081698121447,
 0.5047619047619047,
 0.4785977160800954,
 0.46846038461688116,
 0.5184157112128982,
 0.46463154558137265]