# Prepare sort-seq dataset for use in MAVE-NN

In [1]:
# Standard imports
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# Insert mavenn at beginning of path
import sys
path_to_mavenn_local = '../../../../'
sys.path.insert(0,path_to_mavenn_local)

#Load mavenn and check path
import mavenn
print(mavenn.__path__)

# For testing
from mavenn.src.utils import vec_data_to_mat_data

['../../../../mavenn']


In [2]:
# Load raw data file
raw_df = pd.read_csv(mavenn.__path__[0] +
    '/examples/datasets/sort_seq/full-wt/full-wt-sort_seq.csv',
    index_col=[0])
raw_df.head()

Unnamed: 0,seq,bin,ct
0,GGCTGTTCACTTTATGCTTCCGGCTTGTATTTTGTGTGC,4,23.0
1,GGTTTTACACATTATGCTTCCGGCTCGTCTCTTGTGTGG,2,12.0
2,GGCTTAACACTTAATGCTTCCGGCTCGTATGTTGTGTGG,1,11.0
3,GGTTTTACACTTTATGCTTCCCGCTCGTAAGGTGTGTCG,5,10.0
4,GGCTTTACACTTTATGCGTCCGGCTCGTATGTTGCGTGG,2,10.0


In [3]:
# Refine contents of raw data file
sequences = raw_df['seq'].values
raw_df.columns = ['x','y','ct']
raw_df['ct'] = raw_df['ct'].astype(int)
raw_df.head()

Unnamed: 0,x,y,ct
0,GGCTGTTCACTTTATGCTTCCGGCTTGTATTTTGTGTGC,4,23
1,GGTTTTACACATTATGCTTCCGGCTCGTCTCTTGTGTGG,2,12
2,GGCTTAACACTTAATGCTTCCGGCTCGTATGTTGTGTGG,1,11
3,GGTTTTACACTTTATGCTTCCCGCTCGTAAGGTGTGTCG,5,10
4,GGCTTTACACTTTATGCGTCCGGCTCGTATGTTGCGTGG,2,10


In [4]:
# Pivot and set training/test data
pivot_df = pd.pivot(raw_df, values='ct', index='x', columns='y').fillna(0).astype(int)
pivot_df.columns.name = None

# Do all columns still sum to > 0?
print('rows summing to 0:', (pivot_df.values.sum(axis=1)==0).sum())
pivot_df.head()

rows summing to 0: 0


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9
x,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
AAATACACACTTGCTGCTTCCGGCTCGTATGTTGTGTGG,0,0,0,1,0,0,0,0,0,0
AAATTTACACTGTATGCTTCCGGCTCGCATGGCGTTTGC,0,0,1,0,0,0,0,0,0,0
AAATTTACACTTTATGCATCAGACTCGTATGTTGTGTGG,1,0,0,0,0,0,0,0,0,0
AAATTTACACTTTATGCTTCTGGCGCGTATGCGGCGTGG,0,0,0,1,0,0,0,0,0,0
AACATTACATTTTATGCTTCCGGCTCGTATGGTGTGTGG,0,1,0,0,0,0,0,0,0,0


In [5]:
N = len(pivot_df)
training_frac=.8
pivot_df['training_set'] = (np.random.rand(N) < training_frac)
pivot_df.reset_index(inplace=True)
pivot_df.head()

Unnamed: 0,x,0,1,2,3,4,5,6,7,8,9,training_set
0,AAATACACACTTGCTGCTTCCGGCTCGTATGTTGTGTGG,0,0,0,1,0,0,0,0,0,0,True
1,AAATTTACACTGTATGCTTCCGGCTCGCATGGCGTTTGC,0,0,1,0,0,0,0,0,0,0,True
2,AAATTTACACTTTATGCATCAGACTCGTATGTTGTGTGG,1,0,0,0,0,0,0,0,0,0,True
3,AAATTTACACTTTATGCTTCTGGCGCGTATGCGGCGTGG,0,0,0,1,0,0,0,0,0,0,True
4,AACATTACATTTTATGCTTCCGGCTCGTATGGTGTGTGG,0,1,0,0,0,0,0,0,0,0,True


In [6]:
# Melt dataframe and prepare for saving
data_df = pd.melt(pivot_df, id_vars=['x','training_set'], ignore_index=True)
data_df.columns = ['x','training_set','y','ct']
data_df = data_df[['training_set','ct','y','x']]
data_df.head()

Unnamed: 0,training_set,ct,y,x
0,True,0,0,AAATACACACTTGCTGCTTCCGGCTCGTATGTTGTGTGG
1,True,0,0,AAATTTACACTGTATGCTTCCGGCTCGCATGGCGTTTGC
2,True,1,0,AAATTTACACTTTATGCATCAGACTCGTATGTTGTGTGG
3,True,0,0,AAATTTACACTTTATGCTTCTGGCGCGTATGCGGCGTGG
4,True,0,0,AACATTACATTTTATGCTTCCGGCTCGTATGGTGTGTGG


In [7]:
# Remove entries where ct is 0
ix = data_df['ct'] > 0
data_df = data_df[ix].reset_index(drop=True)
data_df.head()

Unnamed: 0,training_set,ct,y,x
0,True,1,0,AAATTTACACTTTATGCATCAGACTCGTATGTTGTGTGG
1,False,1,0,AACTTAACAATTTATGCTTCCGACTCGTATATTCTGTGG
2,True,1,0,AACTTTACACTATATGCGTCAGGCTCGTATGTTGTGTGG
3,False,1,0,AACTTTACACTGTATGCTTCCGTCTCCTATGTTGTGTGG
4,False,2,0,AACTTTACACTTGATGCTTCCGGCTCGTATGTTGTGTAG


In [8]:
ct_my, x_m = vec_data_to_mat_data(y_n=data_df['y'], 
                                  ct_n=data_df['ct'], 
                                  x_n=data_df['x'])
print(f'Number of rows in data_df: {len(x_m)}')
print(f'Number of rows with no counts: {sum(ct_my.sum(axis=1)==0)}')

Number of rows in data_df: 45778
Number of rows with no counts: 0


In [9]:
ix = data_df['training_set']
training_df = data_df[ix].copy()
ct_my, x_m = vec_data_to_mat_data(y_n=training_df['y'], 
                                  ct_n=training_df['ct'], 
                                  x_n=training_df['x'])
print(f'Number of rows in training_df: {len(x_m)}')
print(f'Number of rows with no counts: {sum(ct_my.sum(axis=1)==0)}')

Number of rows in training_df: 36517
Number of rows with no counts: 0


In [10]:
# Show size of compressed dataset file
file_name = 'sortseq_data.csv.gz'
data_df.to_csv(file_name, compression='gzip', index=False)
print('df (zipped):')
!du -mh $file_name
!mv $file_name ../.

df (zipped):
292K	sortseq_data.csv.gz


In [11]:
# Test loading
loaded_df = mavenn.load_example_dataset('sortseq')
loaded_df.head()

Unnamed: 0,training_set,ct,y,x
0,True,1,0,AAATTTACACTTTATGCATCAGACTCGTATGTTGTGTGG
1,False,1,0,AACTTAACAATTTATGCTTCCGACTCGTATATTCTGTGG
2,True,1,0,AACTTTACACTATATGCGTCAGGCTCGTATGTTGTGTGG
3,False,1,0,AACTTTACACTGTATGCTTCCGTCTCCTATGTTGTGTGG
4,False,2,0,AACTTTACACTTGATGCTTCCGGCTCGTATGTTGTGTAG
