In [1]:
import torch
from tqdm.auto import tqdm
import os

emb_dir = '/work/jiaqi/CLEAN/app/data/esm_data'
tags = ['train', 'val', 'test']
for tag in tags:
    esm2_data = torch.load(f'../data/ec/sprot_10_1022_esm2_t33_ec_above_10_single_label_{tag}.pt')
    esm1b_data = {}
    for pid in tqdm(esm2_data):
        emb = torch.load(os.path.join(emb_dir, f'{pid}.pt'))['mean_representations'][33]
        esm1b_data[pid] = {'embedding': emb, 'ec': esm2_data[pid]['ec']}
    torch.save(esm1b_data, f'../data/ec/sprot_10_1022_esm1b_t33_ec_above_10_single_label_{tag}.pt')

  0%|          | 0/167108 [00:00<?, ?it/s]

  0%|          | 0/20888 [00:00<?, ?it/s]

  0%|          | 0/20889 [00:00<?, ?it/s]

In [1]:
import torch
import pandas as pd
import os
from tqdm.auto import tqdm
import json
import numpy as np

emb_dir = '/work/jiaqi/CLEAN/app/data/esm_data'
csv_path = '/work/jiaqi/CLEAN/app/data/split100.csv'
df = pd.read_csv(csv_path, sep='\t')
# the column EC number doesn't contain ;
df_single_label = df[df['EC number'].str.contains(';') == False]
entries, ec_numbers = df_single_label['Entry'].tolist(), df_single_label['EC number'].tolist()
ec2occurance = {}
for ec in ec_numbers:
    if ec in ec2occurance:
        ec2occurance[ec] += 1
    else:
        ec2occurance[ec] = 1
data = {}
n = len(entries)
print(f'number of single label entries: {n}')
labels_above_10 = [ec for ec in ec2occurance if ec2occurance[ec] >= 5]
labels = set()
for i in tqdm(range(n)):
    pid = entries[i]
    ec = ec_numbers[i]
    if ec2occurance[ec] < 5:
        continue
    labels.add(ec)
    emb = torch.load(os.path.join(emb_dir, f'{pid}.pt'))['mean_representations'][33]
    data[pid] = {'embedding': emb, 'ec': [ec_numbers[i]]}
print(f'data: {len(data)}')
torch.save(data, f'../data/ec/CLEAN_split100_esm1b_t33_ec_above_5_single_label.pt')
# label_list = list(set(ec_numbers))
assert labels == set(labels_above_10)
label_list = list(labels)
print(f'number of unique labels: {len(label_list)}')
with open('../data/ec/CLEAN_split100_ec_above_5_single_label_list.json', 'w') as f:
    json.dump(label_list, f)

number of single label entries: 215439


  0%|          | 0/215439 [00:00<?, ?it/s]

data: 210734
number of unique labels: 2134


In [4]:
import torch
import pandas as pd
import json, os
from tqdm import tqdm
test_sets = ['price', 'new', 'halogenase']
with open('../data/ec/CLEAN_split100_ec_above_5_single_label_list.json') as f:
    label_list = json.load(f)
emb_dir = '/work/jiaqi/CLEAN/app/data/esm_data'
for test_set in test_sets:
    csv_path = f'/work/jiaqi/CLEAN/app/data/datasets/{test_set}.csv'
    df = pd.read_csv(csv_path, sep='\t')
    df_single_label = df[df['EC number'].str.contains(';') == False]
    entries, ec_numbers = df_single_label['Entry'].tolist(), df_single_label['EC number'].tolist()
    test_data = {}
    n = len(entries)
    print(f'{test_set} number of single label entries: {n}')
    for i in tqdm(range(n)):
        pid = entries[i]
        ec = ec_numbers[i]
        if ec not in label_list:
            continue
        emb = torch.load(os.path.join(emb_dir, f'{pid}.pt'))['mean_representations'][33]
        test_data[pid] = {'embedding': emb, 'ec': [ec_numbers[i]]}
    print(f'{test_set} data: {len(test_data)}')
    torch.save(test_data, f'../data/ec/{test_set}_esm1b_t33_single_label_ec_above_5.pt')
    

price number of single label entries: 146


100%|██████████| 146/146 [00:00<00:00, 7866.21it/s]


price data: 89
new number of single label entries: 334


