In [7]:
import torch
import pandas as pd
from Bio import SeqIO

data = torch.load('../data/ec/sprot_10_1022_esm1b_t33_ec_above_10_single_label.pt')
cached_pids = list(data.keys())
df = pd.read_csv('../data/ec/swissprot_complete_ec.csv')
all_pids = df['Entry'].tolist()
uncached_pids = list(set(all_pids) - set(cached_pids))
print(f'Cached: {len(cached_pids)}, Uncached: {len(uncached_pids)}')
records = list(SeqIO.parse('../data/swissprot/uniprot_sprot_10_1022.fasta', 'fasta'))
pid2seq = {record.id.split('|')[1]: str(record.seq) for record in records}
uncached_pid_seq = [f'>{pid}\n{pid2seq[pid]}\n' if pid in pid2seq else '' for pid in uncached_pids]
with open('../data/ec/multilabel/sprot_10_1022_uncached.fasta', 'w') as f:
    f.write(''.join(uncached_pid_seq))

Cached: 208885, Uncached: 27219


In [1]:
import torch
import pandas as pd
from Bio import SeqIO

df = pd.read_csv('../data/ec/swissprot_complete_ec.csv')
pids = df['Entry'].tolist()
ecs = df['EC number'].tolist()
pid2ecs = {pid: ec.split(';') for pid, ec in zip(pids, ecs)}
data1 = torch.load('../data/ec/sprot_10_1022_esm1b_t33_ec_above_10_single_label.pt')
print(len(data1))
data2 = torch.load('../data/ec/multilabel/esm1b_t33_sprot_10_1022_uncached.pt')
print(len(data2))
for pid, emb in data2.items():
    data1[pid] = {'embedding': emb, 'ec': pid2ecs[pid]}
print(len(data1))
torch.save(data1, '../data/ec/multilabel/sprot_10_1022_esm1b_t33.pt')

208885
19666
228551


In [2]:
ec2occurance = {}
for k, v in data1.items():
    for ec in v['ec']:
        ec2occurance[ec] = ec2occurance.get(ec, 0) + 1
occurances = list(ec2occurance.values())
print(f'Mean: {sum(occurances) / len(occurances)}, Max: {max(occurances)}, Min: {min(occurances)}')

Mean: 45.691654078549846, Max: 2176, Min: 1


In [4]:
import torch
import random

random.seed(0)
data = torch.load('../data/ec/multilabel/sprot_10_1022_esm1b_t33.pt')
print(len(data))
pids = list(data.keys())
random.shuffle(pids)
train_val_pids = pids[:int(0.8 * len(pids))]
test_pids = pids[int(0.8 * len(pids)):]
train_val_data = {pid: data[pid] for pid in train_val_pids}
test_data = {pid: data[pid] for pid in test_pids}
print(len(train_val_data), len(test_data))
torch.save(train_val_data, '../data/ec/multilabel/sprot_10_1022_esm1b_t33_train_val.pt')
torch.save(test_data, '../data/ec/multilabel/sprot_10_1022_esm1b_t33_test.pt')
train_val_ecs = []
for pid, v in train_val_data.items():
    train_val_ecs.extend(v['ec'])
train_val_ecs = list(set(train_val_ecs))
filtered_test_data = {}
for pid, v in test_data.items():
    ecs = v['ec']
    if all(ec in train_val_ecs for ec in ecs):
        filtered_test_data[pid] = v
print(len(filtered_test_data))
torch.save(filtered_test_data, '../data/ec/multilabel/sprot_10_1022_esm1b_t33_test_filtered.pt')

228551
182840 45711
45362


In [1]:
import torch
import random

random.seed(42)
data = torch.load('../data/ec/multilabel/sprot_10_1022_esm1b_t33_train_val.pt')
pids = list(data.keys())
random.shuffle(pids)
train_pids = pids[:int(0.875 * len(pids))]
val_pids = pids[int(0.875 * len(pids)):]
train_data = {pid: data[pid] for pid in train_pids}
val_data = {pid: data[pid] for pid in val_pids}
print(len(train_data), len(val_data))
torch.save(train_data, '../data/ec/multilabel/sprot_10_1022_esm1b_t33_train.pt')
torch.save(val_data, '../data/ec/multilabel/sprot_10_1022_esm1b_t33_val.pt')


159985 22855


In [1]:
import torch
import random,json

random.seed(42)
data = torch.load('../data/ec/multilabel/sprot_10_1022_esm1b_t33_train_val.pt')
label_list = []
for k, v in data.items():
    label_list.extend(v['ec'])
label_list = list(set(label_list))
print(f'num_labels: {len(label_list)}')
test_data = torch.load('../data/ec/multilabel/sprot_10_1022_esm1b_t33_test_filtered.pt')
for k, v in test_data.items():
    for ec in v['ec']:
        assert ec in label_list
with open('../data/ec/multilabel/label_list_train_val.json', 'w') as f:
    json.dump(label_list, f)

num_labels: 4966


In [2]:
test_labels = []
for k, v in test_data.items():
    test_labels.extend(v['ec'])
test_labels = list(set(test_labels))
print(f'test labels: {len(test_labels)}')

test labels: 2880


In [1]:
import pandas as pd

df = pd.read_csv('/work/jiaqi/CLEAN_original/app/data/split100.csv', sep='\t')
ec_nums = df['EC number'].tolist()
ec_nums = ';'.join(ec_nums)
ec_nums = list(set(ec_nums.split(';')))
print(f'num_ecs: {len(ec_nums)}')

num_ecs: 5242
