In [1]:
import pandas as pd
# display dataframes
from IPython.display import display
import json
from Bio import SeqIO
# for the EC number column in each row, do this: split by ';', then remove those EC numbers that contain '-', if all EC numbers are removed, then return NaN
def filter_ec(ec):
    if pd.isnull(ec):
        return ec
    ec = ec.split(';')
    ec = [e.strip() for e in ec if '-' not in e]
    if len(ec) == 0:
        return None
    return ';'.join(ec)

records = SeqIO.parse('../data/swissprot/uniprot_sprot_10_1022_cleaned.fasta', 'fasta')
sprot_10_1022_pids = [record.id for record in records]

sprot_data = pd.read_csv('../data/swissprot/swissprot.tsv', sep='\t')
print(f'before filter: {sprot_data.shape} rows')
sprot_data = sprot_data[sprot_data['Entry'].isin(sprot_10_1022_pids)]
print(f'after filter: {sprot_data.shape} rows')
ec_data = sprot_data[['Entry', 'EC number', 'Date of creation']]
ec_data = ec_data.dropna().reset_index(drop=True) # 275533 rows × 3 columns
ec_data['EC number'] = ec_data['EC number'].apply(filter_ec)
ec_data = ec_data.dropna().reset_index(drop=True) # 236104 rows × 3 columns
display(ec_data)
ec_data.to_csv('../data/ec_new/swissprot_ec_complete.csv', index=False)

# filter rows with Date of creation by 2022-05-25, including 2022-05-25
ec_data['Date of creation'] = pd.to_datetime(ec_data['Date of creation'])
ec_data_before_2022April = ec_data[ec_data['Date of creation'] <= '2022-05-25'].reset_index(drop=True) 
ec_data_after_2022April = ec_data[ec_data['Date of creation'] > '2022-05-25'].reset_index(drop=True)
display(ec_data_before_2022April)
display(ec_data_after_2022April)
ec_data_before_2022April.to_csv('../data/ec_new/swissprot_ec_complete_by2022-05-25.csv', index=False)
ec_data_after_2022April.to_csv('../data/ec_new/swissprot_ec_complete_after_2022-05-25.csv', index=False)

ec_nums_before_2022April = ec_data_before_2022April['EC number'].tolist()
ec_nums_before_2022April = ';'.join(ec_nums_before_2022April)
ec_nums_before_2022April = list(set(ec_nums_before_2022April.split(';')))
with open('../data/ec_new/ec_list_by2022-05-25.json', 'w') as f:
    json.dump(ec_nums_before_2022April, f)

before filter: (570830, 8) rows
after filter: (551965, 8) rows


Unnamed: 0,Entry,EC number,Date of creation
0,A0A009IHW8,3.2.2.6,2020-02-26
1,A0A023I7E1,3.2.1.39,2022-12-14
2,A0A024SC78,3.1.1.74,2022-05-25
3,A0A024SH76,3.2.1.91,2017-08-30
4,A0A044RE18,3.4.21.75,2017-05-10
...,...,...,...
228546,Q6HX62,3.5.4.2,2004-08-16
228547,Q6L032,3.5.4.2,2004-08-16
228548,Q85055,2.7.7.48,2011-01-11
228549,Q94MV8,3.6.1.12,2004-08-16


Unnamed: 0,Entry,EC number,Date of creation
0,A0A009IHW8,3.2.2.6,2020-02-26
1,A0A024SC78,3.1.1.74,2022-05-25
2,A0A024SH76,3.2.1.91,2017-08-30
3,A0A044RE18,3.4.21.75,2017-05-10
4,A0A059TC02,1.2.1.44,2020-12-02
...,...,...,...
227834,Q6HX62,3.5.4.2,2004-08-16
227835,Q6L032,3.5.4.2,2004-08-16
227836,Q85055,2.7.7.48,2011-01-11
227837,Q94MV8,3.6.1.12,2004-08-16


Unnamed: 0,Entry,EC number,Date of creation
0,A0A023I7E1,3.2.1.39,2022-12-14
1,A0A061AE05,2.7.1.25;2.7.7.4,2023-02-22
2,A0A072VDF2,1.2.1.44,2023-02-22
3,A0A0D2Y5A7,2.3.1.12,2023-05-03
4,A0A0H2Z7X0,2.7.7.65,2023-06-28
...,...,...,...
707,Q2G5J4,1.14.14.181,2023-11-08
708,Q4J8K8,3.6.4.12,2024-01-24
709,Q5SKU3,4.2.1.17,2023-09-13
710,Q5VVH2,5.2.1.8,2023-02-22