100%|██████████| 334/334 [00:00<00:00, 9464.56it/s]


new data: 247
halogenase number of single label entries: 37


100%|██████████| 37/37 [00:00<00:00, 24286.27it/s]

halogenase data: 1





In [1]:
import torch
import random

data = torch.load('../data/ec/CLEAN_split100_esm1b_t33_ec_above_5_single_label.pt')
pids = list(data.keys())
random.shuffle(pids)
n = len(pids)
train_pids = pids[:int(n*0.9)]
val_pids = pids[int(n*0.9):]
train_data = {pid: data[pid] for pid in train_pids}
val_data = {pid: data[pid] for pid in val_pids}
print(f'train: {len(train_data)}, val: {len(val_data)}')
torch.save(train_data, '../data/ec/CLEAN_split100_esm1b_t33_ec_above_5_single_label_train.pt')
torch.save(val_data, '../data/ec/CLEAN_split100_esm1b_t33_ec_above_5_single_label_val.pt')

train: 189660, val: 21074


In [14]:
import torch
import pandas as pd
import os
from tqdm.auto import tqdm
import json
import numpy as np

emb_dir = '/work/jiaqi/CLEAN/app/data/esm_data'
csv_path = '/work/jiaqi/CLEAN/app/data/split100.csv'
data_path = '../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_train.pt'
val_data_path = '../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_val.pt'
data = torch.load(data_path)
val_data = torch.load(val_data_path)
data.update(val_data)
pids = list(data.keys())
df = pd.read_csv(csv_path, sep='\t')
# filter rows with Entry in pids
df = df[df['Entry'].isin(pids)]
assert len(df) == len(pids)
print(f'df: {len(df)}')
df.to_csv('/work/jiaqi/CLEAN/app/data/split100_ec_above_10_single_label_train_val.csv', sep='\t', index=False)

df: 186126


In [11]:
import torch
import random

random.seed(42)
data = torch.load('../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label.pt')
pids = list(data.keys())
random.shuffle(pids)
n = len(pids)
train_pids = pids[:int(n * 0.8)]
val_pids = pids[int(n * 0.8):int(n * 0.9)]
test_pids = pids[int(n * 0.9):]
train_data = {pid: data[pid] for pid in train_pids}
val_data = {pid: data[pid] for pid in val_pids}
test_data = {pid: data[pid] for pid in test_pids}
print(f'train: {len(train_data)}, val: {len(val_data)}, test: {len(test_data)}')
torch.save(train_data, '../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_train.pt')
torch.save(val_data, '../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_val.pt')
torch.save(test_data, '../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_test.pt')


train: 165445, val: 20681, test: 20681


In [2]:
import torch

clean_split100_data = torch.load('../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label.pt')
sprot_data = torch.load('../data/ec/sprot_10_1022_esm2_t33_ec_above_10.pt')
clean_pids = set(clean_split100_data.keys())
sprot_pids = set(sprot_data.keys())
print(f'clean: {len(clean_pids)}, sprot: {len(sprot_pids)}')
intersection = clean_pids.intersection(sprot_pids)
print(f'intersection: {len(intersection)}')
print(f'percentage: {len(intersection) / len(clean_pids)} in clean; {len(intersection) / len(sprot_pids)} in sprot')

clean: 206807, sprot: 228551
intersection: 205914
percentage: 0.9956819643435667 in clean; 0.9009542727881305 in sprot


In [None]:
import torch
import pandas 

test_data = torch.load('../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_test.pt')
test_pids = list(test_data.keys())
df = pandas.read_csv('/work/jiaqi/CLEAN/app/data/split100.csv', sep='\t')
df = df[df['Entry'].isin(test_pids)]
print(f'test: {len(df)}')
df.to_csv('/work/jiaqi/CLEAN/app/data/split100_ec_above_10_single_label_test.csv', sep='\t', index=False)

In [1]:
import torch

train_data = torch.load('../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_train.pt')
val_data = torch.load('../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_val.pt')
train_data.update(val_data)
torch.save(train_data, '../data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_train_val.pt')
# python scripts/eval_CLEAN.py configs/eval_long_tail.yml --clean_pred_file /work/jiaqi/CLEAN/app/results/inputs/CLEAN_split100_ec_above_10_single_label_test_maxsep_train.csv --train_data_file data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_train.pt --test_data_file data/ec/CLEAN_split100_esm1b_t33_ec_above_10_single_label_test.pt --label_file data/ec/CLEAN_split100_ec_above_10_single_label_list.json

