In [None]:
import numpy as np
import pandas as pd
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import average_precision_score, f1_score
# from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
import warnings
import argparse
warnings.filterwarnings('ignore')
import random
import numpy as np
import math
import os


## Model

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)              
    torch.cuda.manual_seed(seed)          
    torch.cuda.manual_seed_all(seed)       
    np.random.seed(seed)                   
    random.seed(seed)                      
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False    

class GOrna(Dataset):
    def __init__(self, data, label_map, baseline_index, model_type):
        common_gene = np.intersect1d(list(data.keys()),list(baseline_index))
        if model_type == 'cellfm':
            self.geneset = {gene:index+1 for index,gene in enumerate(baseline_index)}
        else:
            self.geneset = {gene:index for index,gene in enumerate(baseline_index)}
        self.gene=np.array([self.geneset[gene] for gene in common_gene]).astype(np.int32)
        self.label_size = len(label_map.keys())
        label =[set(data[gene]) for gene in data.keys()]
        self.label = [[label_map[item] for item in sublist] for sublist in label]
    
    def __len__(self):
        return len(self.gene)
    
    def __getitem__(self,idx):
        label_idx = torch.tensor(self.label[idx])
        label = torch.zeros(self.label_size, dtype=torch.int)
        label[label_idx] = 1
        return torch.tensor(self.gene[idx]), label
    
class MLP_GO(nn.Module):
    def __init__(self, gene_emb, label_size, hidden_dim=1028, num_emb_layers=2, dropout=0.2):
        super(MLP_GO, self).__init__()
        
        self.gene_emb = gene_emb
        feature_dim = gene_emb.shape[-1]
        
        # Embedding layers
        self.input_block = nn.Sequential(
                                         nn.LayerNorm(feature_dim, eps=1e-6)
                                        ,nn.Linear(feature_dim, hidden_dim)
                                        ,nn.SiLU()
                                        )

        self.hidden_block = []
        for i in range(num_emb_layers - 1):
            self.hidden_block.extend([
                                      nn.LayerNorm(hidden_dim, eps=1e-6)
                                     ,nn.Dropout(dropout)
                                     ,nn.Linear(hidden_dim, hidden_dim)
                                     ,nn.ReLU()
                                     ])
            if i == num_emb_layers - 2:
                self.hidden_block.extend([nn.LayerNorm(hidden_dim, eps=1e-6)])

        self.hidden_block = nn.Sequential(*self.hidden_block)

        # Output layer
        self.label_size = label_size
        self.output_block = nn.Sequential(
                                         nn.Linear(hidden_dim, 512)
                                         ,nn.Dropout(p=0.2)
                                         ,nn.SiLU()
                                         ,nn.Linear(512, 256)
                                         ,nn.Dropout(p=0.2)
                                         ,nn.SiLU()
                                         ,nn.Linear(256, self.label_size)
                                         )
        # Initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, gene_id):
        h_V = self.gene_emb[gene_id].to(torch.float32)
        h_V = self.input_block(h_V)
        # h_V = self.hidden_block(h_V)
        h_V = self.output_block(h_V) 
        return h_V


## Config

In [None]:
args = {}
args['seed'] = 5
args['task'] = 'MF'
args['epoch'] = 5
args['batch'] = 4
args['interval'] = 0.01
args['model'] = 'cellfm'
args['top'] = 10

class Config:
    def __init__(self, args):
        self.seed = args['seed']
        self.task = args['task']
        self.epoch = args['epoch']
        self.batch = args['batch']
        self.interval = args['interval']
        self.model = args['model']
        self.top = args['top']

args = Config(args)

## Loading Dataset

In [None]:
set_seed(args.seed)
device = 'cuda:0'
task = args.task
interval = args.interval
df_data_train =  pd.read_csv(f'../dataset/{task}/top{args.top}_data/processed_train.csv')
df_data_valid =  pd.read_csv(f'../dataset/{task}/top{args.top}_data/processed_valid.csv')
df_data_test =  pd.read_csv(f'../dataset/{task}/top{args.top}_data/processed_test.csv')

