In [7]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from Bio import SeqIO

# Extract sequence IDs from the FASTA file. 

Some IDs have improper headers, since they contain spaces. Since we will be doing clustering, we will extract the ID of the sequence from the header (last value in the | brackets), and save the corrected fasta file.`

In [None]:
import re

seq_iterator = SeqIO.parse(open("../data/phosphosite_sequences/Phosphosite_seq.fasta"), 'fasta')
records = []
for seq in seq_iterator:
    # extract sequence id
    seq_id = seq.description.split('|')[-1]
    seq_desc = re.sub(' ', '_', seq.description)
    records.append(SeqIO.SeqRecord(seq.seq, seq_id, description=seq_desc))

62820

In [16]:
seq_iterator = SeqIO.parse(open("../data/phosphosite_sequences/Phosphosite_seq_fixed.fasta"), 'fasta')
i = 0
records = []
for seq in seq_iterator:
    if '|' in seq.id:
        print(seq.id)
        continue
    # extract sequence id
    seq_id = seq.description.split('|')[-1]
    seq_desc = re.sub(' ', '_', seq.description)
    records.append(SeqIO.SeqRecord(seq.seq, seq_id, description=seq_desc))
    i +=1

GN:Nsmce1|NSMCE1_iso4|mouse|
GN:orf3b|ORF3b_protein|SARSCoV2|
GN:X_IgG|IgG|mouse|
GN:USP17L9P|USP17L9P|human|
GN:TRAP|TRAP|human|
GN:MSX2P1|MSX2P1|human|
GN:ABCC13|ABCC13|human|
GN:Ly76|Ly76|mouse|
GN:DLEU2|DLEU2|human|
GN:SNHG29|SNHG29|human|
GN:PCA3|PCA3|human|
GN:SNHG3|SNHG3|human|
GN:EGFR-AS1|EGFR-AS1|human|


GN:Nsmce1|NSMCE1_iso4|mouse| is missing the uniprot id. It seems to be A0A6P5Q324. It is added to the fixed .fasta file. Other proteins in this list seem to have malformed entries in the original .fasta file, so they are removed from the fixed version.

In [17]:
SeqIO.write(records, '../data/phosphosite_sequences/Phosphosite_seq_fixed.fasta', 'fasta')

62807

In [2]:
seed = 42

In [18]:
prot_info = pd.read_json('../data/phosphosite_sequences/phosphosite_df.json')
prot_info

Unnamed: 0,id,sites,sequence
0,A0A024R4G9,"[14, 16, 20]",MTVLEAVLEIQAITGSRLLSMVPGPARPPGSCWDPTQCTRTWLLSH...
1,A0A075B759,"[40, 79, 93, 119]",MVNSVVFFEITRDGKPLGRISIKLFADKIPKTAENFRALSTGEKGF...
2,A0A087WP46,"[359, 972, 973, 974, 988, 997, 1000, 1005, 101...",MARDGAEQPDSGPLPRPSPCPQEDRASNLMPPKPPRTWGLQLQGPS...
3,A0A087WPF7,"[32, 43, 622, 626, 798, 941, 956, 1031, 1038, ...",MDGPTRGHGLRKKRRSRSQRDRERRSRAGLGTGAAGGIGAGRTRAP...
4,A0A087WQ53,[58],MGQNNNVTEFILLGLTQDPAGQKVLFVMFLLIYIVKIVGNLLIVGT...
...,...,...,...
42252,XP_997087,"[347, 907, 915, 918, 927]",MENFLALMNSISDTWMSPSCMDIAMDMGIAFVCGAGLFFLLLPFLK...
42253,YP_009725299,"[504, 660, 661, 794, 1826]",APTKVTFGDDTVIEVQGYKSVNITFELDERIDKVLNEKCSAYTVEL...
42254,YP_009725305,[5],NNELSPVALRQMSCAAGTTQTACTDDNALAYYNTTKGGRFVLALLS...
42255,YP_009725309,[56],AENVTGLFKDCSKVITGLHPTQAPTHLSVDTKFKTEGLCVDIPGIP...


In [19]:
clusters = pd.read_csv('../data/clustered30_new.tsv', sep='\t', names=['cluster_rep', 'cluster_mem'])
clusters

Unnamed: 0,cluster_rep,cluster_mem
0,Q8CH79,Q8CH79
1,Q8CH79,Q12797-3
2,Q9NQ86-3,Q9NQ86-3
3,O88656,O88656
4,O88656,Q9R0Q6
...,...,...
62803,Q9BUD6,Q9BUD6
62804,Q9BUD6,Q8BMS2
62805,Q9BUD6,Q9WV75
62806,A6NEK1,A6NEK1


In [32]:
joined = prot_info.join(clusters.set_index('cluster_mem'), on='id', how='left').drop_duplicates('id')

Keep only sequences with length < 1024, becasue ESM has max input size of 1024 and tokenizer will add the [CLS] token.

In [35]:
joined = joined[joined['sequence'].apply(lambda x: len(x) < 1024)]
joined

Unnamed: 0,id,sites,sequence,cluster_rep
0,A0A024R4G9,"[14, 16, 20]",MTVLEAVLEIQAITGSRLLSMVPGPARPPGSCWDPTQCTRTWLLSH...,A0A024R4G9
1,A0A075B759,"[40, 79, 93, 119]",MVNSVVFFEITRDGKPLGRISIKLFADKIPKTAENFRALSTGEKGF...,P62937
4,A0A087WQ53,[58],MGQNNNVTEFILLGLTQDPAGQKVLFVMFLLIYIVKIVGNLLIVGT...,Q8NGM1
5,A0A087WQ89,"[18, 26]",METPIQREIRRSCEREESLRRSRGLSPGRAGEELIELRVRPVLSRP...,Q96FF7
6,A0A087WQP5,"[98, 102, 107]",MGCCGCGGCGGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCG...,A0A087WQP5
...,...,...,...,...
42248,XP_988512,[156],MKHPRHVTGCGGGSRGHVRGVATRGSRALGCGLQRGAAGGAGVAAG...,XP_988512
42251,XP_996053,"[19, 399, 400, 401, 542, 545, 803, 804, 807]",MGKFLALMNSIIDSWMGPSSMDIAMDIGIAFMCGAGLFFLLQRFLK...,XP_997087
42254,YP_009725305,[5],NNELSPVALRQMSCAAGTTQTACTDDNALAYYNTTKGGRFVLALLS...,NP_828867
42255,YP_009725309,[56],AENVTGLFKDCSKVITGLHPTQAPTHLSVDTKFKTEGLCVDIPGIP...,YP_009725309


In [None]:
reps = set(joined.cluster_rep.unique())
members = set(joined.id.unique())

missing = set.difference(reps, members)
missing

{'Q2NKQ1-3',
 'Q9WVB7',
 'G3V6W5',
 'F1MN93',
 'Q5U2P0',
 'O35103',
 'Q8IUB9',
 'Q07646',
 'D3Z7Q2',
 'Q497B3',
 'Q7Z3D6-2',
 'Q9QZE3',
 'P35704',
 'EDL99276',
 'Q3TNH5',
 'Q8BLT8',
 'Q8IVV8',
 'Q99819',
 'Q8R5G4',
 'Q8NBJ9-2',
 'Q9BTE6-3',
 'P54619-3',
 'D4A646',
 'XP_343054',
 'P57087-3',
 'P0C0L6',
 'Q96DN5',
 'D3ZZK5',
 'P14432',
 'D3ZSK5',
 'D3ZCA1',
 'Q61711',
 'D3Z3C6',
 'D3TI84',
 'P08424',
 'Q8K126',
 'Q69AB2-2',
 'Q9QWY8',
 'Q58F21-2',
 'C0HK79',
 'Q91VC9',
 'Q60587',
 'Q8BPG6',
 'Q99J23',
 'F1LQF5',
 'Q5MAI5',
 'Q8R5K6',
 'A2AJD1',
 'Q8VDV0',
 'Q8C522',
 'Q3U4X8',
 'A0A494C1R9',
 'P0DOX4',
 'O70191',
 'Q810P2',
 'P04049-2',
 'P32927-2',
 'Q6PJQ5',
 'F1M3G9',
 'D3ZIE9',
 'B2RZD0',
 'Q62876',
 'Q9H0C1',
 'Q6P0N0',
 'P98173',
 'Q9UBK8-1',
 'Q8VIH3',
 'F1M3N2',
 'O08776',
 'B0BN89',
 'O55174',
 'P70597',
 'Q9CYK1',
 'Q5I0K1',
 'Q9JKJ9',
 'P04070-2',
 'P13945',
 'P0C0T2',
 'NP_001289873',
 'Q9NX14-2',
 'P58511',
 'Q9NRP0-2',
 'Q810F0',
 'Q86SQ9-2',
 'Q12982-2',
 'NP_001316994',
 

In [41]:
len(missing)

3080

The missing set is a set of cluster representatives with length >= 1024, thus not included in this filtered version. We need to choose a different representative for the members of these clusters. It will be the longest protein from the cluster that is still remaining in the set.

In [85]:
to_fix = joined[joined['cluster_rep'].apply(lambda x: x in missing)]
to_fix

Unnamed: 0,id,sites,sequence,cluster_rep
14,A0A096MJN4,"[170, 306]",MIKHFLEDNSDDAELSKFVKDFPGSEPCHPTESKTRVARPQILEPR...,O43236-4
21,A0A0B4J1G0,[235],MWQLLLPTALVLTAFSGIQAGLQKAVVNLDPKWVRVLEEDSVTLRC...,P27645
30,A0A0G2JTM7,"[279, 289, 483]",MGKKHKKHKSDRHFYEEYVEKPLKLVLKVGGSEVTELSTGSSGHDS...,Q9NPI1-2
34,A0A0G2JUG7,"[105, 107, 109, 110, 162, 180, 224, 342, 416, ...",MACRRRYLSSLETGSSLSTDRYSVEGEAPSSETGTSLDSPSAYHQG...,B4DGC5
60,A0A0G2KA14,[404],MFGRSRSWVGGGHGKSSRNIHSLDHLKYLYHVLTKNTTVTEQNRNL...,Q2KHT3
...,...,...,...,...
42222,XP_577119,[12],MEVFLMIRRKKTTIFIEAQELTTVLELKRIVQGILKRPPEEQLLFK...,P62870
42232,XP_578472,"[76, 78, 83]",MAMVCAFLTILVVMSHWSTCCLGCNLPRTHNLTVLKQMSRQSPVSC...,B1AYH7
42237,XP_621234,"[23, 293, 303]",MDHKLQNSDGRQPVLAIKNQTLTEFILLGLTDAPELQVCVFLFLLL...,P47883
42251,XP_996053,"[19, 399, 400, 401, 542, 545, 803, 804, 807]",MGKFLALMNSIIDSWMGPSSMDIAMDIGIAFMCGAGLFFLLQRFLK...,XP_997087


In [104]:
fix_grouped = to_fix.groupby('cluster_rep')

In [105]:
def collapse_fn(group : pd.DataFrame):
    #print(group['sequence'])
    #print(group)
    lens = group['sequence'].apply(lambda x: len(x))
    new_rep = lens.argmax()
    group['new_rep'] = group.iloc[new_rep].id
    return group

fixed = fix_grouped.apply(collapse_fn)

In [106]:
fixed

Unnamed: 0_level_0,Unnamed: 1_level_0,id,sites,sequence,cluster_rep,new_rep
cluster_rep,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
A0A024RBG1,7937,O95989,"[30, 131, 160]",MMKLKSNQTRTYDGDGYKKRAACLCFRSESEEEVLLVSSSRHPDRW...,A0A024RBG1,Q9NZJ9
A0A024RBG1,9042,P0C027,[150],MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
A0A024RBG1,9043,P0C028,"[148, 150]",MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
A0A024RBG1,29828,Q8NFP7,"[11, 67, 148, 150, 154, 158, 159, 162]",MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
A0A024RBG1,32894,Q96G61,"[11, 67, 148, 150, 154, 158, 159, 162]",MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
...,...,...,...,...,...,...
XP_913389,14595,P62851,"[26, 55, 65, 69, 74, 93, 109]",MPPKDDKKKKDAGKSAKKDKDPVNKSGGKAKKKKWSKGKVRDKLNN...,XP_913389,P62851
XP_913389,14596,P62852,"[26, 55, 69, 93, 113]",MPPKDDKKKKDAGKSAKKDKDPVNKSGGKAKKKKWSKGKVRDKLNN...,XP_913389,P62851
XP_997087,1577,B7ZWJ3,"[368, 489, 914, 918, 919, 920]",MENFLSLMNSIIDPWMSNSSMDIAMDMTIGFMCGVGLFFLLIPFLK...,XP_997087,B7ZWJ3
XP_997087,19226,Q3V0M1,"[141, 148, 368, 914, 918, 919, 920]",MENFLSLMNSIIDSWMSNSSMDIAMDMTIGFMCGVGLFFLLIPFLK...,XP_997087,B7ZWJ3


In [108]:
fixed = fixed.reset_index(level='cluster_rep', drop=True)
fixed

Unnamed: 0,id,sites,sequence,cluster_rep,new_rep
7937,O95989,"[30, 131, 160]",MMKLKSNQTRTYDGDGYKKRAACLCFRSESEEEVLLVSSSRHPDRW...,A0A024RBG1,Q9NZJ9
9042,P0C027,[150],MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
9043,P0C028,"[148, 150]",MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
29828,Q8NFP7,"[11, 67, 148, 150, 154, 158, 159, 162]",MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
32894,Q96G61,"[11, 67, 148, 150, 154, 158, 159, 162]",MKCKPNQTRTYDPEGFKKRAACLCFRSEREDEVLLVSSSRYPDRWI...,A0A024RBG1,Q9NZJ9
...,...,...,...,...,...
14595,P62851,"[26, 55, 65, 69, 74, 93, 109]",MPPKDDKKKKDAGKSAKKDKDPVNKSGGKAKKKKWSKGKVRDKLNN...,XP_913389,P62851
14596,P62852,"[26, 55, 69, 93, 113]",MPPKDDKKKKDAGKSAKKDKDPVNKSGGKAKKKKWSKGKVRDKLNN...,XP_913389,P62851
1577,B7ZWJ3,"[368, 489, 914, 918, 919, 920]",MENFLSLMNSIIDPWMSNSSMDIAMDMTIGFMCGVGLFFLLIPFLK...,XP_997087,B7ZWJ3
19226,Q3V0M1,"[141, 148, 368, 914, 918, 919, 920]",MENFLSLMNSIIDSWMSNSSMDIAMDMTIGFMCGVGLFFLLIPFLK...,XP_997087,B7ZWJ3


In [113]:
clusters.set_index('cluster_mem').loc['XP_996053']

cluster_rep    XP_997087
Name: XP_996053, dtype: object

In [109]:
joined['cluster_rep'][fixed.index] = fixed['new_rep']

In [110]:
joined.loc[42251]

id                                                     XP_996053
sites               [19, 399, 400, 401, 542, 545, 803, 804, 807]
sequence       MGKFLALMNSIIDSWMGPSSMDIAMDIGIAFMCGAGLFFLLQRFLK...
cluster_rep                                               B7ZWJ3
Name: 42251, dtype: object

In [114]:
len(joined['cluster_rep'].unique())

13120

In [116]:
reps = joined['cluster_rep'].unique()
reps

array(['A0A024R4G9', 'P62937', 'Q8NGM1', ..., 'XP_988512', 'YP_009725305',
       'YP_009725309'], dtype=object)

In [129]:
rep_mask = joined['id'].apply(lambda x : x in reps)
rep_df = joined[rep_mask]
rep_df['length'] = rep_df['sequence'].apply(lambda x: len(x))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  rep_df['length'] = rep_df['sequence'].apply(lambda x: len(x))


In [None]:
rep_df

Unnamed: 0,id,sites,sequence,cluster_rep,length
0,A0A024R4G9,"[14, 16, 20]",MTVLEAVLEIQAITGSRLLSMVPGPARPPGSCWDPTQCTRTWLLSH...,A0A024R4G9,117
6,A0A087WQP5,"[98, 102, 107]",MGCCGCGGCGGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCG...,A0A087WQP5,138
13,A0A096MJJ4,[501],MSEEKPDKIAPEETAFEEIEKDFQEVLSELSGDKSLEKFRVEYEKL...,A0A096MJJ4,873
19,A0A0A6YY25,"[101, 110, 414, 417, 443, 682, 683]",MCSPASSKILYRNPRFLRVAFLQLHHQQQSGVFCDALLQAEGEAVP...,A0A0A6YY25,723
20,A0A0B4J1F3,"[100, 106, 326, 330]",MSCTFTALLCLGLTLRLWIPVLTGSLPKPILRVQPDSVVQVWTKVT...,A0A0B4J1F3,663
...,...,...,...,...,...
42247,XP_987269,[172],MVSFAQWTVSKMRGKSWTENGERRSPLHNLVLKRFRGSTLKDQVLK...,XP_987269,241
42248,XP_988512,[156],MKHPRHVTGCGGGSRGHVRGVATRGSRALGCGLQRGAAGGAGVAAG...,XP_988512,206
42254,YP_009725305,[5],NNELSPVALRQMSCAAGTTQTACTDDNALAYYNTTKGGRFVLALLS...,YP_009725305,113
42255,YP_009725309,[56],AENVTGLFKDCSKVITGLHPTQAPTHLSVDTKFKTEGLCVDIPGIP...,YP_009725309,527


Split the proteins into classes based on their length, and then stratify the dataset.

In [135]:
_, length_bins = np.histogram(rep_df['length'], 10)
length_labels = np.digitize(rep_df['length'], length_bins)
rep_df['length_class'] = length_labels
rep_df

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  rep_df['length_class'] = length_labels


Unnamed: 0,id,sites,sequence,cluster_rep,length,length_class
0,A0A024R4G9,"[14, 16, 20]",MTVLEAVLEIQAITGSRLLSMVPGPARPPGSCWDPTQCTRTWLLSH...,A0A024R4G9,117,1
6,A0A087WQP5,"[98, 102, 107]",MGCCGCGGCGGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCG...,A0A087WQP5,138,2
13,A0A096MJJ4,[501],MSEEKPDKIAPEETAFEEIEKDFQEVLSELSGDKSLEKFRVEYEKL...,A0A096MJJ4,873,9
19,A0A0A6YY25,"[101, 110, 414, 417, 443, 682, 683]",MCSPASSKILYRNPRFLRVAFLQLHHQQQSGVFCDALLQAEGEAVP...,A0A0A6YY25,723,8
20,A0A0B4J1F3,"[100, 106, 326, 330]",MSCTFTALLCLGLTLRLWIPVLTGSLPKPILRVQPDSVVQVWTKVT...,A0A0B4J1F3,663,7
...,...,...,...,...,...,...
42247,XP_987269,[172],MVSFAQWTVSKMRGKSWTENGERRSPLHNLVLKRFRGSTLKDQVLK...,XP_987269,241,3
42248,XP_988512,[156],MKHPRHVTGCGGGSRGHVRGVATRGSRALGCGLQRGAAGGAGVAAG...,XP_988512,206,2
42254,YP_009725305,[5],NNELSPVALRQMSCAAGTTQTACTDDNALAYYNTTKGGRFVLALLS...,YP_009725305,113,1
42255,YP_009725309,[56],AENVTGLFKDCSKVITGLHPTQAPTHLSVDTKFKTEGLCVDIPGIP...,YP_009725309,527,6


In [153]:
train, test = train_test_split(rep_df.id, test_size=0.2, random_state=seed, stratify=rep_df['length_class'])

In [154]:
train

17812    Q16594
18339    Q3C2P9
503      A2RUU4
19332    Q400G9
40595    Q9Y274
          ...  
15765    Q01954
34303    Q9BQ61
37445    Q9JHI0
4415     M0R3X9
18874    Q3UGE6
Name: id, Length: 10496, dtype: object

In [145]:
rep_df

Unnamed: 0,id,sites,sequence,cluster_rep,length,length_class
0,A0A024R4G9,"[14, 16, 20]",MTVLEAVLEIQAITGSRLLSMVPGPARPPGSCWDPTQCTRTWLLSH...,A0A024R4G9,117,1
6,A0A087WQP5,"[98, 102, 107]",MGCCGCGGCGGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCGCGGCG...,A0A087WQP5,138,2
13,A0A096MJJ4,[501],MSEEKPDKIAPEETAFEEIEKDFQEVLSELSGDKSLEKFRVEYEKL...,A0A096MJJ4,873,9
19,A0A0A6YY25,"[101, 110, 414, 417, 443, 682, 683]",MCSPASSKILYRNPRFLRVAFLQLHHQQQSGVFCDALLQAEGEAVP...,A0A0A6YY25,723,8
20,A0A0B4J1F3,"[100, 106, 326, 330]",MSCTFTALLCLGLTLRLWIPVLTGSLPKPILRVQPDSVVQVWTKVT...,A0A0B4J1F3,663,7
...,...,...,...,...,...,...
42247,XP_987269,[172],MVSFAQWTVSKMRGKSWTENGERRSPLHNLVLKRFRGSTLKDQVLK...,XP_987269,241,3
42248,XP_988512,[156],MKHPRHVTGCGGGSRGHVRGVATRGSRALGCGLQRGAAGGAGVAAG...,XP_988512,206,2
42254,YP_009725305,[5],NNELSPVALRQMSCAAGTTQTACTDDNALAYYNTTKGGRFVLALLS...,YP_009725305,113,1
42255,YP_009725309,[56],AENVTGLFKDCSKVITGLHPTQAPTHLSVDTKFKTEGLCVDIPGIP...,YP_009725309,527,6


In [155]:
train_labels = rep_df.set_index('id').loc[train]['length_class']
train_labels

id
Q16594     3
Q3C2P9     5
A2RUU4     1
Q400G9     5
Q9Y274     4
          ..
Q01954    10
Q9BQ61     2
Q9JHI0     6
M0R3X9     2
Q3UGE6     7
Name: length_class, Length: 10496, dtype: int64

In [156]:
from sklearn.model_selection import StratifiedKFold

cv = StratifiedKFold(random_state=seed, shuffle=True)
splits = list(cv.split(train, train_labels))

In [158]:
import json

final_dataset = {
    'prot_ids' : reps.tolist(),
    'test' : test.tolist(),
    'train' : train.tolist(),
    'splits' : []
}
for t, dev in splits:
    d = { 'train' : t.tolist(), 'dev' : dev.tolist() }
    final_dataset['splits'].append(d)

with open('../data/prepared_dataset.json', 'w') as f:
    json.dump(final_dataset, f, indent='\t')