In [1]:
import torch
import random

random.seed(42)
all_data = {}
for tag in ['train', 'val', 'test']:
    data = torch.load(f'../data/ec/sprot_10_1022_esm1b_t33_ec_above_10_single_label_{tag}.pt')
    all_data.update(data)
print(f'all: {len(all_data)}')
torch.save(all_data, '../data/ec/sprot_10_1022_esm1b_t33_ec_above_10_single_label.pt')
all_pids = list(all_data.keys())
random.shuffle(all_pids)
n = len(all_pids)
train_val_pids = all_pids[:int(n * 0.8)]
test_pids = all_pids[int(n * 0.8):]
train_val_data = {pid: all_data[pid] for pid in train_val_pids}
test_data = {pid: all_data[pid] for pid in test_pids}
print(f'train_val: {len(train_val_data)}, test: {len(test_data)}')
torch.save(train_val_data, '../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt')
torch.save(test_data, '../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_test.pt')

all: 208885
train_val: 167108, test: 41777


In [4]:
import torch
import pandas 

test_data = torch.load('../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_test.pt')
test_pids = list(test_data.keys())
print(f'test: {len(test_pids)}')
df = pandas.read_csv('/work/jiaqi/CLEAN/app/data/sprot_10_1022_ec_above_10_single_label_all.csv', sep='\t')
df = df[df['Entry'].isin(test_pids)]
print(f'test: {len(df)}')
df.to_csv('/work/jiaqi/CLEAN/app/data/sprot_10_1022_ec_above_10_single_label_test_for_ensemble.csv', sep='\t', index=False)

test: 41777
test: 41777


In [5]:
import torch
import pandas 

train_val_data = torch.load('../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt')
train_val_pids = list(train_val_data.keys())
print(f'train_val: {len(train_val_pids)}')
df = pandas.read_csv('/work/jiaqi/CLEAN/app/data/sprot_10_1022_ec_above_10_single_label_all.csv', sep='\t')
df = df[df['Entry'].isin(train_val_pids)]
print(f'train_val: {len(df)}')
df.to_csv('/work/jiaqi/CLEAN/app/data/sprot_10_1022_ec_above_10_single_label_train_val_for_ensemble.csv', sep='\t', index=False)

train_val: 167108
train_val: 167108


In [6]:
'P25152_8' in train_val_pids

False

In [2]:
import pandas as pd

df = pd.read_csv('/work/jiaqi/CLEAN/app/data/inputs/sprot_10_1022_ec_above_10_single_label_test_for_ensemble.csv', sep='\t')
entries, seqs = df['Entry'].tolist(), df['Sequence'].tolist()
entry_seq_list = [f'>{entries[i]}\n{seqs[i]}\n' for i in range(len(entries))]
with open('/work/jiaqi/CLEAN/app/data/inputs/sprot_10_1022_ec_above_10_single_label_test_for_ensemble.fasta', 'w') as f:
    f.writelines(entry_seq_list)

In [1]:
import torch
import random

random.seed(0)
train_val_data = torch.load(f'../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt')
print(f'train_val: {len(train_val_data)}')
train_val_pids = list(train_val_data.keys())
random.shuffle(train_val_pids)
train_pids = train_val_pids[:int(len(train_val_pids) * 0.875)]
val_pids = train_val_pids[int(len(train_val_pids) * 0.875):]
train_data = {pid: train_val_data[pid] for pid in train_pids}
val_data = {pid: train_val_data[pid] for pid in val_pids}
print(f'train: {len(train_data)}, val: {len(val_data)}')
torch.save(train_data, f'../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train.pt')
torch.save(val_data, f'../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_val.pt')

train_val: 167108
train: 146219, val: 20889


In [2]:
import torch

data = torch.load('../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_test.pt')
list(data.keys())[:10]

['B9JG24',
 'Q49YU4',
 'C4ZW44',
 'C3LKW3',
 'Q8ZWV9',
 'B1KJM7',
 'Q2SZS9',
 'A1KWN7',
 'O83433',
 'B3Q9V6']

In [3]:
data['B9JG24']

