In [None]:
%matplotlib inline
#%matplotlib notebook

import pickle
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import defaultdict
from scipy import spatial
from Bio import pairwise2
from tqdm import tqdm_notebook
from glob import glob

if '../' not in sys.path:
    sys.path.append('../')

from utils import number2patten
from utils import slice_out_kmer

sns.set(style='ticks')

### Load Embeddings

In [None]:
kmer_sizes = np.array([3, 4, 5])

glob_str = '../max5_min3_mers_10padding_64embedding_epoch1_batch*.pickle'
emb_files = {int(f.rstrip('.pickle').split('batch')[1]): f for f in glob(glob_str)}
latest_emb_file = emb_files[sorted(emb_files)[-1]]
print(latest_emb_file)

kmer_emb = pickle.load(open(latest_emb_file, 'rb'))
print('All:', kmer_emb.shape)

### Computing Nedleman-Wunsch Score and Cosine Distance

In [None]:
def get_comb(kmer_size):
    num_kmers = 4**kmer_size
    for i in range(0, num_kmers):
        for ii in range(i+1, num_kmers):
            yield i, ii


def compute_scores(kmer_embeddings, kmer_size, 
                   new_global_align=False, new_cosine_similarity=False):
    
    np_files = {
        'global_align': 'global_align_scores_{}-mers'.format(kmer_size),
        'cosine_similarity': 'cosine_similarity_{}-mers'.format(kmer_size)
    }
    
    combinations = (((4**kmer_size)**2) - 4**kmer_size) // 2
    
    global_align_exist = os.path.isfile(np_files['global_align'] + '.npy')
    cosine_sim_exist = os.path.isfile(np_files['cosine_similarity'] + '.npy')
    
    done = 0
    
    if global_align_exist and not new_global_align:
        global_align_scores = np.load(np_files['global_align'] + '.npy')
        done += 1
    else:
        global_align_scores = np.zeros(combinations, dtype=np.int8)

    if cosine_sim_exist and not new_cosine_similarity:
        cosine_similarity = np.load(np_files['cosine_similarity'] + '.npy')
        done += 1
    else:
        cosine_similarity = np.zeros(global_align_scores.size)

    if done < 2:
        for i, (num_seq1, num_seq2) in tqdm_notebook(enumerate(get_comb(kmer_size)), 
                                                     total=combinations,
                                                     desc='{}-mer'.format(kmer_size)):
        
            seq1 = number2patten(num_seq1, kmer_size)
            seq2 = number2patten(num_seq2, kmer_size)
        
            if not global_align_exist or new_global_align:
                global_align_scores[i] = max([align[2] for align in pairwise2.align.globalxx(seq1, seq2)])
            
            if not cosine_sim_exist or new_cosine_similarity:
                cosine_similarity[i] = 1 - spatial.distance.cosine(kmer_embeddings[num_seq1], 
                                                               kmer_embeddings[num_seq2])
            
        if not global_align_exist or new_global_align:
            np.save(np_files['global_align'], global_align_scores)
        if not cosine_sim_exist or new_cosine_similarity:
            np.save(np_files['cosine_similarity'], cosine_similarity)
    
    return pd.DataFrame(data={'global_align_scores': global_align_scores,
                              'cosine_similarity': cosine_similarity})


# Look at all kmer sizes
for kmer_size in kmer_sizes:
    
    emb = slice_out_kmer(kmer_emb, kmer_size, min(kmer_sizes), max(kmer_sizes))
    kmer_df = compute_scores(emb, kmer_size, new_global_align=False, new_cosine_similarity=True)
    kmer_df['kmer_size'] = kmer_size
    
    if kmer_size == kmer_sizes[0]:
        all_kmer_df = kmer_df.copy()
    else:
        all_kmer_df = all_kmer_df.append(kmer_df, ignore_index=True)

In [None]:
from collections import Counter
Counter(all_kmer_df['kmer_size'])

### Plotting  Nedleman-Wunsch Score vs. Emb. Cosine Distance

In [None]:
#sns.violinplot(x='global_align_scores', y='cosine_similarity', hue='kmer_size', data=all_kmer_df)

g = sns.FacetGrid(all_kmer_df, col='kmer_size', size=4, aspect=1)
g = g.map(sns.violinplot, 'global_align_scores', 'cosine_similarity')

plt.show()