In [1]:
import pandas as pd
from Bio import SeqIO
import re
import json

def extract_go_codes(go_string):
    go_codes = re.findall(r'GO:\d+', go_string)
    return ';'.join(go_codes)

records = SeqIO.parse('../data/swissprot/uniprot_sprot_10_1022_cleaned.fasta', 'fasta')
sprot_10_1022_pids = [record.id for record in records]

sp_df = pd.read_csv('../data/swissprot/swissprot.tsv', sep='\t')
print(len(sp_df))
GO_data = sp_df[['Entry', 'Date of creation', 'Gene Ontology IDs']].dropna().reset_index(drop=True)
print(len(GO_data))
GO_data = GO_data[GO_data['Entry'].isin(sprot_10_1022_pids)]
print(len(GO_data))
GO_data['Gene Ontology IDs'] = GO_data['Gene Ontology IDs'].apply(extract_go_codes)
display(GO_data)

# filter rows with Date of creation before 2022-04-01, not including 2022-04-01
GO_data['Date of creation'] = pd.to_datetime(GO_data['Date of creation'])
GO_data_before_2022April = GO_data[GO_data['Date of creation'] < '2022-04-01'].reset_index(drop=True) 
GO_data_after_2022April = GO_data[GO_data['Date of creation'] >= '2022-04-01'].reset_index(drop=True)
display(GO_data_before_2022April)
display(GO_data_after_2022April)
GO_data_before_2022April.to_csv('../data/GO_new/swissprot_GO_before_2022April.csv', index=False)
GO_data_after_2022April.to_csv('../data/GO_new/swissprot_GO_after_2022April.csv', index=False)

GO_labels_before_2022April = GO_data_before_2022April['Gene Ontology IDs'].tolist()
GO_labels_before_2022April = ';'.join(GO_labels_before_2022April).split(';')
GO_labels_before_2022April = list(set(GO_labels_before_2022April))
print(f'Number of unique GO labels before 2022-04-01: {len(GO_labels_before_2022April)}')
with open('../data/GO_new/GO_labels_before_2022April.json', 'w') as f:
    json.dump(GO_labels_before_2022April, f)

570830
547026
528507


Unnamed: 0,Entry,Date of creation,Gene Ontology IDs
0,A0A009IHW8,2020-02-26,GO:0003953;GO:0007165;GO:0019677;GO:0050135;GO...
1,A0A023I7E1,2022-12-14,GO:0000272;GO:0005576;GO:0042973;GO:0052861;GO...
3,A0A024SC78,2022-05-25,GO:0005576;GO:0050525
4,A0A024SH76,2017-08-30,GO:0005576;GO:0016162;GO:0030245;GO:0030248
5,A0A026W182,2017-10-25,GO:0004984;GO:0005549;GO:0005886;GO:0007165;GO...
...,...,...,...
547021,Q9ZZX9,2007-09-11,GO:0005739
547022,T2KN80,2019-10-16,GO:0042597
547023,V5XVW4,2017-03-15,GO:0019028
547024,V5XWI9,2017-03-15,GO:0019028


Unnamed: 0,Entry,Date of creation,Gene Ontology IDs
0,A0A009IHW8,2020-02-26,GO:0003953;GO:0007165;GO:0019677;GO:0050135;GO...
1,A0A024SH76,2017-08-30,GO:0005576;GO:0016162;GO:0030245;GO:0030248
2,A0A026W182,2017-10-25,GO:0004984;GO:0005549;GO:0005886;GO:0007165;GO...
3,A0A044RE18,2017-05-10,GO:0004252;GO:0005576;GO:0031638;GO:0046872;GO...
4,A0A059TC02,2020-12-02,GO:0000166;GO:0005737;GO:0007623;GO:0009699;GO...
...,...,...,...
525118,Q9ZZX9,2007-09-11,GO:0005739
525119,T2KN80,2019-10-16,GO:0042597
525120,V5XVW4,2017-03-15,GO:0019028
525121,V5XWI9,2017-03-15,GO:0019028


Unnamed: 0,Entry,Date of creation,Gene Ontology IDs
0,A0A023I7E1,2022-12-14,GO:0000272;GO:0005576;GO:0042973;GO:0052861;GO...
1,A0A024SC78,2022-05-25,GO:0005576;GO:0050525
2,A0A061AE05,2023-02-22,GO:0000103;GO:0004020;GO:0004781;GO:0005524;GO...
3,A0A068Q5Q5,2023-09-13,GO:0016829;GO:0098015;GO:0098671;GO:0098994;GO...
4,A0A072VDF2,2023-02-22,GO:0005737;GO:0009699;GO:0009809;GO:0016616;GO...
...,...,...,...
3379,P0DQY5,2023-05-03,GO:0005576
3380,P0DV55,2022-05-25,GO:0005737
3381,P0DW86,2022-10-12,GO:0016020
3382,P0DW90,2022-10-12,GO:0016020


