# Generate a Rosetta dataset

This notebook shows how to generate a Rosetta dataset, which can be used to pretrain METL models. It assumes you already ran Rosetta simulations using our [metl-sim](https://github.com/gitter-lab/metl-sim) repository. For demonstration purposes, this notebook will use a sample output from that repository ([avgfp_rosettafy_sample.db](../data/rosetta_data/avgfp_rosettafy_sample.db)). 

1. Create a Rosetta dataset using the script [parse_rosetta_data.py](../code/parse_rosetta_data.py) to parse and format the Rosetta data generated by [metl-sim](https://github.com/gitter-lab/metl-sim). This step is necessary to remove duplicates, handle NaN values, and remove outliers. 
2. Create train, validation, and test splits for the Rosetta dataset using [split_dataset.py](../code/split_dataset.py).
3. Compute standardization parameters from the train set using [compute_rosetta_standardization.py](../code/compute_standardization.py). The standardization parameters are needed during training and evaluation to standardize the various Rosetta energies so that they are on similar scales with mean 0 and standard deviation 1.
4. Make sure all PDB files from the Rosetta dataset are present in the [pdb_files](../data/pdb_files) directory and are listed in [pdb_index.csv](../data/rosetta_data/pdb_index.csv).


---
**NOTE**

The Rosetta dataset generated by running this notebook is already saved in this repository in the [data/rosetta_data/avgfp](../data/rosetta_data/avgfp) directory. If you would like to run this notebook, first delete or rename that directory.

---

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

# define the name of the project root directory
project_root_dir_name = "metl"

# find the project root by checking each parent directory
current_dir = os.getcwd()
while os.path.basename(current_dir) != project_root_dir_name and current_dir != os.path.dirname(current_dir):
    current_dir = os.path.dirname(current_dir)

# change the current working directory to the project root directory
if os.path.basename(current_dir) == project_root_dir_name:
    os.chdir(current_dir)
else:
    print("project root directory not found")
    
# add the project code folder to the system path so imports work
module_path = os.path.abspath("code")
if module_path not in sys.path:
    sys.path.append(module_path)

# Parse Rosetta data

Parsing rosetta data is handled by the [parse_rosetta_data.py](../code/parse_rosetta_data.py) script.

In [3]:
%run code/parse_rosetta_data.py -h

usage: parse_rosetta_data.py [-h] [--ds_name DS_NAME] [--pdb_fns PDB_FNS]
                             [--db_fn DB_FN]
                             [--keep_num_muts KEEP_NUM_MUTS [KEEP_NUM_MUTS ...]]
                             [--remove_nan] [--no_remove_nan]
                             [--remove_duplicates] [--no_remove_duplicates]
                             [--remove_outliers] [--no_remove_outliers]
                             [--outlier_energy_term OUTLIER_ENERGY_TERM]
                             [--outlier_threshold OUTLIER_THRESHOLD]
                             [--replace_pdb_fn REPLACE_PDB_FN] [--ct_fn CT_FN]
                             {generate_dataset,generate_dms_coverage_dataset}

positional arguments:
  {generate_dataset,generate_dms_coverage_dataset}
                        mode to run

optional arguments:
  -h, --help            show this help message and exit
  --ds_name DS_NAME     name of the dataset to generate
  --pdb_fns PDB_FNS     either the path to a sin

## Define arguments and run this script on our sample database

In [4]:
# define the dataset name for our Rosetta dataset
# call this one "avgfp" because it's for pretraining a METL-Local avGFP model
ds_name = "avgfp"

# path to the metl-sim database containing the raw Rosetta data
# using the sample database provided in this repository
db_fn = "data/rosetta_data/avgfp_rosettafy_sample.db"

# what pdb files to include in this Rosetta dataset
# this is necessary if, for example, your raw metl-sim database contains multiple different 
# PDB files, but you only want to create a Rosetta dataset for one of them
# the avGFP PDB file is "1gfl_cm.pdb", so that's what we'll use
# note, if you want to use multiple PDB files (like for a METL-Global rosetta dataset), 
# you can specify a path to a text file containing a list of PDBs to include
pdb_fns = "1gfl_cm.pdb"

# removing outliers
outlier_energy_term = "total_score"
outlier_threshold = 6.5

In [5]:
# run parse_rosetta_data.py with the arguments defined above
# additionally, we will set the flags:
# --remove_nan (removes variants with NaN values for any of the energy terms)
# --remove_duplicates (removes duplicate variants)
# --remove_outliers (removes outliers)
%run code/parse_rosetta_data.py generate_dataset --ds_name $ds_name --db_fn $db_fn --pdb_fns $pdb_fns --remove_nan --remove_duplicates --remove_outliers --outlier_energy_term $outlier_energy_term --outlier_threshold $outlier_threshold

INFO:METL.__main__:output data directory will be: data/rosetta_data/avgfp
INFO:METL.__main__:connecting to database at: data/rosetta_data/avgfp_rosettafy_sample.db


db_count: 10000
process_list: []
process_list_ds: []
initial data loaded into dataframe
Filtering variants with NaN values
Dropped 0 variants with nan values
Num variants after NaN filter: 10000
Removed 1 duplicates
Num variants after duplicate filter: 9999
Removed 87 outliers
Num variants after outlier removal: 9912
Saving dataset to CSV
Saving dataset to HDF, pandas fixed format
Saving dataset to SQL


100%|██████████████████████████████████████████████████████████████| 9912/9912 [00:00<00:00, 122878.77it/s]


# Create train, validation, and test splits

Functions for creating training splits are contained in the [split_dataset.py](../code/split_dataset.py) file.
This file does not have a command line interface, so you will call functions directly after importing it into this notebook.

In [6]:
import split_dataset as sd
import random

In [7]:
# path to the .h5 dataset created above
ds_fn = "data/rosetta_data/avgfp/avgfp.h5"

# define an output directory for placing the splits
out_dir = "data/rosetta_data/avgfp/splits"

# load the dataset
ds = pd.read_hdf(ds_fn, key="variant")

We will first withhold 5% of the data to be used as a "super test" set using the `supertest()` function.

In [8]:
# for purposes of this demonstration, make random seed constant
# rseed1 = random.randint(1000, 9999)
rseed1 = 7808
_, withhold_fn = sd.supertest(ds, size=0.05, rseed=rseed1, out_dir=out_dir)

INFO:METL.split_dataset:saving supertest split to file data/rosetta_data/avgfp/splits/supertest_w1aea30517f4f_s0.05_r7808.txt


Now we create standard 80% train, 10% validation, and 10% test sets using the `train_val_test()` function.

In [9]:
# for purposes of this demonstration, make random seed constant
# rseed2 = random.randint(1000, 9999)
rseed2 = 4991
split, out_dir_split = sd.train_val_test(ds, train_size=0.8, val_size=0.1, test_size=0.1, withhold=withhold_fn, rseed=rseed2, out_dir=out_dir)

INFO:METL.split_dataset:saving train-val-test split to directory data/rosetta_data/avgfp/splits/standard_tr0.8_tu0.1_te0.1_w1aea30517f4f_r4991


# Compute standardization parameters
The Rosetta energy terms are on different scales, and to make training easier we standardize them to have mean 0 and standard deviation 1. The standardization parameters are computed on the train set only. The script [compute_rosetta_standardization.py](../code/compute_standardization.py) will compute the standardization parameters and store them in the given split directory.

In [10]:
%run code/compute_rosetta_standardization.py -h

usage: compute_rosetta_standardization.py [-h] [--split_dir SPLIT_DIR]
                                          [--energies_start_col ENERGIES_START_COL]
                                          ds_fn_h5

positional arguments:
  ds_fn_h5              path to the rosetta dataset in hdf5 format

optional arguments:
  -h, --help            show this help message and exit
  --split_dir SPLIT_DIR
                        path to the split directory containing the
                        train/val/test split indices. if provided, the
                        standardization parameters will be computed on the
                        training set only. this is necessary for training a
                        source model.
  --energies_start_col ENERGIES_START_COL
                        the column name of the first energy term in the
                        dataset. default is 'total_score'. this is used to
                        determine which columns in the dataset are energy
             

In [11]:
# compute the standardization parameters for the dataset and train split we created above
ds_fn_h5 = "data/rosetta_data/avgfp/avgfp.h5"
split_dir = "data/rosetta_data/avgfp/splits/standard_tr0.8_tu0.1_te0.1_w1aea30517f4f_r4991"
%run code/compute_rosetta_standardization.py $ds_fn_h5 --split_dir $split_dir

INFO:METL.__main__:computing standardization params on training set only
INFO:METL.__main__:saving standardization params to: data/rosetta_data/avgfp/splits/standard_tr0.8_tu0.1_te0.1_w1aea30517f4f_r4991/standardization_params


# Ensure PDB files are listed in the index
PDB files in the Rosetta dataset need to be present in the [pdb_files](../data/pdb_files) directory and listed in [pdb_index.csv](../data/rosetta_data/pdb_index.csv) file. This is necessary because the 3D relative position embedding needs access to the PDB files and some of the training code will reference the index file. 

This bit of code will automatically generate the pdb_index.csv based on the pdb files in the pdb_files directory. Run this any time you add pdb files to that directory.

In [12]:
pdb_dir = "data/pdb_files"
pdb_fns = [join(pdb_dir, x) for x in os.listdir(pdb_dir) if x.endswith(".pdb")]
with open("data/rosetta_data/pdb_index.csv", "w") as f:
    f.write("pdb_fn,aa_sequence,seq_len\n")
    for pdb_fn in pdb_fns:
        print("Processing {}".format(pdb_fn))
        seq = utils.extract_seq_from_pdb(pdb_fn)
        f.write("{},{},{}\n".format(basename(pdb_fn), seq, len(seq)))

Processing data/pdb_files/1gfl_cm.pdb