with open(f'../dataset/{task}/top{args.top}_data/func_dict.json') as file:
    label_dict = json.load(file)
if args.model == 'uce':
    key = 'protein'
else:
    key = 'gene'
data_train = df_data_train.groupby(key)['go'].apply(list).to_dict()
data_valid = df_data_valid.groupby(key)['go'].apply(list).to_dict()
data_test = df_data_test.groupby(key)['go'].apply(list).to_dict()
label_size = len(label_dict.keys())

gene_emb = torch.load('../dataset/cellFM_embedding.pt').to(device) 
model_idx = pd.read_csv('../csv/gene_info.csv')['HGNC_gene']

train_set = GOrna(data=data_train, label_map=label_dict, baseline_index=model_idx, model_type=args.model)
valid_set = GOrna(data=data_valid, label_map=label_dict, baseline_index=model_idx, model_type=args.model)
test_set = GOrna(data=data_test, label_map=label_dict, baseline_index=model_idx, model_type=args.model)
print(len(train_set), len(valid_set), len(test_set))
train_loader = DataLoader(train_set, batch_size=1024, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=1024, shuffle=False)
test_loader = DataLoader(test_set, batch_size=1024, shuffle=False)


7749 63 24


## Training

In [None]:

model = MLP_GO(gene_emb, label_size).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
num_epochs = args.epoch

best_valid_loss = float('inf')

best_model_path = f'weights/{task}' #/{args.model}_best_model_{args.seed}.pth
if not os.path.exists(best_model_path):
    os.makedirs(best_model_path)

best_model_path = f'weights/{task}/{args.model}_best_model_{args.seed}.pth'

best_epoch = 0
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    all_labels = []
    all_scores = []
    for batch_idx, (batch_ids, batch_labels) in enumerate(tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}', leave=False)):
        batch_ids, batch_labels = batch_ids.to(device), batch_labels.to(device)
        optimizer.zero_grad()
        outputs = model(batch_ids)
        loss = criterion(outputs, batch_labels.float())
        loss = (loss*batch_labels).sum()/batch_labels.sum() + (loss*(1-batch_labels)).sum()/(1-batch_labels).sum()
        if math.isnan(loss.item()):
            exit()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        all_labels.append(batch_labels.cpu().numpy())
        all_scores.append(torch.sigmoid(outputs).cpu().detach().numpy())
        

    avg_train_loss = total_train_loss / len(train_loader)
    all_labels = np.concatenate(all_labels)
    all_scores = np.concatenate(all_scores)
    train_aupr = average_precision_score(all_labels, all_scores)
    
    print('Loss/train', avg_train_loss, epoch)
    print('AUPR/train', train_aupr, epoch)

    model.eval()
    valid_loss = 0
    y_true, y_scores = [], []
    with torch.no_grad():
        for batch_ids, batch_labels in tqdm(test_loader, desc='Validating', leave=False):
            batch_ids, batch_labels = batch_ids.to(device), batch_labels.to(device)
            outputs = model(batch_ids)
            loss = criterion(outputs, batch_labels.float())
            loss = (loss*batch_labels).sum()/batch_labels.sum() + (loss*(1-batch_labels)).sum()/(1-batch_labels).sum()
            valid_loss += loss.item()
            y_true.append(batch_labels.cpu().numpy())
            y_scores.append(torch.sigmoid(outputs).cpu().numpy())


    avg_valid_loss = valid_loss / len(test_loader)
    y_true = np.concatenate(y_true).reshape(-1)
    y_scores = np.concatenate(y_scores).reshape(-1)
    precision, recall, _ = precision_recall_curve(y_true, y_scores)
    aupr = auc(recall, precision)

    
    best_f1 = 0
    thresholds = np.arange(0.0, 1.0, interval)
    for threshold in thresholds:
        predictions = (y_scores >= threshold).astype(int)
        f1 = f1_score(y_true, predictions, average='macro')
        best_f1 = max(best_f1, f1)

    
    print('Loss/valid', avg_valid_loss, epoch)
    print('AUPR/valid', aupr, epoch)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Avg Train Loss: {avg_train_loss:.4f}, Avg Valid Loss: {avg_valid_loss:.4f}, AUPR: {aupr:.4f}, F max: {best_f1:.4f}')

    if avg_valid_loss < best_valid_loss:
        best_epoch = epoch
        best_valid_loss = avg_valid_loss
        torch.save(model.state_dict(), best_model_path)
        print(f'Best model saved with validation loss: {best_valid_loss:.4f}')



                                                                 