Number of unique GO labels before 2022-04-01: 28130


In [2]:
import pandas as pd
import json

def all_in_list(label_string, valid_list):
    label_list = label_string.split(';')
    return all(label in valid_list for label in label_list)

test_df = pd.read_csv('../data/GO_new/swissprot_GO_after_2022April.csv')
with open('../data/GO_new/GO_labels_before_2022April.json', 'r') as f:
    GO_labels_before_2022April = json.load(f)
test_df_filtered = test_df[test_df['Gene Ontology IDs'].apply(lambda x: all_in_list(x, GO_labels_before_2022April))].reset_index(drop=True)
test_df_filtered.to_csv('../data/GO_new/swissprot_GO_after_2022April_filtered.csv', index=False)
display(test_df_filtered)
test_labels = test_df_filtered['Gene Ontology IDs'].tolist()
test_labels = ';'.join(test_labels).split(';')
test_labels = list(set(test_labels))
diff = set(test_labels) - set(GO_labels_before_2022April)
print(f'Number of unique GO labels after 2022-04-01: {len(test_labels)}')
print(f'Number of unique GO labels after 2022-04-01 that are not in the labels before 2022-04-01: {len(diff)}')

Unnamed: 0,Entry,Date of creation,Gene Ontology IDs
0,A0A023I7E1,2022-12-14,GO:0000272;GO:0005576;GO:0042973;GO:0052861;GO...
1,A0A024SC78,2022-05-25,GO:0005576;GO:0050525
2,A0A061AE05,2023-02-22,GO:0000103;GO:0004020;GO:0004781;GO:0005524;GO...
3,A0A068Q5Q5,2023-09-13,GO:0016829;GO:0098015;GO:0098671;GO:0098994;GO...
4,A0A072VDF2,2023-02-22,GO:0005737;GO:0009699;GO:0009809;GO:0016616;GO...
...,...,...,...
3283,P0DQY5,2023-05-03,GO:0005576
3284,P0DV55,2022-05-25,GO:0005737
3285,P0DW86,2022-10-12,GO:0016020
3286,P0DW90,2022-10-12,GO:0016020


Number of unique GO labels after 2022-04-01: 3493
Number of unique GO labels after 2022-04-01 that are not in the labels before 2022-04-01: 0


In [3]:
from tqdm.auto import tqdm
import os
import torch

def csv2pt(csv_path, pt_path, data_dir):
    df = pd.read_csv(csv_path)
    pids = df['Entry'].tolist()
    ecs = df['Gene Ontology IDs'].tolist()
    data = {}
    for pid, ec in tqdm(list(zip(pids, ecs))):
        emb = torch.load(os.path.join(data_dir, f'{pid}.pt'))['mean_representations'][33]
        labels = ec.split(';')
        data[pid] = {'embedding': emb, 'GO': labels}
    print(f'data: {len(data)}')
    torch.save(data, pt_path)
    print(f'saved to {pt_path}')

data_dir = '../data/sprot_esm1b_emb_per_residue'
csv2pt('../data/GO_new/swissprot_GO_before_2022April.csv', '../data/GO_new/swissprot_GO_before_2022April.pt', data_dir)
csv2pt('../data/GO_new/swissprot_GO_after_2022April_filtered.csv', '../data/GO_new/swissprot_GO_after_2022April_filtered.pt', data_dir) 

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 525123/525123 [18:09<00:00, 482.19it/s]


data: 525123
saved to ../data/GO_new/swissprot_GO_before_2022April.pt


100%|██████████| 3288/3288 [00:06<00:00, 500.61it/s]


data: 3288
saved to ../data/GO_new/swissprot_GO_after_2022April_filtered.pt


In [4]:
import torch
import random

random.seed(42)
data = torch.load('../data/GO_new/swissprot_GO_before_2022April.pt')
pids = list(data.keys())
random.shuffle(pids)
train_pids = pids[:int(0.9*len(pids))]
val_pids = pids[int(0.9*len(pids)):]
train_data = {pid: data[pid] for pid in train_pids}
val_data = {pid: data[pid] for pid in val_pids}
print(f'train_data: {len(train_data)}')
print(f'val_data: {len(val_data)}')
torch.save(train_data, '../data/GO_new/swissprot_GO_before_2022April_train.pt')
torch.save(val_data, '../data/GO_new/swissprot_GO_before_2022April_val.pt')

train_data: 472610
val_data: 52513
