In [14]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.stats import ttest_rel,wilcoxon,ranksums,ttest_ind
from statsmodels.stats.multitest import multipletests

In [2]:
dfP2 = pd.read_table('data/phrases-reads-matrix-ovrexpr-rp2.txt')
dfP1 = pd.read_table('data/phrases-reads-matrix-ovrexpr-rp1.txt')
df = pd.read_table('data/phrases-reads-matrix_24042019.txt')
def rename_column(c):
    if 'dpn' in c:
        if '-' in c:
            splitc = c.split('-')
            return splitc[0]+'.'+splitc[2]+'-'+splitc[1]
        elif '.' in c:
            expt = c.split('i.')[1]
            digest = c.split('.')[0]
            return expt+'-'+digest
    else:
        return c
df.columns = [rename_column(c) for c in df.columns]
dfP2.columns = [rename_column(c) for c in dfP2.columns]
dfP1.columns = [rename_column(c) for c in dfP1.columns]

In [3]:
def normalize(df):
    expts = set([c.split('-')[0] for c in df.columns if 'dpn' in c])
    dpn_cols = list([c for c in df.columns if 'dpn' in c])
    #trim any phrases that have no reads
    norm =df.copy()
    norm = norm[np.all(norm[dpn_cols] != 0,axis=1)]
    for expt in expts:
        norm[expt+'-dpnii'] = norm[expt+'-dpnii']/(np.sum(norm[expt+'-dpnii'])/100000)
        norm[expt+'-dpni'] = norm[expt+'-dpni']/(np.sum(norm[expt+'-dpni'])/100000)
        norm[expt] = norm[expt+'-dpnii']/(norm[expt+'-dpni']+norm[expt+'-dpnii'])
    norm['ed.average'] = (norm['ed1.1']+norm['ed1.2']+norm['ed2.1']+norm['ed2.2'])/4.0
    norm['es.average'] = (norm['es1.1']+norm['es1.2']+norm['es2.1']+norm['es2.2']+norm['es3.1']+norm['es3.2']+norm['es4.1']+norm['es4.2'])/8.0
    return norm

def normalizeP(df,rep):
    expts = set([c.split('-')[0] for c in df.columns if 'dpn' in c])
    dpn_cols = [c for c in df.columns if 'dpn' in c]
    norm =df.copy()
    norm = norm[np.all(norm[dpn_cols] != 0,axis=1)]
    for expt in expts:
        norm[expt+'-dpnii'] = norm[expt+'-dpnii']/(np.sum(norm[expt+'-dpnii'])/100000)
        norm[expt+'-dpni'] = norm[expt+'-dpni']/(np.sum(norm[expt+'-dpni'])/100000)
        norm[expt+'_'+rep] = norm[expt+'-dpnii']/(norm[expt+'-dpni']+norm[expt+'-dpnii'])
    norep_expts = ['.'.join(expt.split('.')[:-1]) for expt in expts]
    for expt in norep_expts:
        norm[expt+'_'+rep+'.average'] = (norm[expt+'.rep1'+'_'+rep]+norm[expt+'.rep2'+'_'+rep])/2.0
    return norm

In [4]:
norm_df = normalize(df)
norm_perturb1 = normalizeP(dfP1,'August')
norm_perturb2 = normalizeP(dfP2,'Sept')

In [5]:
norm_df.head()

Unnamed: 0,seq,category,gc.content,description,es1.1-dpni,es1.2-dpni,es2.1-dpni,es2.2-dpni,es3.1-dpni,es3.2-dpni,...,es1.1,es4.1,es2.1,ed1.1,es1.2,ed1.2,es4.2,es2.2,ed.average,es.average
0,CCGCTGTCTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGAAA...,CAT1,0.69,16.3110504,9.3126,6.29672,14.506854,10.456088,3.918077,2.765532,...,0.67561,0.842191,0.719595,0.81911,0.802863,0.90222,0.806932,0.786676,0.870671,0.690074
1,GAAAACGTTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGAAA...,CAT1,0.66,15.53072216,11.167014,16.001391,7.340818,13.482061,2.210827,2.451542,...,0.703222,0.830446,0.740088,0.898104,0.616249,0.832615,0.923726,0.553754,0.812765,0.75031
2,TTCCGCTGTCTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGA...,CAT1,0.67,15.0963833,10.923812,10.596595,14.943808,17.676481,36.331256,38.367234,...,0.530243,0.577685,0.540406,0.867787,0.349599,0.795582,0.593482,0.492691,0.794212,0.520792
3,CCACTAGGTGGCGCTAGGCGGAAACGCCGCACTTCCGGCAACAACC...,CAT1,0.66,14.30055153,11.167014,7.787699,16.167277,14.890188,1.535296,2.185858,...,0.493161,0.729998,0.519955,0.833795,0.683469,0.897361,0.684443,0.60225,0.827346,0.581546
4,GCACCACTAGGTGGCGCTAGGCCACGCCCTCTTGTGGGCGGAAACG...,CAT1,0.68,13.70067341,17.439592,8.306878,18.002482,9.332583,2.530169,3.043293,...,0.723427,0.791749,0.576626,0.924453,0.837778,0.90536,0.784332,0.745218,0.897033,0.737664


