In [1]:
import pandas as pd
from Bio import SeqIO
import re
import json

def extract_pf_codes(s):
    # Define the regex pattern for codes starting with "PF" followed by digits
    pattern = r'PF\d+'
    
    # Find all matches of the pattern in the string
    matches = re.findall(pattern, s)
    
    return ';'.join(matches)

records = SeqIO.parse('../data/swissprot/uniprot_sprot_10_1022_cleaned.fasta', 'fasta')
sprot_10_1022_pids = [record.id for record in records]

sp_df = pd.read_csv('../data/swissprot/swissprot.tsv', sep='\t')
print(len(sp_df))
pfam_data = sp_df[['Entry', 'Date of creation', 'Pfam']].dropna().reset_index(drop=True)
print(len(pfam_data))
pfam_data = pfam_data[pfam_data['Entry'].isin(sprot_10_1022_pids)]
print(len(pfam_data))
pfam_data['Pfam'] = pfam_data['Pfam'].apply(extract_pf_codes)
display(pfam_data)

# filter rows with Date of creation before 2022-05-25, not including 2022-05-25
pfam_data['Date of creation'] = pd.to_datetime(pfam_data['Date of creation'])
pfam_data_before_2022April = pfam_data[pfam_data['Date of creation'] <= '2022-05-25'].reset_index(drop=True) 
pfam_data_after_2022April = pfam_data[pfam_data['Date of creation'] > '2022-05-25'].reset_index(drop=True)
display(pfam_data_before_2022April)
display(pfam_data_after_2022April)
pfam_data_before_2022April.to_csv('../data/pfam_new/swissprot_pfam_by2022-05-25.csv', index=False)
pfam_data_after_2022April.to_csv('../data/pfam_new/swissprot_pfam_after_2022-05-25.csv', index=False)

pfam_labels_before_2022April = pfam_data_before_2022April['Pfam'].tolist()
pfam_labels_before_2022April = ';'.join(pfam_labels_before_2022April).split(';')
pfam_labels_before_2022April = list(set(pfam_labels_before_2022April))
print(f'Number of unique pfam labels by 2022-05-25: {len(pfam_labels_before_2022April)}')
with open('../data/pfam_new/pfam_labels_by2022-05-25.json', 'w') as f:
    json.dump(pfam_labels_before_2022April, f)

570830
540554
523142


Unnamed: 0,Entry,Date of creation,Pfam
0,A0A009IHW8,2020-02-26,PF13676
1,A0A023I7E1,2022-12-14,PF17652;PF03639
3,A0A024SC78,2022-05-25,PF01083
4,A0A024SH76,2017-08-30,PF00734;PF01341
5,A0A026W182,2017-10-25,PF02949
...,...,...,...
540549,Q9ZVR3,2007-05-01,PF14299
540550,Q9ZW38,2007-04-03,PF01344
540551,Q9ZWC6,2007-01-23,PF00646;PF13516
540552,Q9ZWX5,2000-12-01,PF05322


Unnamed: 0,Entry,Date of creation,Pfam
0,A0A009IHW8,2020-02-26,PF13676
1,A0A024SC78,2022-05-25,PF01083
2,A0A024SH76,2017-08-30,PF00734;PF01341
3,A0A026W182,2017-10-25,PF02949
4,A0A044RE18,2017-05-10,PF01483;PF00082;PF16470
...,...,...,...
520631,Q9ZVR3,2007-05-01,PF14299
520632,Q9ZW38,2007-04-03,PF01344
520633,Q9ZWC6,2007-01-23,PF00646;PF13516
520634,Q9ZWX5,2000-12-01,PF05322


Unnamed: 0,Entry,Date of creation,Pfam
0,A0A023I7E1,2022-12-14,PF17652;PF03639
1,A0A061AE05,2023-02-22,PF01583;PF01747;PF14306
2,A0A072VDF2,2023-02-22,PF01370
3,A0A075D5I4,2023-06-28,PF13847
4,A0A075D657,2023-06-28,PF08241
...,...,...,...
2501,A0A8V8TPE2,2024-01-24,PF15288
2502,M1L9M3,2023-02-22,PF05796
2503,O53518,2022-10-12,PF02374;PF17886
2504,P71839,2022-08-03,PF13483