Loss/train 1.285217523574829 0
AUPR/train 0.32630731048345446 0


                                                 

Loss/valid 1.022995948791504 0
AUPR/valid 0.7241943978711585 0
Epoch [1/5], Avg Train Loss: 1.2852, Avg Valid Loss: 1.0230, AUPR: 0.7242, F max: 0.7859
Best model saved with validation loss: 1.0230


                                                                 

Loss/train 1.1033946722745895 1
AUPR/train 0.32999786201010706 1


                                                 

Loss/valid 0.91304612159729 1
AUPR/valid 0.7299729673698993 1
Epoch [2/5], Avg Train Loss: 1.1034, Avg Valid Loss: 0.9130, AUPR: 0.7300, F max: 0.8022
Best model saved with validation loss: 0.9130


                                                                 

Loss/train 1.057364508509636 2
AUPR/train 0.3420036537020204 2


                                                 

Loss/valid 0.9065755009651184 2
AUPR/valid 0.7410746582493499 2
Epoch [3/5], Avg Train Loss: 1.0574, Avg Valid Loss: 0.9066, AUPR: 0.7411, F max: 0.8063
Best model saved with validation loss: 0.9066


                                                                 

Loss/train 1.032627671957016 3
AUPR/train 0.3607050540944967 3


                                                 

Loss/valid 0.901448130607605 3
AUPR/valid 0.7530088524643128 3
Epoch [4/5], Avg Train Loss: 1.0326, Avg Valid Loss: 0.9014, AUPR: 0.7530, F max: 0.8066
Best model saved with validation loss: 0.9014


                                                                 

Loss/train 1.0145306959748268 4
AUPR/train 0.38040314586530555 4


                                                 

Loss/valid 0.8872042894363403 4
AUPR/valid 0.7662916989731818 4
Epoch [5/5], Avg Train Loss: 1.0145, Avg Valid Loss: 0.8872, AUPR: 0.7663, F max: 0.8215
Best model saved with validation loss: 0.8872


In [None]:
model.load_state_dict(torch.load(best_model_path))
model.eval()
test_loss = 0
y_true, y_scores = [], []

with torch.no_grad():
    for batch_ids, batch_labels in tqdm(test_loader, desc='Testing', leave=False):
        batch_ids, batch_labels = batch_ids.to(device), batch_labels.to(device)
        outputs = model(batch_ids)
        loss = criterion(outputs, batch_labels.float())
        test_loss += loss.item()
        loss = (loss*batch_labels).sum()/batch_labels.sum() + (loss*(1-batch_labels)).sum()/(1-batch_labels).sum()
        y_true.append(batch_labels.cpu().numpy())
        y_scores.append(torch.sigmoid(outputs).cpu().numpy())

y_true = np.concatenate(y_true).reshape(-1)
y_scores = np.concatenate(y_scores).reshape(-1)

precision, recall, _ = precision_recall_curve(y_true, y_scores)
test_aupr = auc(recall, precision)

test_fmax = 0
thresholds = np.arange(0.0, 1.0, interval)
for threshold in thresholds:
    predictions = (y_scores >= threshold).astype(int)
    f1 = f1_score(y_true, predictions, average='macro')
    test_fmax = max(test_fmax, f1)
print(f'best epoch: {best_epoch}, best valid loss: {best_valid_loss:.4f}, Test Loss: {test_loss/len(test_loader):.4f}, Test AUPR: {test_aupr:.4f}, Test Fmax: {test_fmax:.4f}')


                                              



best epoch: 4, best valid loss: 0.8872, Test Loss: 0.4436, Test AUPR: 0.7663, Test Fmax: 0.8215