In [7]:
colshare = [c for c in list(set(norm_perturb1.columns).intersection(set(norm_perturb2.columns))) if 'dpn' not in c]
norm_perturb = norm_perturb1.merge(norm_perturb2,on=colshare)
norm_perturb.head()

Unnamed: 0,seq,category,gc.content,description,brachyury.24h.rep1-dpni_x,brachyury.24h.rep2-dpni_x,foxa2.24h.rep1-dpni_x,foxa2.24h.rep2-dpni_x,brachyury.24h.rep3-dpni,brachyury.48h.rep1-dpni,...,brachyury.24h.rep1-dpnii_y,brachyury.24h.rep2-dpnii_y,foxa2.24h.rep1-dpnii_y,foxa2.24h.rep2-dpnii_y,foxa2.24h.rep1_Sept,brachyury.24h.rep2_Sept,foxa2.24h.rep2_Sept,brachyury.24h.rep1_Sept,foxa2.24h_Sept.average,brachyury.24h_Sept.average
0,CCGCTGTCTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGAAA...,CAT1,0.69,16.3110504,15.805721,11.226281,25.674685,20.288268,54.918665,3.951769,...,9.422685,3.584269,18.980642,9.595763,0.718863,0.413633,0.607323,0.638267,0.663093,0.52595
1,GAAAACGTTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGAAA...,CAT1,0.66,15.53072216,12.384197,4.90678,17.846861,11.558997,7.103587,3.913402,...,5.384391,2.560192,16.693231,2.515743,0.717928,0.665584,0.49463,0.60427,0.606279,0.634927
2,TTCCGCTGTCTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGA...,CAT1,0.67,15.0963833,9.941788,9.457683,20.93186,10.135559,6.073459,2.877502,...,9.118727,9.831137,8.711628,13.836587,0.482192,0.498132,0.615552,0.444986,0.548872,0.471559
3,CCACTAGGTGGCGCTAGGCGGAAACGCCGCACTTCCGGCAACAACC...,CAT1,0.66,14.30055153,20.174082,10.050811,12.887798,12.193541,7.725955,3.580891,...,3.734336,2.935687,2.725426,3.737676,0.455804,0.599442,0.510101,0.558213,0.482953,0.578828
4,GCACCACTAGGTGGCGCTAGGCCACGCCCTCTTGTGGGCGGAAACG...,CAT1,0.68,13.70067341,21.185476,12.207637,35.664891,24.129834,8.112253,41.295346,...,25.054788,21.676291,21.949409,7.834743,0.65045,0.908361,0.236417,0.737288,0.443433,0.822825


In [8]:
value_cols = [c for c in norm_perturb.columns if ('ed' in c or 'es' in c or 'brachyury.24h' in c or 'foxa2.24h' in c) and ('description' not in c and 'average' not in c and 'dpn' not in c  and 'August' not in c)]
print(value_cols)

['foxa2.24h.rep1_Sept', 'brachyury.24h.rep2_Sept', 'foxa2.24h.rep2_Sept', 'brachyury.24h.rep1_Sept']


In [9]:
norm_perturb['brachyury.average'] = (norm_perturb['brachyury.24h.rep1_Sept'] +  norm_perturb['brachyury.24h.rep2_Sept'])/2.0
norm_perturb['foxa2.average'] = (norm_perturb['foxa2.24h.rep1_Sept'] +  norm_perturb['foxa2.24h.rep2_Sept'])/2.0

In [10]:
def get_kmer(d):
    if len(d.split('/')) > 2:
        if len(d.split('/')[2].split(',')) > 2:
            return d.split('/')[2].split(',')[2]
    else:
        return 'None'
