In [None]:
import argparse
import logging
import os
import sys
import json
import pandas as pd
from Bio.PDB import PDBParser, is_aa
from tqdm import tqdm

sys.path.append(os.path.abspath('../'))

from network import MMPeptide, SEQPeptide, VoxPeptide, MMFPeptide
from swinunetr import SwinUNETR
from ThermoGNN.model import GraphGNN
from dataset import pdb_parser, AMAs
from dataset_graph import PairData
from torchmetrics import F1Score, Accuracy, AveragePrecision, AUROC
import torch
import networkx as nx
from Bio.PDB.Polypeptide import three_to_one
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from torch.utils.data import DataLoader, Dataset
from torch_geometric.loader import DataLoader as GDataLoader
import numpy as np
from utils import set_seed
import warnings

warnings.filterwarnings('ignore')

basedir = os.path.abspath('../checkpoints')
wdirs = [
    'anti-gat-mlce1280.00252-10-50',
    'anti-gcn-mlce1280.00252-10-50',
    'anti-gin-mlce1280.00252-10-50',
    'anti-graphsage-mlce1280.00252-10-50',
    'anti-mm-ce1280.00250',
    'anti-mm-mlce1280.00250',
    'anti-seq-ce1280.00250',
    'anti-voxel-ce1280.00250',
    'anti-voxel-tr-ce160.00230',
]
simi = 64

def main(weight_dir, perform):
    if not os.path.exists(weight_dir):
        raise ValueError
    args = load_args(weight_dir)
    
    set_seed(args.seed)

    logging.basicConfig(handlers=[
        logging.StreamHandler()],
        format="%(asctime)s: %(message)s", datefmt="%F %T", level=logging.INFO)

    logging.info(f'saving_dir: {weight_dir}')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    voxel_flag = True
    if args.model == 'seq':
        model = SEQPeptide(classes=args.classes, q_encoder='mlp')
    elif args.model == 'voxel':
        model = VoxPeptide(classes=args.classes)
    elif args.model == 'mm':
        model = MMPeptide(classes=args.classes, q_encoder='mlp')
    elif args.model == 'mmf':
        model = MMFPeptide(classes=args.classes, q_encoder='mlp')
    elif args.model == 'voxel-tr':
        model = SwinUNETR(img_size=(64, 64, 64), in_channels=3, classes=args.classes)
    else:
        model = GraphGNN(num_layer=args.num_layer, input_dim=20, emb_dim=args.emb_dim, out_dim=args.classes, JK="last",
                         drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type)
        voxel_flag = False
    model.to(device)
        
    logging.info('Loading Test Dataset')
    qlx_set = advs24(voxel_flag=voxel_flag, max_length=50)
    if voxel_flag:
        qlx_loader = DataLoader(qlx_set, batch_size=1, shuffle=False)
    else:
        qlx_loader = GDataLoader(qlx_set, batch_size=1, follow_batch=['x_s'], shuffle=False)

    logging.info(f'Test set:advs2024 {len(qlx_set)}')
    
    class_performs = [[] for _ in range(5)]
    for fold in range(5):
        weights_path = f"{weight_dir}/model_{fold + 1}.pth"
        model.load_state_dict(torch.load(weights_path))
        logging.info(f'Running Cross Validation {fold + 1}')

        class_performs[fold] = eval(args, model, qlx_loader, device, voxel_flag)

    logging.info(f'Cross Validation Finished!')

    qlx_perform_list = np.asarray(class_performs).reshape(5, 24)
    taskname = os.path.basename(weight_dir)
    perform.write(','.join(sum([[taskname] + [str(x) for x in np.mean(qlx_perform_list, 0)[i:i+4]] for i in range(0, 24, 4)], []))+'\n')
    perform.write(','.join(sum([['std'] + [str(x) for x in np.std(qlx_perform_list, 0)[i:i+4]] for i in range(0, 24, 4)], []))+'\n')

def eval(args, model, valid_loader, device, voxel_flag):
    num_labels = 4
    avg = 'none'
    if num_labels == 1:
        task = 'binary'
    else:
        task = 'multilabel'
    metric_acc = Accuracy(average=avg, task=task, num_labels=num_labels).to(device)
    metric_f1 = F1Score(average=avg, task=task, num_labels=num_labels).to(device)
    metric_ap = AveragePrecision(average=avg, task=task, num_labels=num_labels).to(device)
    metric_auc = AUROC(average=avg, task=task, num_labels=num_labels).to(device)
    
    avg = 'micro'
    micro_acc = Accuracy(average=avg, task=task, num_labels=num_labels).to(device)
    micro_f1 = F1Score(average=avg, task=task, num_labels=num_labels).to(device)
    micro_ap = AveragePrecision(average=avg, task=task, num_labels=num_labels).to(device)
    micro_auc = AUROC(average=avg, task=task, num_labels=num_labels).to(device)

    model.eval()
    preds = []
    gts = []
    with torch.no_grad():
        for data in valid_loader:
            voxel, seq, gt = data
            gts.append(gt.to(device))
            if voxel_flag:
                out = model((voxel.to(device), seq.to(device)))
            else:
                out = model(voxel.to(device))
            preds.append(out)

    preds = torch.nn.functional.sigmoid(torch.cat(preds, dim=0))[:, [1, 2, 4, 5]]
    gts = torch.cat(gts, dim=0).int()[:, [1, 2, 4, 5]]

    ap = metric_ap(preds, gts).cpu().detach().numpy()
    f1 = metric_f1(preds, gts).cpu().detach().numpy()
    acc = metric_acc(preds, gts).cpu().detach().numpy()
    auc = metric_auc(preds, gts).cpu().detach().numpy()
    macros = np.expand_dims(np.mean(np.stack((ap, f1, acc, auc), axis=0), axis=-1), axis=0)
    micros = np.asarray([[micro_ap(preds, gts).item(), micro_f1(preds, gts).item(), 
                         micro_acc(preds, gts).item(), micro_auc(preds, gts).item()]])
    results = np.concatenate((macros, micros, np.stack((ap, f1, acc, auc), axis=-1)), axis=0)
    return results

