In [1]:
import os

import pandas as pd
import torch

from bioemu.observables.folding_stability import compute_folded_proportion_from_dG
from datasets import DatasetDict, load_dataset

seed = 42  # Set a seed for reproducibility

In [None]:
# 1. Load your dataset
dataset_tag = "dataset2"
dataset2 = load_dataset(
    path="RosettaCommons/MegaScale", name=dataset_tag, data_dir=dataset_tag
)

# 2. First split: 80% train / 20% (val + test)
train_testvalid = dataset2["train"].train_test_split(test_size=0.2, seed=42)

# 3. Second split: split that 20% into two equal parts (10% each)
test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=42)

# 4. Build the final DatasetDict
dataset2_splits = DatasetDict(
    {
        "train": train_testvalid["train"],  # 80%
        "val": test_valid["train"],  # 10%
        "test": test_valid["test"],  # 10%
    }
)

In [None]:
os.makedirs("datasets/megascale", exist_ok=True)

dataset2_splits["train"].to_csv("datasets/megascale/train.csv")
dataset2_splits["val"].to_csv("datasets/megascale/val.csv")
dataset2_splits["test"].to_csv("datasets/megascale/test.csv")

In [8]:
train = pd.read_csv("datasets/megascale/train.csv")
val = pd.read_csv("datasets/megascale/val.csv")
test = pd.read_csv("datasets/megascale/test.csv")

# convert dG_ML to float and fileter out values that are not numbers
train["dG_ML"] = pd.to_numeric(train["dG_ML"], errors="coerce")
val["dG_ML"] = pd.to_numeric(val["dG_ML"], errors="coerce")
test["dG_ML"] = pd.to_numeric(test["dG_ML"], errors="coerce")
# filter out rows with NaN dG_ML
train = train.dropna(subset=["dG_ML"])
val = val.dropna(subset=["dG_ML"])
test = test.dropna(subset=["dG_ML"])
# compute the proportion of folded sequences based on dG_ML
train["p_folded"] = compute_folded_proportion_from_dG(
    torch.tensor(-train["dG_ML"].to_numpy())
).numpy()
val["p_folded"] = compute_folded_proportion_from_dG(
    torch.tensor(-val["dG_ML"].to_numpy())
).numpy()
test["p_folded"] = compute_folded_proportion_from_dG(
    torch.tensor(-test["dG_ML"].to_numpy())
).numpy()

In [9]:
train

Unnamed: 0,name,dna_seq,log10_K50_t,log10_K50_t_95CI_high,log10_K50_t_95CI_low,log10_K50_t_95CI,fitting_error_t,log10_K50unfolded_t,deltaG_t,deltaG_t_95CI_high,...,mut_type,WT_name,WT_cluster,log10_K50_trypsin_ML,log10_K50_chymotrypsin_ML,dG_ML,ddG_ML,Stabilizing_mut,pair_name,p_folded
0,2LVN.pdb_Q15R,TCTGCGGGCGGTTCCGCTGGCGGCTCCGCGGGTGGCTCCCAGCTCA...,0.199068,0.231871,0.151547,0.080324,0.124635,-1.165985,1.800355,1.846291,...,Q15R,2LVN.pdb,59,0.1990680627201658,-0.4049197314594209,2.036008,-0.21448136260606754,False,,0.968875
1,1PGA.pdb_L5A_insA14,TCTGCTGGTGGTTCTGCGGGTATGACCTACAAAGCGATCCTGAACG...,-0.509430,-0.480570,-0.534830,0.054260,0.027942,-1.030330,0.488222,0.542670,...,insA14,1PGA.pdb_L5A,4,-0.5094296205602438,-1.872472744400257,0.728225,-1.712453515137192,False,,0.773770
2,EA|run7_1572_0002.pdb_delT11,TCTGCTGGTGGCTCCGCTGGCGGCTCTGCGGGCGGTTCTGACCGCC...,0.824038,0.882067,0.735686,0.146382,0.189363,-0.908498,2.315325,2.395979,...,delT11,EA|run7_1572_0002.pdb,EEHH,0.82403780132907,0.9018172923669228,2.369427,-0.4640036557474181,False,,0.982034
4,2K52.pdb_F21S_K17I,GACGTTGAACCGGGTAAATTCTACAAAGGTGTTGTTACCCGTATCG...,0.309887,0.344282,0.253533,0.090750,0.076017,-1.349622,2.207185,2.254419,...,K17I,2K52.pdb_F21S,225,0.3098870789814789,-0.7155926441499197,2.390553,-0.018486938460182678,False,,0.982653
5,1H8K.pdb_W36A_K34F,TCCGCTGGTGGTTCCGCGGGTGGCAAAGAACTGGTGCTGGTTCTCT...,0.030287,0.054779,-0.005703,0.060482,0.069081,-1.103431,1.471772,1.507113,...,K34F,1H8K.pdb_W36A,15,0.030287198573502,-1.0469122931853083,1.777257,-0.46307179973400503,False,,0.952627
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
621031,2B89.pdb_Q37D,TCTGCGGGCGGTTCTGCTGGCGGCTCCAAATTCAACAAAGAACGTG...,0.272856,0.292874,0.246556,0.046317,0.051666,-1.084301,1.789869,1.817965,...,Q37D,2B89.pdb,71,0.2728559888754384,-1.1234614014309183,1.638959,-0.09911526329965858,False,,0.940902
621032,r10_572_TrROS_Hall.pdb_K36G,TCTGCTGGCGGCTCCGCGGGCGGTTCTGCGGGTGGTATGCTGTACG...,0.693889,0.787869,0.658143,0.129726,0.088670,-0.589571,1.692183,1.826088,...,K36G,r10_572_TrROS_Hall.pdb,hall,0.6938887931626745,-0.2406505604834428,2.001203,-0.8692827089334663,False,,0.967053
621033,2LHR.pdb_hnet0_E24L:R46Y,TCCGCTGGCGGCTCTGCGGGCGGTTACAACCTGCAAAAGCTGCTCG...,-0.376246,-0.360968,-0.407199,0.046231,0.091335,-1.212921,1.026847,1.050687,...,E24L:R46Y,2LHR.pdb,56,-0.3762461465931201,-1.4736302219235724,1.389321,-0.6431683859494268,False,2LHR.pdb_hnet0,0.912623
621034,2MA4.pdb_hnet1_D36W:K40K,TCTGCGGAAATCATGAAAAAGACCGACTTCGACAAAGTTGCGTCTG...,1.360870,1.433721,1.320826,0.112895,0.125512,-0.833307,2.970308,3.076528,...,D36W:K40K,2MA4.pdb,107,1.3608699982293584,0.1152622368850471,3.465694,0.9111786556181292,-,2MA4.pdb_hnet1,0.997135


In [10]:
# randmly sample 5 sequences from train and 1 sequence from val
# sequences should be shorter than 50 amino acids
train_sample = train[(train["aa_seq"].str.len() < 50) & (train["dG_ML"] < 1)].sample(
    n=1, random_state=seed
)
val_sample = val[(val["aa_seq"].str.len() < 50) & (val["dG_ML"] < 1)].sample(
    n=1, random_state=seed
)
# save the samples
os.makedirs("test/megascale", exist_ok=True)
train_sample.to_csv("test/megascale/train_sample.csv", index=False)
val_sample.to_csv("test/megascale/val_sample.csv", index=False)