Number of unique pfam labels by 2022-05-25: 14723


In [2]:
import pandas as pd
import json

def all_in_list(label_string, valid_list):
    label_list = label_string.split(';')
    return all(label in valid_list for label in label_list)

test_df = pd.read_csv('../data/pfam_new/swissprot_pfam_after_2022-05-25.csv')
with open('../data/pfam_new/pfam_labels_by2022-05-25.json', 'r') as f:
    pfam_labels_before_2022April = json.load(f)
test_df_filtered = test_df[test_df['Pfam'].apply(lambda x: all_in_list(x, pfam_labels_before_2022April))].reset_index(drop=True)
test_df_filtered.to_csv('../data/pfam_new/swissprot_pfam_after_2022-05-25_filtered.csv', index=False)
display(test_df_filtered)
test_labels = test_df_filtered['Pfam'].tolist()
test_labels = ';'.join(test_labels).split(';')
test_labels = list(set(test_labels))
diff = set(test_labels) - set(pfam_labels_before_2022April)
print(f'Number of unique pfam labels after 2022-05-25: {len(test_labels)}')
print(f'Number of unique pfam labels after 2022-05-25 that are not in the labels before 2022-05-25: {len(diff)}')

Unnamed: 0,Entry,Date of creation,Pfam
0,A0A023I7E1,2022-12-14,PF17652;PF03639
1,A0A061AE05,2023-02-22,PF01583;PF01747;PF14306
2,A0A072VDF2,2023-02-22,PF01370
3,A0A075D5I4,2023-06-28,PF13847
4,A0A075D657,2023-06-28,PF08241
...,...,...,...
2400,A0A8V8TPE2,2024-01-24,PF15288
2401,M1L9M3,2023-02-22,PF05796
2402,O53518,2022-10-12,PF02374;PF17886
2403,P71839,2022-08-03,PF13483


Number of unique pfam labels after 2022-05-25: 1446
Number of unique pfam labels after 2022-05-25 that are not in the labels before 2022-05-25: 0


In [3]:
from tqdm.auto import tqdm
import os
import torch

def csv2pt(csv_path, pt_path, data_dir):
    df = pd.read_csv(csv_path)
    pids = df['Entry'].tolist()
    ecs = df['Pfam'].tolist()
    data = {}
    for pid, ec in tqdm(list(zip(pids, ecs))):
        emb = torch.load(os.path.join(data_dir, f'{pid}.pt'))['mean_representations'][33]
        labels = ec.split(';')
        data[pid] = {'embedding': emb, 'pfam': labels}
    print(f'data: {len(data)}')
    torch.save(data, pt_path)
    print(f'saved to {pt_path}')

data_dir = '../data/sprot_esm1b_emb_per_residue'
csv2pt('../data/pfam_new/swissprot_pfam_by2022-05-25.csv', '../data/pfam_new/swissprot_pfam_by2022-05-25.pt', data_dir)
csv2pt('../data/pfam_new/swissprot_pfam_after_2022-05-25_filtered.csv', '../data/pfam_new/swissprot_pfam_after_2022-05-25_filtered.pt', data_dir) 

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 520636/520636 [13:52<00:00, 625.17it/s] 


data: 520636
saved to ../data/pfam_new/swissprot_pfam_by2022-05-25.pt


100%|██████████| 2405/2405 [00:01<00:00, 1223.74it/s]


data: 2405
saved to ../data/pfam_new/swissprot_pfam_after_2022-05-25_filtered.pt


In [4]:
import torch
import random

random.seed(42)
data = torch.load('../data/pfam_new/swissprot_pfam_by2022-05-25.pt')
pids = list(data.keys())
random.shuffle(pids)
train_pids = pids[:int(0.9*len(pids))]
val_pids = pids[int(0.9*len(pids)):]
train_data = {pid: data[pid] for pid in train_pids}
val_data = {pid: data[pid] for pid in val_pids}
torch.save(train_data, '../data/pfam_new/swissprot_pfam_by2022-05-25_train.pt')
torch.save(val_data, '../data/pfam_new/swissprot_pfam_by2022-05-25_val.pt')

In [5]:
330362957 - 15001 * 14717

109593240