In [2]:
import pandas as pd
import json

def all_in_list(ec_string, valid_list):
    ec_list = ec_string.split(';')
    return all(ec in valid_list for ec in ec_list)

with open('../data/ec_new/ec_list_by2022-05-25.json', 'r') as f:
    ec_nums_before_2022April = json.load(f)
ec_data_after_2022April = pd.read_csv('../data/ec_new/swissprot_ec_complete_after_2022-05-25.csv')
ec_data_after_2022April_filtered = ec_data_after_2022April[ec_data_after_2022April['EC number'].apply(all_in_list, valid_list=ec_nums_before_2022April)].reset_index(drop=True)
display(ec_data_after_2022April_filtered)
ec_data_after_2022April_filtered.to_csv('../data/ec_new/swissprot_ec_complete_after_2022-05-25_filtered.csv', index=False)

Unnamed: 0,Entry,EC number,Date of creation
0,A0A023I7E1,3.2.1.39,2022-12-14
1,A0A061AE05,2.7.1.25;2.7.7.4,2023-02-22
2,A0A072VDF2,1.2.1.44,2023-02-22
3,A0A0D2Y5A7,2.3.1.12,2023-05-03
4,A0A0H2Z7X0,2.7.7.65,2023-06-28
...,...,...,...
566,P9WES7,4.4.1.17,2022-10-12
567,Q4J8K8,3.6.4.12,2024-01-24
568,Q5SKU3,4.2.1.17,2023-09-13
569,Q5VVH2,5.2.1.8,2023-02-22


In [3]:
from tqdm.auto import tqdm
import os
import torch

def csv2pt(csv_path, pt_path, data_dir):
    df = pd.read_csv(csv_path)
    pids = df['Entry'].tolist()
    ecs = df['EC number'].tolist()
    data = {}
    for pid, ec in tqdm(list(zip(pids, ecs))):
        emb = torch.load(os.path.join(data_dir, f'{pid}.pt'))['mean_representations'][33]
        labels = ec.split(';')
        data[pid] = {'embedding': emb, 'ec': labels}
    print(f'data: {len(data)}')
    torch.save(data, pt_path)
    print(f'saved to {pt_path}')

data_dir = '../data/sprot_esm1b_emb_per_residue'
csv2pt('../data/ec_new/swissprot_ec_complete_by2022-05-25.csv', '../data/ec_new/swissprot_ec_complete_by2022-05-25.pt', data_dir)
csv2pt('../data/ec_new/swissprot_ec_complete_after_2022-05-25_filtered.csv', '../data/ec_new/swissprot_ec_complete_after_2022-05-25_filtered.pt', data_dir)        

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 227839/227839 [08:07<00:00, 466.98it/s]


data: 227839
saved to ../data/ec_new/swissprot_ec_complete_by2022-05-25.pt


100%|██████████| 571/571 [00:02<00:00, 269.99it/s]


data: 571
saved to ../data/ec_new/swissprot_ec_complete_after_2022-05-25_filtered.pt


In [2]:
ec_nums_before_2022April = ec_data_before_2022April['EC number'].tolist()
ec_nums_before_2022April = ';'.join(ec_nums_before_2022April)
ec_nums_before_2022April = list(set(ec_nums_before_2022April.split(';')))
print(f'num_ecs_before_2022April: {len(ec_nums_before_2022April)}')

ec_nums_after_2022April = ec_data_after_2022April['EC number'].tolist()
ec_nums_after_2022April = ';'.join(ec_nums_after_2022April)
ec_nums_after_2022April = list(set(ec_nums_after_2022April.split(';')))
print(f'num_ecs_after_2022April: {len(ec_nums_after_2022April)}')

diff_ecs = list(set(ec_nums_after_2022April) - set(ec_nums_before_2022April))
print(f'num_diff_ecs: {len(diff_ecs)}')
print(diff_ecs)

