In [2]:
import pandas as pd

all_data = pd.read_csv('../data/swissprot/swissprot.tsv', sep='\t')
# display(all_data)
pfam_data = all_data[['Entry', 'Pfam']]
pfam_data = pfam_data.dropna()
display(pfam_data)

Unnamed: 0,Entry,Pfam
0,A0A009IHW8,PF13676; TIR_2; 1.;
1,A0A023I7E1,"PF17652; Glyco_hydro81C; 1.;""PF03639; Glyco_hy..."
2,A0A024B7W1,"PF20907; Flav_NS3-hel_C; 1.;""PF01003; Flavi_ca..."
3,A0A024SC78,PF01083; Cutinase; 1.;
4,A0A024SH76,"PF00734; CBM_1; 1.;""PF01341; Glyco_hydro_6; 1."";"
...,...,...
570807,Q9ZVR3,PF14299; PP2; 1.;
570808,Q9ZW38,PF01344; Kelch_1; 2.;
570809,Q9ZWC6,"PF00646; F-box; 1.;""PF13516; LRR_6; 1."";"
570810,Q9ZWX5,PF05322; NinE; 1.;


In [6]:
import re

def find_pfam(s):
    # Define the pattern: 'PF' followed by one or more digits
    pattern = r'PF\d+'
    # Find all non-overlapping matches of the pattern in the string
    pfam_list = re.findall(pattern, s)
    return ';'.join(pfam_list)

pfam_data['Pfam'] = pfam_data['Pfam'].apply(find_pfam)
display(pfam_data)

Unnamed: 0,Entry,Pfam
0,A0A009IHW8,PF13676
1,A0A023I7E1,PF17652;PF03639
2,A0A024B7W1,PF20907;PF01003;PF07652;PF21659;PF02832;PF0086...
3,A0A024SC78,PF01083
4,A0A024SH76,PF00734;PF01341
...,...,...
570807,Q9ZVR3,PF14299
570808,Q9ZW38,PF01344
570809,Q9ZWC6,PF00646;PF13516
570810,Q9ZWX5,PF05322


In [13]:
pfam_data_single_label = pfam_data[~pfam_data['Pfam'].str.contains(';')]
display(pfam_data_single_label)
pfam_data_single_label.to_csv('../data/pfam/sprot_pfam_single_label.csv', index=False)
pfam_labels = pfam_data_single_label['Pfam'].tolist()
pfam_list = []
for label in pfam_labels:
    pfam_list.extend(label.split(';'))
pfam_list = list(set(pfam_list))
print(len(pfam_list))
pfam2occurence = {}
for labels in pfam_labels:
    for label in labels.split(';'):
        if label in pfam2occurence:
            pfam2occurence[label] += 1
        else:
            pfam2occurence[label] = 1
pfam_data_single_label_above10 = pfam_data_single_label[pfam_data_single_label['Pfam'].apply(lambda x: pfam2occurence[x] > 10)]
display(pfam_data_single_label_above10)

Unnamed: 0,Entry,Pfam
0,A0A009IHW8,PF13676
3,A0A024SC78,PF01083
5,A0A026W182,PF02949
7,A0A059TC02,PF01370
8,A0A060A682,PF10699
...,...,...
570806,Q9ZVR1,PF14299
570807,Q9ZVR3,PF14299
570808,Q9ZW38,PF01344
570810,Q9ZWX5,PF05322


10409


Unnamed: 0,Entry,Pfam
0,A0A009IHW8,PF13676
3,A0A024SC78,PF01083
5,A0A026W182,PF02949
7,A0A059TC02,PF01370
11,A0A061I403,PF02661
...,...,...
570800,Q9ZVI6,PF04525
570801,Q9ZVJ2,PF01871
570806,Q9ZVR1,PF14299
570807,Q9ZVR3,PF14299


In [1]:
import pandas as pd
import torch
from tqdm.auto import tqdm

pfam_data = pd.read_csv('../data/pfam/sprot_pfam_single_label.csv')
pfam2occurence = {}
pids, labels = pfam_data['Entry'].tolist(), pfam_data['Pfam'].tolist()
for pid, label in zip(pids, labels):
    if label not in pfam2occurence:
        pfam2occurence[label] = 0
    pfam2occurence[label] += 1
# filter out rows with more than 10 occurence
pfam_data_filtered = pfam_data[pfam_data['Pfam'].apply(lambda x: pfam2occurence[x] > 10)]

pid2embeddings = torch.load('../data/embeddings/sprot_10_1022_esm2_t33.pt')
print(len(pid2embeddings))
pid_with_emb = list(pid2embeddings.keys())
pid2pfam = {pid: pfam for pid, pfam in zip(pfam_data_filtered['Entry'].tolist(), pfam_data_filtered['Pfam'].tolist())}
data = {}
for pid, pfam in tqdm(pid2pfam.items()):
    try:
        data[pid] = {'embedding': pid2embeddings[pid], 'pfam': [pfam]}
    except:
        continue
print(len(data))
torch.save(data, '../data/pfam/sprot_10_1022_esm2_t33_pfam_above_10_single_label.pt')

551965


  0%|          | 0/322789 [00:00<?, ?it/s]

320616


In [2]:
# split train test val
import torch
import random

random.seed(42)
pfam_data = torch.load('../data/pfam/sprot_10_1022_esm2_t33_pfam_above_10_single_label.pt')
pids = list(pfam_data.keys())
random.shuffle(pids)
train_data = {pid: pfam_data[pid] for pid in pids[:int(0.8 * len(pids))]}
val_data = {pid: pfam_data[pid] for pid in pids[int(0.8 * len(pids)):int(0.9 * len(pids))]}
test_data = {pid: pfam_data[pid] for pid in pids[int(0.9 * len(pids)):]}
print(len(train_data), len(val_data), len(test_data))
torch.save(train_data, '../data/pfam/sprot_10_1022_esm2_t33_pfam_above_10_single_label_train.pt')
torch.save(val_data, '../data/pfam/sprot_10_1022_esm2_t33_pfam_above_10_single_label_val.pt')
torch.save(test_data, '../data/pfam/sprot_10_1022_esm2_t33_pfam_above_10_single_label_test.pt')

256492 32062 32062


In [3]:
import torch
import json

pfam_data = torch.load('../data/pfam/sprot_10_1022_esm2_t33_pfam_above_10_single_label.pt')
labels = [entry['pfam'][0] for entry in pfam_data.values()]
labels = list(set(labels))
print(len(labels))
with open('../data/pfam/pfam_single_label_list.json', 'w') as f:
    json.dump(labels, f)

3222