def get_type(d):
    if 'S' in d:
        return 'scrambled'
    else:
        return 'motif'
def get_bg(d):
    if len(d.split('/')) > 2:
        if len(d.split('/')[2].split(',')) > 2:
            return '/'.join(d.split('/')[2].split(',')[:2])
    else:
        return 'None'   
replicates=[c for c in norm_df.columns if c[0] == 'e' and 'average' not in c]
norm_perturb = pd.merge(norm_perturb, norm_df[['seq']+replicates+['es.average','ed.average']], on='seq')
norm_perturb['kmer'] = [get_kmer(d) for d  in norm_perturb['description']]
norm_perturb['control'] = [get_type(c) for c in norm_perturb['category']]
norm_perturb['background'] = [get_bg(d) for d in norm_perturb['description']]

In [None]:
value_cols = [c for c in norm_perturb.columns if ('ed' in c or 'es' in c or 'brachyury.24h' in c or 'foxa2.24h' in c) and ('description' not in c and 'average' not in c and 'dpn' not in c  and 'August' not in c)]

In [11]:
cat2desc = {
    "CAT1": "universally opening phrases with GC content between 60 - 70 %",
"CAT2": "universally opening phrases with GC content between 30 - 50 %",
"CAT3": "universally closing phrases with GC content between 60 - 70 %",
"CAT4": "universally closing phrases with GC content between 30 - 50 %",
"CAT5": "opening ES cells using one or more occurrences of one k-mer per phrase",
"CAT6": "opening ES cells using combinations of k-mers per phrase",
"CAT7": "phrases opening ED cells using one or more occurrences of one k-mer per phrase",
"CAT8": "phrases opening ED cells using combinations of k-mers per phrase",
"CAT9": "ES-Salient-TF",
"CAT10": "ES-Salient-Top",
"CAT11": "ES-Native",
"CAT12": "ED-Salient-TF",
"CAT13": "ED-Salient-Top",
"CAT14": "ED-Native",
"CAT15": "SLOT-CNN",
"CAT16": "background"
}
norm_perturb['catdesc'] =[cat2desc[c.strip('S')] for c in norm_perturb['category']]

In [12]:
norm_perturb.head()

Unnamed: 0,seq,category,gc.content,description,brachyury.24h.rep1-dpni_x,brachyury.24h.rep2-dpni_x,foxa2.24h.rep1-dpni_x,foxa2.24h.rep2-dpni_x,brachyury.24h.rep3-dpni,brachyury.48h.rep1-dpni,...,es1.2,ed1.2,es4.2,es2.2,es.average,ed.average,kmer,control,background,catdesc
0,CCGCTGTCTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGAAA...,CAT1,0.69,16.3110504,15.805721,11.226281,25.674685,20.288268,54.918665,3.951769,...,0.802863,0.90222,0.806932,0.786676,0.690074,0.870671,,motif,,universally opening phrases with GC content be...
1,GAAAACGTTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGAAA...,CAT1,0.66,15.53072216,12.384197,4.90678,17.846861,11.558997,7.103587,3.913402,...,0.616249,0.832615,0.923726,0.553754,0.75031,0.812765,,motif,,universally opening phrases with GC content be...
2,TTCCGCTGTCTCGCGCACCACTAGGTGGCGCCCTCTTGTGGGCGGA...,CAT1,0.67,15.0963833,9.941788,9.457683,20.93186,10.135559,6.073459,2.877502,...,0.349599,0.795582,0.593482,0.492691,0.520792,0.794212,,motif,,universally opening phrases with GC content be...
3,CCACTAGGTGGCGCTAGGCGGAAACGCCGCACTTCCGGCAACAACC...,CAT1,0.66,14.30055153,20.174082,10.050811,12.887798,12.193541,7.725955,3.580891,...,0.683469,0.897361,0.684443,0.60225,0.581546,0.827346,,motif,,universally opening phrases with GC content be...
4,GCACCACTAGGTGGCGCTAGGCCACGCCCTCTTGTGGGCGGAAACG...,CAT1,0.68,13.70067341,21.185476,12.207637,35.664891,24.129834,8.112253,41.295346,...,0.837778,0.90536,0.784332,0.745218,0.737664,0.897033,,motif,,universally opening phrases with GC content be...


# number motifs significant change ES -> ES + FoxA2/Brachyury

