In [1]:
import sys
sys.path.append('..')
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from utils import commons
from models.mlp import MLPModel
from datasets.sequence_dataset import SingleLabelSequenceDataset
import matplotlib.pyplot as plt
import numpy as np
import os, json
import seaborn as sns
from tqdm.auto import tqdm
from sklearn.manifold import TSNE

device = 'cuda:0'

def generate_embeddings(model, dataloader, label_list):
    embeddings, labels, preds = [], [], []
    with torch.no_grad():
        for data, label in tqdm(dataloader):
            output, feature = model(data.to(device))
            embeddings.append(commons.toCPU(feature))
            label = label.tolist()
            label = [label_list[i] for i in label]
            labels.extend(label)
            pred = torch.argmax(output, dim=1).tolist()
            pred = [label_list[i] for i in pred]
            preds.extend(pred)
    embeddings = torch.cat(embeddings, dim=0)
    assert len(embeddings) == len(labels)
    return embeddings, labels, preds

def get_ec2occurance(data_file, label_file):
    data = torch.load(data_file)
    with open(label_file, 'r') as f:
        label_list = json.load(f)
    ec2occurance = {label: 0 for label in label_list}
    for k, v in data.items():
        for label in v['ec']:
            ec2occurance[label] += 1
    
    return ec2occurance, label_list

def get_label2mean(embeddings, labels):
    label2embeddings = {label: [] for label in labels}
    assert len(embeddings) == len(labels), f'{len(embeddings)} != {len(labels)}'
    n = len(embeddings)
    for i in range(n):
        label2embeddings[labels[i]].append(embeddings[i])
    label2mean = {label: torch.stack(embeddings).mean(dim=0) for label, embeddings in label2embeddings.items()}
    return label2mean

def get_pairwise_angle(means):
    g_mean = means.mean(dim=0)
    centered_mean = means - g_mean
    means_ = F.normalize(centered_mean, p=2, dim=1)
    cosine = torch.matmul(means_, means_.t())
    angles = torch.acos(cosine)
    for i in range(angles.size(0)):
        angles[i, i] = 0
    
    return angles

# model_dir = '../logs/train_mlp_single_label_CE_2024_02_13__09_48_52'
model_dir = '../logs_nc2_v2/train_mlp_single_label_NC_2024_02_26__15_41_46_0.0005_0.01'
test_data_file = '../data/ec/sprot_10_1022_esm2_t33_ec_above_10_single_label_test.pt'
train_data_file = '../data/ec/sprot_10_1022_esm2_t33_ec_above_10_single_label_train.pt'
label_file = '../data/ec/swissprot_ec_list_above_10.json'

# dataset
trainset = SingleLabelSequenceDataset(train_data_file, label_file)
testset = SingleLabelSequenceDataset(test_data_file, label_file)
train_loader = DataLoader(trainset, batch_size=512, shuffle=False)
test_loader = DataLoader(testset, batch_size=512, shuffle=False)
with open(label_file, 'r') as f:
    label_list = json.load(f)
label2idx = {label: i for i, label in enumerate(label_list)}
ec2occurance, label_list = get_ec2occurance(train_data_file, label_file)

config = commons.load_config(os.path.join(model_dir, 'config.yml'))
model = globals()[config.model.model_type](config.model)
ckpt = torch.load(os.path.join(model_dir, 'checkpoints/best_checkpoints.pt'))
model.load_state_dict(ckpt)
model.eval()
model.to(device)

train_embeddings, train_labels, train_preds = generate_embeddings(model, train_loader, label_list)
test_embeddings, test_labels, test_preds = generate_embeddings(model, test_loader, label_list)
train_label2mean = get_label2mean(train_embeddings, train_labels)
test_label2mean = get_label2mean(test_embeddings, test_labels)