num_ecs_before_2022April: 5416
num_ecs_after_2022April: 513
num_diff_ecs: 105
['1.14.13.248', '1.14.11.71', '1.14.19.58', '2.1.1.380', '1.1.1.429', '4.2.1.180', '1.4.2.2', '4.3.99.5', '3.5.3.13', '1.14.13.246', '2.1.1.82', '2.4.1.358', '3.1.1.120', '2.1.1.76', '5.3.1.37', '2.1.1.383', '1.14.13.251', '4.2.3.160', '1.5.3.5', '1.14.15.38', '2.1.1.381', '1.13.11.89', '3.5.1.20', '2.2.1.14', '1.3.3.8', '4.2.3.177', '1.1.1.417', '4.8.1.7', '1.12.1.4', '2.5.1.102', '3.4.21.121', '6.7.1.1', '2.4.1.341', '6.2.1.60', '1.20.99.1', '4.1.1.115', '3.1.1.110', '1.17.99.8', '1.1.1.425', '1.14.13.250', '1.14.19.60', '3.1.6.19', '1.14.14.174', '3.5.1.68', '1.4.1.19', '4.2.3.200', '1.14.13.187', '1.14.99.55', '1.14.14.181', '5.4.99.64', '4.2.1.177', '5.1.1.14', '1.5.7.3', '2.5.1.38', '1.13.11.88', '4.8.1.8', '4.4.1.41', '1.14.19.48', '3.5.1.84', '1.1.1.432', '1.13.11.59', '1.1.1.435', '2.2.1.15', '3.2.1.213', '3.1.3.33', '2.3.1.145', '1.14.14.171', '1.14.13.249', '4.1.99.27', '1.14.11.72', '2.1.2.7', '3.

In [16]:
ec_nums = ec_data_before_2022April['EC number'].tolist()
ec_nums = ';'.join(ec_nums)
ec_nums = list(set(ec_nums.split(';')))
print(f'num_ecs: {len(ec_nums)}')

df_CLEAN = pd.read_csv('/work/jiaqi/CLEAN_original/app/data/split100.csv', sep='\t')
ec_nums_CLEAN = df_CLEAN['EC number'].tolist()
ec_nums_CLEAN = ';'.join(ec_nums_CLEAN)
ec_nums_CLEAN = list(set(ec_nums_CLEAN.split(';')))

diff_ecs = set(ec_nums) - set(ec_nums_CLEAN)
print(f'num_diff_ecs: {len(diff_ecs)}')
print(diff_ecs)

# P82604,4.8.1.4,2003-09-19 in ours
# P82604	4.99.1.7 in CLEAN split100
# filter out rows with EC number containing any of the EC numbers in diff_ecs
# diff_ec_data = ec_data_before_2022April[]

num_ecs: 5416
num_diff_ecs: 214
{'4.8.1.4', '5.1.1.11', '2.4.3.6', '2.3.1.237', '3.4.24.13', '2.3.1.306', '3.1.3.35', '6.2.1.68', '2.4.1.183', '1.14.14.22', '2.4.99.24', '3.6.1.75', '2.3.1.94', '7.2.2.18', '1.14.15.36', '2.1.1.386', '1.3.99.41', '2.3.1.38', '1.17.3.2', '2.4.3.9', '4.8.1.6', '1.14.13.200', '2.3.1.62', '3.4.21.122', '3.4.22.66', '4.98.1.1', '2.4.1.333', '4.3.2.11', '1.14.11.82', '1.1.1.78', '1.14.19.79', '7.2.2.22', '3.1.1.71', '4.1.1.93', '1.1.1.104', '3.4.23.47', '1.7.1.9', '1.14.15.26', '1.2.1.51', '2.7.7.108', '3.4.24.79', '1.1.1.71', '2.1.1.379', '6.1.1.24', '2.1.1.376', '2.3.3.21', '2.7.1.237', '3.2.1.97', '6.3.2.50', '1.5.3.24', '2.7.1.52', '2.4.99.23', '3.7.1.28', '3.2.1.187', '1.14.14.180', '2.6.1.28', '7.2.3.1', '1.14.14.184', '5.6.2.3', '2.7.1.93', '1.14.11.68', '6.2.1.67', '2.3.1.308', '3.1.26.13', '2.7.1.154', '3.1.3.52', '3.4.21.98', '3.4.22.47', '3.4.24.61', '6.3.2.40', '1.10.3.17', '3.2.2.7', '1.13.11.92', '2.3.1.165', '4.6.1.22', '3.4.19.6', '2.7.1.235',

In [7]:
ec_nums_after_filtered = ec_data_after_2022April_filtered['EC number'].tolist()
ec_nums_after_filtered = ';'.join(ec_nums_after_filtered)
ec_nums_after_filtered = list(set(ec_nums_after_filtered.split(';')))

diff = set(ec_nums_after_filtered) - set(ec_nums_before_2022April)
diff

set()

In [14]:
import pandas as pd

df = pd.read_csv('/work/jiaqi/CLEAN_original/app/data/split100.csv', sep='\t')
sequences = df['Sequence'].tolist()
seq_lens = [len(seq) for seq in sequences]
print(f'min seq len: {min(seq_lens)}, max seq len: {max(seq_lens)}')
num_le10 = sum(1 for l in seq_lens if l < 10)
print(f'num seq len < 10: {num_le10}')

min seq len: 4, max seq len: 1022
num seq len < 10: 43


In [13]:
recs = SeqIO.parse('/work/jiaqi/ProtRepr/data/swissprot/uniprot_sprot_10_1022_cleaned.fasta', 'fasta')
seq_lens = [len(rec.seq) for rec in recs]
print(f'min seq len: {min(seq_lens)}, max seq len: {max(seq_lens)}')

min seq len: 10, max seq len: 1022


In [18]:
with open('../data/ec_new/ec_list_before_2022April.json') as f:
    label_list = json.load(f)
len(label_list)

5321

In [4]:
import torch
import random

random.seed(42)
data = torch.load('../data/ec_new/swissprot_ec_complete_by2022-05-25.pt')
print(f'data: {len(data)}')
pids = list(data.keys())
random.shuffle(pids)
train_pids = pids[:int(0.9*len(pids))]
val_pids = pids[int(0.9*len(pids)):]
train_data = {pid: data[pid] for pid in train_pids}
val_data = {pid: data[pid] for pid in val_pids}
print(f'train: {len(train_data)}, val: {len(val_data)}')
torch.save(train_data, '../data/ec_new/swissprot_ec_complete_by2022-05-25_train.pt')
torch.save(val_data, '../data/ec_new/swissprot_ec_complete_by2022-05-25_val.pt')

data: 227839
train: 205055, val: 22784


In [5]:
import pandas as pd
from Bio import SeqIO

train_data = pd.read_csv('../data/ec_new/swissprot_ec_complete_by2022-05-25.csv')
test_data = pd.read_csv('../data/ec_new/swissprot_ec_complete_after_2022-05-25_filtered.csv')
records = SeqIO.parse('../data/swissprot/uniprot_sprot_10_1022_cleaned.fasta', 'fasta')
pid2seq = {record.id: str(record.seq) for record in records}
train_data['Sequence'] = train_data['Entry'].apply(lambda x: pid2seq[x])
test_data['Sequence'] = test_data['Entry'].apply(lambda x: pid2seq[x])
train_data = train_data.drop(columns=['Date of creation'])
test_data = test_data.drop(columns=['Date of creation'])
train_data.to_csv('/work/jiaqi/CLEAN_original/app/data/swissprot_ec_complete_by2022-05-25.csv', index=False, sep='\t')
test_data.to_csv('/work/jiaqi/CLEAN_original/app/data/swissprot_ec_complete_after_2022-05-25_filtered.csv', index=False, sep='\t')


In [6]:
test_pids = test_data['Entry'].tolist()
test_seqs = test_data['Sequence'].tolist()
test_fasta = '\n'.join([f'>{pid}\n{seq}' for pid, seq in zip(test_pids, test_seqs)])
with open('/work/jiaqi/CLEAN_original/app/data/inputs/swissprot_ec_complete_after_2022-05-25_filtered.fasta', 'w') as f:
    f.write(test_fasta)


In [7]:
train_pids = train_data['Entry'].tolist()
test_pids = test_data['Entry'].tolist()
pids = train_pids + test_pids
num_cp = 0
for pid in pids:
    if not os.path.exists(f'/work/jiaqi/CLEAN_original/app/data/esm_data/{pid}.pt'):
        os.system(f'cp /work/jiaqi/ProtRepr/data/sprot_esm1b_emb_per_residue/{pid}.pt /work/jiaqi/CLEAN_original/app/data/esm_data/{pid}.pt')
        num_cp += 1
print(f'num_cp: {num_cp}')

num_cp: 27
