In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import random
import math

from Bio import SeqIO

In [2]:
data_dir = '/novo/projects/departments/cdd/molecular_ai/mlbp/data/static_input_data'

In [3]:
path = f'{data_dir}/sbxw_fibrillation_peptide_waltzdb.csv'
waltz_df = pd.read_csv(path,index_col=0)
print(waltz_df.shape)
display(waltz_df.groupby('value_bool').size())
display(waltz_df.groupby(['data_split','value_bool']).size())
print(waltz_df.sequence.nunique())
waltz_df = waltz_df.drop(columns='data_split')
waltz_df.head()

(1399, 3)


value_bool
False    892
True     507
dtype: int64

data_split  value_bool
test        False         177
            True          100
train       False         715
            True          407
dtype: int64

1399


Unnamed: 0,sequence,value_bool
0,STVPIE,False
1,GVIWIA,True
2,LATVYA,False
3,NATAHQ,False
4,STVGIE,False


In [4]:
# shuffle and split as in the aggreprot preprint
# https://www.biorxiv.org/content/10.1101/2024.03.06.583680v1
# Section 2.1.1
waltz_df = waltz_df.sample(frac=1,random_state=42).reset_index(drop=True)

# 90% train / 10% test
n_train = round(len(waltz_df)*.90)
n_test = len(waltz_df)-n_train
print(n_train,n_test)

train_df = waltz_df.iloc[0:n_train].copy(deep=True)
train_df['data_split'] = 'train'
test_df = waltz_df.iloc[n_train:n_train+n_test].copy(deep=True)
test_df['data_split'] = 'test'
assert len(train_df)+len(test_df)==len(waltz_df)

# split the training set into 5 batches
n_batches = 5
n_subset = math.ceil(n_train/n_batches)
train_df['data_split'] = train_df.index//n_subset
display(train_df.groupby('data_split').size())

1259 140


data_split
0    252
1    252
2    252
3    252
4    251
dtype: int64

In [5]:
def assign_data_split(x,i):
    if x == 'test':
        return x
    elif x == i:
        return 'val'
    else:
        return 'train'

# combine together
for i in range(n_batches):
    concat_df = pd.concat([train_df,test_df])
    concat_df['data_split'] = concat_df['data_split'].apply(lambda x: assign_data_split(x,i))
    print(f'Split: {i}')
    print('Validation indices:',concat_df[concat_df['data_split']=='val'].index)
    concat_df.to_csv(f'{data_dir}/sbxw_fibrillation_peptide_waltz_aggreprot_split{i}.csv')

Split: 0
Validation indices: Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
       ...
       242, 243, 244, 245, 246, 247, 248, 249, 250, 251],
      dtype='int64', length=252)
Split: 1
Validation indices: Index([252, 253, 254, 255, 256, 257, 258, 259, 260, 261,
       ...
       494, 495, 496, 497, 498, 499, 500, 501, 502, 503],
      dtype='int64', length=252)
Split: 2
Validation indices: Index([504, 505, 506, 507, 508, 509, 510, 511, 512, 513,
       ...
       746, 747, 748, 749, 750, 751, 752, 753, 754, 755],
      dtype='int64', length=252)
Split: 3
Validation indices: Index([ 756,  757,  758,  759,  760,  761,  762,  763,  764,  765,
       ...
        998,  999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007],
      dtype='int64', length=252)
Split: 4
Validation indices: Index([1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017,
       ...
       1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258],
      dtype='int64', length=251)
