# Split the Sarkisyan dataset into three data distribution sets.

Will split randomly.

Re-run top to bottom Apr 30, 2019.

In [1]:
import sys
import os
import random

import numpy as np
import pandas as pd
from sklearn.manifold import MDS
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt

sys.path.append('../common')
import data_io_utils
import paths
import utils

%reload_ext autoreload
%autoreload 

## Set seeds for reproducibility

In [2]:
np.random.seed(1)
random.seed(1)

## Sync data

In [3]:
data_io_utils.sync_s3_path_to_local(paths.SARKISYAN_DATA_FILE, is_single_file=True)
sark_df = pd.read_csv(paths.SARKISYAN_DATA_FILE)

# Randomly shuffle
sark_df = sark_df.sample(frac=1).reset_index(drop=True)


print(sark_df.shape)
sark_df.head()

(51715, 2)


Unnamed: 0,seq,quantitative_function
0,MSKGEGLFTGVVPILVELDGDVNGHKFSVSGEGEGGATYGKLTLKY...,0.007245
1,MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,0.162121
2,MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,0.760652
3,MSKGEELFTGVVPILVELRGDVSGHKFSVSGEGEGDATSGKLTLKF...,0.991444
4,MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,0.006828


## Split

In [4]:
kf = KFold(n_splits=3)
split_indices = []
for _, sp_idx in kf.split(sark_df):
    split_indices.append(sp_idx)

In [5]:
split_dfs = []
for sidx in split_indices:
    split_dfs.append(sark_df.iloc[sidx])

Some checks

In [6]:
# Assert all examples are used
assert np.sum([sdf.shape[0] for sdf in split_dfs]) == sark_df.shape[0]

# Assert no overlaps in which seqs are being used between sets
for i in range(len(split_dfs)):
    for j in range(i+1, len(split_dfs)):
        s1 = split_dfs[i]
        s2 = split_dfs[j]
        print(set(s1['seq']).intersection(set(s2['seq'])))
        
        assert len(set(s1['seq']).intersection(set(s2['seq']))) == 0

set()
set()
set()


## Export

In [7]:
output_file_prefix = 'sarkisyan_split_'

os.makedirs(os.path.join(paths.TTS_SPLITS_DIR, 'data_distributions'), exist_ok=True)

for i,sdf in enumerate(split_dfs):
    ofile = os.path.join(paths.TTS_SPLITS_DIR, 'data_distributions', output_file_prefix + str(i) + '.csv')
    sdf.to_csv(ofile, index=False)
    
    print(ofile)
    print(data_io_utils.generate_md5_checksum(ofile))
    print()

/notebooks/analysis/common/../../data/s3/datasets/tts_splits/data_distributions/sarkisyan_split_0.csv
05923c69d7ecfff31e2f46d2fee52eb5

/notebooks/analysis/common/../../data/s3/datasets/tts_splits/data_distributions/sarkisyan_split_1.csv
a70b69b9077b103179fc245cbc99488d

/notebooks/analysis/common/../../data/s3/datasets/tts_splits/data_distributions/sarkisyan_split_2.csv
17c182531eca0d1e11e180cf9744e9b9



Manually verified these results are reproducible by running the notebook 2x top to bottom and checking MD5 checksums.

## Sync back up to S3

In [8]:
# Post publication note: Disabling sync to read-only bucket.
#data_io_utils.sync_local_path_to_s3(paths.TTS_SPLITS_DIR)