[2024-02-27 10:06:49,024::SequenceDataset::INFO] Loaded 167108 sequences with ec labels
[2024-02-27 10:06:49,025::SequenceDataset::INFO] Label level: 4; Num of labels: 1920
[2024-02-27 10:06:49,670::SequenceDataset::INFO] Loaded 20889 sequences with ec labels
[2024-02-27 10:06:49,670::SequenceDataset::INFO] Loaded 20889 sequences with ec labels
[2024-02-27 10:06:49,671::SequenceDataset::INFO] Label level: 4; Num of labels: 1920
[2024-02-27 10:06:49,671::SequenceDataset::INFO] Label level: 4; Num of labels: 1920


  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

In [8]:
from collections import Counter

def get_angle(vec1, vec2):
    cosine = torch.dot(vec1, vec2) / (torch.norm(vec1) * torch.norm(vec2))
    return round(torch.acos(cosine).item(), 4)

n = len(test_preds)
k = 0
gt2preds = {}
for i in range(n):
    if test_preds[i] != test_labels[i] and ec2occurance[test_labels[i]] < 10:
        if test_labels[i] not in gt2preds:
            gt2preds[test_labels[i]] = []
        gt2preds[test_labels[i]].append(test_preds[i])
        # print(f'{k}: {test_labels[i]} ({ec2occurance[test_labels[i]]}) -> {test_preds[i]} ({ec2occurance[test_preds[i]]})')
        # k += 1
for i, (gt, preds) in enumerate(gt2preds.items()):
    # given preds, return a list where each element is [pred, occurance of pred in preds]
    preds = Counter(preds).most_common()
    for pred, occurance in preds:
        print(f'{i}: {gt} ({ec2occurance[gt]}) -> {pred} ({ec2occurance[pred]}) * {occurance};      \tangle: {get_angle(train_label2mean[gt], train_label2mean[pred]) if gt in train_label2mean and pred in train_label2mean else "N/A"}')


0: 7.6.2.7 (4) -> 7.6.2.14 (27) * 20;      	angle: 0.3929
1: 4.2.1.32 (2) -> 3.5.4.27 (14) * 5;      	angle: 0.5068
1: 4.2.1.32 (2) -> 2.1.1.195 (51) * 1;      	angle: 0.7849
1: 4.2.1.32 (2) -> 4.2.3.5 (690) * 1;      	angle: 0.8067
2: 1.3.1.12 (3) -> 1.1.1.22 (34) * 14;      	angle: 0.8896
3: 2.7.2.7 (8) -> 2.3.1.234 (750) * 1;      	angle: 0.6864
3: 2.7.2.7 (8) -> 2.7.2.15 (18) * 1;      	angle: 0.6684
4: 1.3.7.5 (8) -> 2.7.7.6 (1574) * 1;      	angle: 0.9182
5: 2.1.1.298 (1) -> 2.1.1.297 (77) * 16;      	angle: 0.7904
6: 1.3.7.3 (8) -> 1.3.3.3 (240) * 3;      	angle: 0.5613
6: 1.3.7.3 (8) -> 3.1.1.96 (567) * 1;      	angle: 0.7274
7: 1.3.7.2 (8) -> 1.3.3.3 (240) * 1;      	angle: 0.6807
7: 1.3.7.2 (8) -> 1.3.7.3 (8) * 1;      	angle: 0.3354
7: 1.3.7.2 (8) -> 1.3.7.5 (8) * 1;      	angle: 0.5069
8: 4.4.1.3 (9) -> 3.5.1.103 (23) * 1;      	angle: 0.8903
9: 1.4.1.13 (7) -> 4.1.99.17 (438) * 2;      	angle: 1.077
9: 1.4.1.13 (7) -> 1.1.1.205 (89) * 1;      	angle: 0.9968
10: 2.10.1.1 (2

In [3]:
torch.acos(torch.tensor(-1 / 1919))

tensor(1.5713)

In [4]:
le10 = 0
for ec in test_labels:
    if ec2occurance[ec] < 10:
        le10 += 1
le10

904