In [1]:
import pandas as pd
from Bio import SeqIO
import re
import json

def extract_four_digits(s):
    # Define the regex pattern for EC numbers
    # This pattern looks for sequences of digits separated by periods, occurring at least three times.
    pattern = r'\b\d+\.\d+\.\d+\.\d+\b'
    
    # Find all matches of the pattern in the string
    matches = re.findall(pattern, s)
    
    return ';'.join(matches)

records = SeqIO.parse('../data/swissprot/uniprot_sprot_10_1022_cleaned.fasta', 'fasta')
sprot_10_1022_pids = [record.id for record in records]

sp_df = pd.read_csv('../data/swissprot/swissprot.tsv', sep='\t')
print(len(sp_df))
df_gene3D = sp_df[['Entry', 'Date of creation', 'Gene3D']].dropna().reset_index(drop=True)
print(len(df_gene3D))
df_gene3D = df_gene3D[df_gene3D['Entry'].isin(sprot_10_1022_pids)].reset_index(drop=True)
# display(df_gene3D)
df_gene3D['Gene3D'] = df_gene3D['Gene3D'].apply(extract_four_digits)
display(df_gene3D)

# filter rows with Date of creation before 2022-05-25
df_gene3D['Date of creation'] = pd.to_datetime(df_gene3D['Date of creation'])
gene3D_data_before_2022April = df_gene3D[df_gene3D['Date of creation'] <= '2022-05-25'].reset_index(drop=True) 
gene3D_data_after_2022April = df_gene3D[df_gene3D['Date of creation'] > '2022-05-25'].reset_index(drop=True)
display(gene3D_data_before_2022April)
display(gene3D_data_after_2022April)
gene3D_data_before_2022April.to_csv('../data/gene3D_new/swissprot_gene3D_by2022-05-25.csv', index=False)
gene3D_data_after_2022April.to_csv('../data/gene3D_new/swissprot_gene3D_after_2022-05-25.csv', index=False)

gene3D_labels_before_2022April = gene3D_data_before_2022April['Gene3D'].tolist()
gene3D_labels_before_2022April = ';'.join(gene3D_labels_before_2022April).split(';')
gene3D_labels_before_2022April = list(set(gene3D_labels_before_2022April))
print(f'Number of unique Gene3D labels before 2022-05-25: {len(gene3D_labels_before_2022April)}')
with open('../data/gene3D_new/gene3D_labels_by2022-05-25.json', 'w') as f:
    json.dump(gene3D_labels_before_2022April, f)

570830
458752


Unnamed: 0,Entry,Date of creation,Gene3D
0,A0A009IHW8,2020-02-26,3.40.50.10140
1,A0A023I7E1,2022-12-14,1.10.287.1170;1.20.5.420
2,A0A024SC78,2022-05-25,3.40.50.1820
3,A0A024SH76,2017-08-30,3.20.20.40
4,A0A044RE18,2017-05-10,2.60.120.260;3.30.70.850;3.40.50.200
...,...,...,...
442833,Q9ZVR0,2007-01-23,1.20.1280.50
442834,Q9ZW38,2007-04-03,2.120.10.80
442835,Q9ZWC6,2007-01-23,3.80.10.10
442836,U3H0A9,2014-05-14,3.90.1840.10


Unnamed: 0,Entry,Date of creation,Gene3D
0,A0A009IHW8,2020-02-26,3.40.50.10140
1,A0A024SC78,2022-05-25,3.40.50.1820
2,A0A024SH76,2017-08-30,3.20.20.40
3,A0A044RE18,2017-05-10,2.60.120.260;3.30.70.850;3.40.50.200
4,A0A059TC02,2020-12-02,3.40.50.720
...,...,...,...
440738,Q9ZVR0,2007-01-23,1.20.1280.50
440739,Q9ZW38,2007-04-03,2.120.10.80
440740,Q9ZWC6,2007-01-23,3.80.10.10
440741,U3H0A9,2014-05-14,3.90.1840.10