In [28]:
keep  = norm_perturb[norm_perturb['catdesc'].isin(['ES-Salient-TF','ES-Salient-Top',
                                      'ED-Salient-TF','ED-Salient-Top',
                                      'SLOT-CNN'])]

wilcoxon_tests = {}
ttest_tests = {}
all_cats = list(set(keep['kmer']))
pos_cats = []
for cat in all_cats:
    keep_cat = keep[keep['kmer']==cat]
    keep_pd = keep_cat.pivot_table(index=['background'],columns=['control'],
                                   values=['brachyury.average','es.average']).dropna()
    es_scram = wilcoxon(keep_pd['brachyury.average'].motif,keep_pd['es.average'].motif)
    wilcoxon_tests[cat] = es_scram[1]
    es_scram = ttest_rel(keep_pd['brachyury.average'].motif,keep_pd['es.average'].motif)
    ttest_tests[cat] = es_scram[1]
    if es_scram[0] > 0:
        pos_cats.append(cat)
wilcoxon_pvals = np.array([wilcoxon_tests[cat] for cat in all_cats])
ttest_pvals = np.array([ttest_tests[cat] for cat in all_cats])
wilcoxon_true,_,_,_ = multipletests(wilcoxon_pvals,method='fdr_bh')
ttest_true,_,_,_ = multipletests(ttest_pvals,method='fdr_bh')
diff_sig_cats = [cat for i,cat in enumerate(all_cats) if ttest_true[i] and wilcoxon_true[i]]
diff_sig_pos_cats = [cat for i,cat in enumerate(all_cats) if ttest_true[i] and wilcoxon_true[i] and cat in pos_cats]

In [30]:
keep  = norm_perturb[norm_perturb['catdesc'].isin(['ES-Salient-TF','ES-Salient-Top',
                                      'ED-Salient-TF','ED-Salient-Top',
                                      'SLOT-CNN'])]

wilcoxon_tests = {}
ttest_tests = {}
all_cats = list(set(keep['kmer']))
for cat in diff_sig_pos_cats:
    keep_cat = keep[keep['kmer']==cat]
    keep_pd = keep_cat.pivot_table(index=['background'],columns=['control'],
                                   values=['brachyury.average','es.average']).dropna()
    try:
        es_scram = wilcoxon(keep_pd['es.average'].motif,keep_pd['es.average'].scrambled)
        ot_scram = wilcoxon(keep_pd['brachyury.average'].motif,keep_pd['brachyury.average'].scrambled)
        wilcoxon_tests[cat] = min(es_scram[1],ot_scram[1])
        es_scram = ttest_rel(keep_pd['es.average'].motif,keep_pd['es.average'].scrambled)
        ot_scram = ttest_rel(keep_pd['brachyury.average'].motif,keep_pd['brachyury.average'].scrambled)
        ttest_tests[cat] = min(es_scram[1],ot_scram[1])
    except AttributeError:
        pass
test_cats = list(wilcoxon_tests.keys())
wilcoxon_pvals = np.array([wilcoxon_tests[cat] for cat in test_cats])
ttest_pvals = np.array([ttest_tests[cat] for cat in test_cats])
wilcoxon_true,_,_,_ = multipletests(wilcoxon_pvals,method='fdr_bh')
ttest_true,_,_,_ = multipletests(ttest_pvals,method='fdr_bh')
diff_scram_sig_cats = [cat for i,cat in enumerate(test_cats) if ttest_true[i] and wilcoxon_true[i]]

In [31]:
print(len(diff_sig_cats),'/',len(all_cats))
print(len(diff_sig_pos_cats),'/',len(all_cats))
print(len(diff_scram_sig_cats),'/',len(all_cats))

13 / 76
7 / 76
6 / 76


In [24]:
keep  = norm_perturb[norm_perturb['catdesc'].isin(['ES-Salient-TF','ES-Salient-Top',
                                      'ED-Salient-TF','ED-Salient-Top',
                                      'SLOT-CNN'])]

wilcoxon_tests = {}
ttest_tests = {}
all_cats = list(set(keep['kmer']))
for cat in all_cats:
    keep_cat = keep[keep['kmer']==cat]
    keep_pd = keep_cat.pivot_table(index=['background'],columns=['control'],
                                   values=['foxa2.average','es.average']).dropna()
    es_scram = wilcoxon(keep_pd['foxa2.average'].motif,keep_pd['es.average'].motif)
    wilcoxon_tests[cat] = es_scram[1]
    es_scram = ttest_rel(keep_pd['foxa2.average'].motif,keep_pd['es.average'].motif)
    ttest_tests[cat] = es_scram[1]