{'embedding': tensor([-0.0049,  0.1555, -0.0897,  ..., -0.0696, -0.0307,  0.0684]),
 'ec': ['1.1.1.25']}

In [7]:
import torch

pid2emb = torch.load('../data/ec/ensemble/esm1b_t33_single_EC_mutations_train_val.pt')
train_data = torch.load('../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt')
print(f'train data: {len(train_data)}')
for pid, emb in pid2emb.items():
    ec = train_data[pid.split('_')[0]]['ec']
    train_data[pid] = {'embedding': emb, 'ec': ec}
torch.save(train_data, '../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val_with_single_EC_mutations.pt')
print(f'train data: {len(train_data)}')

train data: 167108
train data: 167458


In [4]:
import random

for i in range(20):
    print(random.randint(0, 4), end=' ')

3 2 2 3 2 0 4 3 4 0 4 3 3 1 3 2 0 3 4 0 

In [2]:
import sys
sys.path.append('..')
from scripts.train_mlp import get_ec2occurance
import pandas as pd

data_file = '../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt'
label_file  = '../data/ec/swissprot_ec_list_above_10.json'
ec2occurance, _ = get_ec2occurance(data_file, label_file, 'ec', 4)
df = pd.read_csv('/work/jiaqi/CLEAN/app/data/sprot_10_1022_ec_above_10_single_label_train_val_for_ensemble.csv', sep='\t')
print(len(df))
# filter rows with ec2occurance[EC number] >= 10
df_filtered = df[df['EC number'].apply(lambda x: ec2occurance[x] >= 10)]
print(len(df_filtered))
df_filtered.to_csv('/work/jiaqi/CLEAN/app/data/sprot_10_1022_ec_above_10_single_label_train_val_for_ensemble_remove_minor.csv', sep='\t', index=False)

Loading ../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt for occurance statistics...
167108
164856


In [3]:
import csv

def get_ec_id_dict(csv_name: str) -> dict:
    csv_file = open(csv_name)
    csvreader = csv.reader(csv_file, delimiter='\t')
    id_ec = {}
    ec_id = {}

    for i, rows in enumerate(csvreader):
        if i > 0:
            id_ec[rows[0]] = rows[1].split(';')
            for ec in rows[1].split(';'):
                if ec not in ec_id.keys():
                    ec_id[ec] = set()
                    ec_id[ec].add(rows[0])
                else:
                    ec_id[ec].add(rows[0])
    return id_ec, ec_id

id_ec, ec_id = get_ec_id_dict('/work/jiaqi/CLEAN/app/data/sprot_10_1022_ec_above_10_single_label_train_val_for_ensemble_remove_minor.csv')
ec_id_lens = [len(ec_id[ec]) for ec in ec_id]
print(f'number of ECs: {len(ec_id)}, max: {max(ec_id_lens)}, min: {min(ec_id_lens)}')

number of ECs: 1410, max: 1750, min: 10


In [None]:
# python scripts/eval_CLEAN.py configs/eval_long_tail.yml --clean_pred_file /work/jiaqi/CLEAN_original/app/results/inputs/sprot_10_1022_ec_above_10_single_label_test_for_ensemble_maxsep_single_mutate_10_seed_0.csv --train_data_file data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt --test_data_file data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_test.pt --label_file data/ec/swissprot_ec_list_above_10.json

In [1]:
import sys
sys.path.append('..')
from scripts.train_mlp import get_ec2occurance
import pandas as pd

data_file = '../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt'
label_file  = '../data/ec/swissprot_ec_list_above_10.json'
ec2occurance, _ = get_ec2occurance(data_file, label_file, 'ec', 4)

class0_10 = [ec for ec in ec2occurance if ec2occurance[ec] < 10]
class10_30 = [ec for ec in ec2occurance if ec2occurance[ec] >= 10 and ec2occurance[ec] < 30]
print(f'class0_10: {len(class0_10)}, class10_30: {len(class10_30)}')

Loading ../data/ec/ensemble/sprot_10_1022_esm1b_t33_ec_above_10_single_label_train_val.pt for occurance statistics...
class0_10: 510, class10_30: 616


In [2]:
(510 * 0.8831 + 616 * 0.9706) / 1126

0.9309685612788631

In [3]:
(510 * 0.8886 + 616 * 0.9601) / 1126

0.9277154529307283