Unnamed: 0,Entry,Date of creation,Gene3D
0,A0A023I7E1,2022-12-14,1.10.287.1170;1.20.5.420
1,A0A061AE05,2023-02-22,3.40.50.620;3.40.50.300;3.10.400.10
2,A0A068Q5Q5,2023-09-13,2.160.20.10
3,A0A072VDF2,2023-02-22,3.40.50.720
4,A0A075D5I4,2023-06-28,3.40.50.150
...,...,...,...
2090,A0A7H0DNG7,2023-02-22,1.10.437.20
2091,A0A7H0DNG9,2023-02-22,1.25.40.20
2092,A0A8V8TMC4,2023-09-13,1.10.472.10
2093,O53518,2022-10-12,2.60.40.790;3.40.50.300


Number of unique Gene3D labels before 2022-05-25: 4927


In [2]:
import pandas as pd
import json

def all_in_list(label_string, valid_list):
    label_list = label_string.split(';')
    return all(label in valid_list for label in label_list)

test_df = pd.read_csv('../data/gene3D_new/swissprot_gene3D_after_2022-05-25.csv')
with open('../data/gene3D_new/gene3D_labels_by2022-05-25.json', 'r') as f:
    gene3D_labels_before_2022April = json.load(f)
test_df_filtered = test_df[test_df['Gene3D'].apply(lambda x: all_in_list(x, gene3D_labels_before_2022April))].reset_index(drop=True)
test_df_filtered.to_csv('../data/gene3D_new/swissprot_gene3D_after_2022-05-25_filtered.csv', index=False)
test_labels = test_df_filtered['Gene3D'].tolist()
test_labels = ';'.join(test_labels).split(';')
test_labels = list(set(test_labels))
diff = set(test_labels) - set(gene3D_labels_before_2022April)
print(f'Number of unique Gene3D labels after 2022-05-25: {len(test_labels)}')
print(f'Number of unique Gene3D labels after 2022-05-25 that are not in the labels before 2022-05-25: {len(diff)}')

Number of unique Gene3D labels after 2022-05-25: 750
Number of unique Gene3D labels after 2022-05-25 that are not in the labels before 2022-05-25: 0


In [3]:
from tqdm.auto import tqdm
import os
import torch

def csv2pt(csv_path, pt_path, data_dir):
    df = pd.read_csv(csv_path)
    pids = df['Entry'].tolist()
    ecs = df['Gene3D'].tolist()
    data = {}
    for pid, ec in tqdm(list(zip(pids, ecs))):
        emb = torch.load(os.path.join(data_dir, f'{pid}.pt'))['mean_representations'][33]
        labels = ec.split(';')
        data[pid] = {'embedding': emb, 'gene3D': labels}
    print(f'data: {len(data)}')
    torch.save(data, pt_path)
    print(f'saved to {pt_path}')

data_dir = '../data/sprot_esm1b_emb_per_residue'
csv2pt('../data/gene3D_new/swissprot_gene3D_by2022-05-25.csv', '../data/gene3D_new/swissprot_gene3D_by2022-05-25.pt', data_dir)
csv2pt('../data/gene3D_new/swissprot_gene3D_after_2022-05-25_filtered.csv', '../data/gene3D_new/swissprot_gene3D_after_2022-05-25_filtered.pt', data_dir) 

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 440743/440743 [14:53<00:00, 493.48it/s]


data: 440743
saved to ../data/gene3D_new/swissprot_gene3D_by2022-05-25.pt


100%|██████████| 2058/2058 [00:04<00:00, 485.41it/s]

data: 2058
saved to ../data/gene3D_new/swissprot_gene3D_after_2022-05-25_filtered.pt





In [4]:
import torch
import random

random.seed(42)

data = torch.load('../data/gene3D_new/swissprot_gene3D_by2022-05-25.pt')
pids = list(data.keys())
random.shuffle(pids)
train_pids = pids[:int(0.9*len(pids))]
val_pids = pids[int(0.9*len(pids)):]
train_data = {pid: data[pid] for pid in train_pids}
val_data = {pid: data[pid] for pid in val_pids}
torch.save(train_data, '../data/gene3D_new/swissprot_gene3D_by2022-05-25_train.pt')
torch.save(val_data, '../data/gene3D_new/swissprot_gene3D_by2022-05-25_val.pt')