wilcoxon_pvals = np.array([wilcoxon_tests[cat] for cat in all_cats])
ttest_pvals = np.array([ttest_tests[cat] for cat in all_cats])
wilcoxon_true,_,_,_ = multipletests(wilcoxon_pvals,method='fdr_bh')
ttest_true,_,_,_ = multipletests(ttest_pvals,method='fdr_bh')
diff_sig_cats = [cat for i,cat in enumerate(all_cats) if ttest_true[i] and wilcoxon_true[i]]

In [25]:
keep  = norm_perturb[norm_perturb['catdesc'].isin(['ES-Salient-TF','ES-Salient-Top',
                                      'ED-Salient-TF','ED-Salient-Top',
                                      'SLOT-CNN'])]

wilcoxon_tests = {}
ttest_tests = {}
all_cats = list(set(keep['kmer']))
for cat in diff_sig_cats:
    keep_cat = keep[keep['kmer']==cat]
    keep_pd = keep_cat.pivot_table(index=['background'],columns=['control'],
                                   values=['foxa2.average','es.average']).dropna()
    try:
        es_scram = wilcoxon(keep_pd['es.average'].motif,keep_pd['es.average'].scrambled)
        ot_scram = wilcoxon(keep_pd['foxa2.average'].motif,keep_pd['foxa2.average'].scrambled)
        wilcoxon_tests[cat] = min(es_scram[1],ot_scram[1])
        es_scram = ttest_rel(keep_pd['es.average'].motif,keep_pd['es.average'].scrambled)
        ot_scram = ttest_rel(keep_pd['foxa2.average'].motif,keep_pd['foxa2.average'].scrambled)
        ttest_tests[cat] = min(es_scram[1],ot_scram[1])
    except AttributeError:
        pass
test_cats = list(wilcoxon_tests.keys())
wilcoxon_pvals = np.array([wilcoxon_tests[cat] for cat in test_cats])
ttest_pvals = np.array([ttest_tests[cat] for cat in test_cats])
wilcoxon_true,_,_,_ = multipletests(wilcoxon_pvals,method='fdr_bh')
ttest_true,_,_,_ = multipletests(ttest_pvals,method='fdr_bh')
diff_scram_sig_cats = [cat for i,cat in enumerate(test_cats) if ttest_true[i] and wilcoxon_true[i]]

In [26]:
print(len(diff_sig_cats),'/',len(all_cats))
print(len(diff_scram_sig_cats),'/',len(all_cats))

44 / 76
35 / 76


In [None]:
norm_perturb = norm_perturb[['seq','category','gc.content','kmer','control','background']+value_cols]
norm_perturb.head()

In [None]:
normmelt = pd.melt(norm_perturb,id_vars=['seq','catdesc','gc.content','control','kmer','background'],
                   value_vars=value_cols)
normmelt.head()

In [None]:
def get_condition(var):
    for condition in ['foxa2','brachyury','es','ed']:
        if condition in var:
            return condition
normmelt['condition'] = [get_condition(var) for var in normmelt['variable']]

In [None]:
set(normmelt['background'])

In [None]:
from statsmodels.regression.linear_model import OLS
conditions = list(set(normmelt['condition']))
train_backgrounds=['1308/-',
 '1308/scrambled-1',
 '1308/scrambled-2',
 '1308/scrambled-5',
 '1343/-',
 '1343/scrambled-1',
 '1343/scrambled-2',
 '1343/scrambled-5',
 '1383/-',
 '1383/scrambled-1',
 '1383/scrambled-2',
 '1383/scrambled-5',
 '1389/-',
 '1389/scrambled-1',
 '1389/scrambled-2',
 '1389/scrambled-5',
 '1470/-',
 '1470/scrambled-1',
 '1470/scrambled-2',
 '1470/scrambled-5']
normmelt_es_cond = normmelt[normmelt['background'].isin(train_backgrounds)]
endog=np.array(normmelt_es_cond['value'])
kmers = [k for k in list(set(normmelt_es_cond['kmer'])) if k != None and k != 'None']
exog_kmer_cond = np.zeros((len(endog),len(kmers)*(len(conditions))))
for j,condition in enumerate(conditions):
    for i,kmer in enumerate(kmers):
        contains_kmer = np.logical_and(normmelt_es_cond['kmer']==kmer,normmelt_es_cond['control']=='motif')
        exog_kmer_cond[:,(j*len(kmers))+i] = np.logical_and(contains_kmer,normmelt_es_cond['condition']==condition)
    