class advs24(Dataset):
    def __init__(self, voxel_flag, max_length=50):
        self.num_classes = 6
        self.max_length=50
        self.voxel_flag = voxel_flag
        self.p = PDBParser(QUIET=True)
        if voxel_flag:
            processer = self.voxelprocess
        else:
            processer = self.graphprocess

        all_data = pd.read_csv('advs2024b.csv', encoding="unicode_escape").values
        filtered = pd.read_csv(f'./simi/{simi}.csv')['FileB_Sequence'].str.upper().str.strip().unique().tolist()

        idx_list, seq_list, labels = all_data[:, 0], all_data[:, 1], all_data[:, [2, 3, 4, 5, 6, 8]].astype(np.float32)

        filter_idx_list = []
        seq_new_list = []
        label_list = []
        for idx in range(len(idx_list)):
            seq = seq_list[idx].upper().strip()
            if seq in filtered:
                continue
            if 'X' in seq or 'U' in seq or 'O' in seq or len(seq) > max_length or len(seq) < 6:
                continue

            filter_idx_list.append(idx)
            seq_new_list.append(seq)
            label_list.append(labels[idx])

        self.data_list = []
        for i in tqdm(range(len(filter_idx_list))):
            idx = filter_idx_list[i]
            seq = seq_new_list[i]
            label = label_list[i]
            if os.path.exists('./advs2024/' + seq + ".pdb"):
                pdb_path = './advs2024/' + seq + ".pdb"
            else:
                print(f'lacking pdb file {seq}')
                continue
            
            processer(idx, pdb_path, seq, label)
                
    def voxelprocess(self, idx, pdb_path, seq, label):
        voxel, _ = pdb_parser(self.p, idx, pdb_path)
        seq_emb = [AMAs[char] for char in seq] + [0] * (self.max_length - len(seq))
        self.data_list.append((voxel, seq_emb, label))
        return voxel, seq_emb, label
    
    def graphprocess(self, idx, pdb_path, seq, label):
        structure = self.p.get_structure(idx, pdb_path)
        G = nx.Graph()
        flag = False
        for i in structure[0]:
            if i.id != 'A':
                flag = True
        if flag:
            return

        chain = structure[0]['A']
        for res in chain:
            if is_aa(res.get_resname(), standard=True):
                resname = res.get_resname()
                G.add_node(res.id[1], name=resname)

        num_nodes = len(G.nodes)
        nodes_list = list(G.nodes)

        for i in range(num_nodes):
            for j in range(i + 1, num_nodes):
                m = nodes_list[i]
                n = nodes_list[j]
                distance = chain[m]["CA"] - chain[n]["CA"]

                if distance <= 5:
                    G.add_edge(m, n, weight=5 / distance)
        G = nx.convert_node_labels_to_integers(G)

        def load_aa_features(feature_path):
            aa_features = {}
            for line in open(feature_path):
                line = line.strip().split()
                aa, features = line[0], line[1:]
                features = [float(feature) for feature in features]
                aa_features[aa] = features
            return aa_features

        aa_features = load_aa_features('../features.txt')
        features = []
        for node in nodes_list:
            res = chain[int(node)]
            aa_feature = aa_features[three_to_one(res.get_resname())]
            features.append(aa_feature)
        for i, node in enumerate(G.nodes.data()):
            node[1]['x'] = features[i]
        data_wt = from_networkx(G)

        data_graph = PairData(data_wt.edge_index, data_wt.x)
        seq_emb = [AMAs[char] for char in seq] + [0] * (self.max_length - len(seq))
        self.data_list.append((data_graph, seq_emb, label))
        return data_graph, seq_emb, label

    def __getitem__(self, idx):
        voxel, seq_emb, gt = self.data_list[idx]
        if self.voxel_flag:
            return torch.Tensor(voxel).float(), torch.Tensor(seq_emb), torch.Tensor(gt)
        else:
            return voxel, torch.Tensor(seq_emb), torch.Tensor(gt)

    def __len__(self):
        return len(self.data_list)

def load_args(weight_dir):
    with open(os.path.join(weight_dir, "config.json"), "r") as f:
        args_dict = json.load(f)
    parser = argparse.ArgumentParser()
    for key in args_dict:
        parser.add_argument(f'--{key}')
    args = parser.parse_args([])
    args.__dict__.update(args_dict)
    return args

if __name__ == "__main__":
    for i in range(64, 81, 2):
        simi = i
        perform = open(f'./results/{simi}.csv', 'w')
        perform.write('Macro Average,AP,F1,ACC,AUC,Micro Average,AP,F1,ACC,AUC,Pseudomonas aeruginosa,AP,F1,ACC,AUC,Staphylococcus aureus,AP,F1,ACC,AUC,Enterobacteriaceae,AP,F1,ACC,AUC,Salmonella species,AP,F1,ACC,AUC\n')
        for dir in wdirs:
            wdir = os.path.join(basedir, dir)
            main(wdir, perform)