replicates = list(set(normmelt_es_cond['variable']))
exog_rep = np.zeros((len(endog),len(replicates)))
for i,rep in enumerate(replicates):
    exog_rep[:,i] = (normmelt_es_cond['variable'] == rep)
exog_gc = np.array(normmelt_es_cond['gc.content']).reshape((-1,1))
print(np.sum(exog_rep,axis=0))
plt.plot(np.sum(exog_kmer_cond,axis=0))
plt.show()
exog = np.concatenate([exog_rep,exog_kmer_cond,exog_gc],axis=1)

model = OLS(endog,exog)
res = model.fit()


In [None]:
test_data = normmelt[~normmelt['background'].isin(train_backgrounds+['None',None])]
test_endog=np.array(test_data['value'])

In [None]:
test_kmer_cond = np.zeros((len(test_endog),len(kmers)*(len(conditions))))
for j,condition in enumerate(conditions):
    for i,kmer in enumerate(kmers):
        contains_kmer = np.logical_and(test_data['kmer']==kmer,test_data['control']=='motif')
        test_kmer_cond[:,(j*len(kmers))+i] = np.logical_and(contains_kmer,test_data['condition']==condition)

test_rep = np.zeros((len(test_endog),len(replicates)))
for i,rep in enumerate(replicates):
    test_rep[:,i] = (test_data['variable'] == rep)
test_gc = np.array(test_data['gc.content']).reshape((-1,1))
print(np.sum(test_rep,axis=0))
plt.plot(np.sum(test_kmer_cond,axis=0))
plt.show()
test_X = np.concatenate([test_rep,test_kmer_cond,test_gc],axis=1)


In [None]:
#SAVE EVERYTHING
normmelt_es_cond.to_pickle('regression_results/training_dataframe.pkl')
test_data.to_pickle('regression_results/testing_dataframe.pkl')
np.save('regression_results/trainY.npy',endog)
np.save('regression_results/trainX.npy',exog)
np.save('regression_results/testY.npy',test_endog,)
np.save('regression_results/testX.npy',test_X)
res.save("regression_results/ols_trained_model.pkl")
with open('regression_results/parameters.pkl','wb') as f:
    pickle.dump({'conditions':conditions,'replicates':replicates,'kmers':kmers},f)

In [None]:
sns.set_context('notebook',font_scale=1.5)
param = results[0].params
param_kmers = param[len(replicates):]
#colors=['blue','purple','aqua','red']
cmap = {'es':'blue','brachyury':'purple','foxa2':'aqua','ed':'red'}
big_mat = np.zeros((len(kmers),len(conditions)))
for j,condition in enumerate(conditions):
    big_mat[:,j] = param_kmers[j*len(kmers):(j*len(kmers)+len(kmers))]
    if condition != 'es':
        allsorted = np.argsort((param_kmers[j*len(kmers):(j*len(kmers)+len(kmers))]-param_kmers[0*len(kmers):(0*len(kmers)+len(kmers))]))[::-1]
    else:
        allsorted = np.argsort((param_kmers[j*len(kmers):(j*len(kmers)+len(kmers))]-param_kmers[1*len(kmers):(1*len(kmers)+len(kmers))]))[::-1]

    top = allsorted[param_kmers[j*len(kmers):(j*len(kmers)+len(kmers))][allsorted]>0][:10]
    data = np.zeros((len(top),len(conditions)))
    for oj,oth in enumerate(conditions):
        for ti,t in enumerate(top):
            data[ti,oj] = param_kmers[oj*len(kmers)+t]
    data_df = pd.DataFrame(data=data,columns=conditions,index=[kmers[t] for t in top])
    
    #TODO make this a heatmap
    sns.heatmap(data_df,cmap='RdBu')
    #plt.ylabel('')
    #plt.xlabel('parameter weight')
    #plt.legend(bbox_to_anchor=(1.0,1.0))
    plt.title('Top 10 '+condition)
    plt.show()

big_df = pd.DataFrame(data=big_mat,columns=conditions,index=kmers)
sns.clustermap(big_df,cmap='RdBu',figsize=(20,40))
#plt.legend()
#plt.show()