In [1]:
folder = "D:/users/marko/downloads/mirna/analysis"
folder2 = "D:/users/marko/downloads/mirna/models"
folder3 = "D:/users/marko/downloads/mirna/"

In [2]:
model_code = "38tsal5j"

# library imports

In [39]:
import numpy as np
import pandas as pd
import sys 
import os 
import io
import wandb


sys.path.insert(0, folder)
sys.path.insert(0, folder2)
os.chdir(folder)
os.makedirs(f"{folder}/latspaces/{model_code}/", exist_ok=True)


from concepts import *
from utils import *
from model import *
from dataset import *

from tqdm import tqdm

import plotly.graph_objs as go
from ipywidgets import HTML, Image, Layout, interactive, \
                       RadioButtons, HBox, VBox, interact
import PIL.Image
from graphviz import Digraph, Source

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from sklearn.svm import LinearSVC, SVC
from sklearn.linear_model import LogisticRegression


In [4]:
# connect with wandb and access run
api = wandb.Api(api_key='46d1be10d4e9900dd55fb752c4ecaa4ca0341b20')
run = api.run(f"generativemirna/MIRGEN/{model_code}")

In [5]:

class Args:
    
    def __init__(self, args):
        for i in args:
            setattr(self, i, args[i])

In [6]:
args = Args(run.config)
args = process_args(args)

In [7]:
if args.model_type == 'vae':
    model = MIRVAE(args).to('cuda')
elif args.model_type == 'diva':
    model = MIRDIVA(args).to('cuda')

In [8]:
train_loader = get_data_loader(folder3, 'train', analysis=True)
test_loader = get_data_loader(folder3, 'test', analysis=True)

Loading Labels! (~10s)
Loading Names! (~5s)
Loading Labels! (~10s)
Loading Names! (~5s)


In [9]:
concepts_train = pd.read_csv(f'{folder3}/data/concepts_tr.csv')
concepts_test = pd.read_csv(f'{folder3}/data/concepts_te.csv')

# concepts_train = create_annotated_df(train_loader.dataset.images*255, train_loader.dataset.labels.argmax(1))
# concepts_test = create_annotated_df(test_loader.dataset.images*255, test_loader.dataset.labels.argmax(1))

# concepts_train.to_csv(f'{folder3}/data/concepts_tr.csv')
# concepts_test.to_csv(f'{folder3}/data/concepts_te.csv')

concepts_train['loop_length'].fillna(0, inplace=True)
concepts_train['loop_width'].fillna(0, inplace=True)
concepts_train.replace(['upper','lower'], [1,-1], inplace=True)
concepts_train.fillna(-1, inplace=True)

concepts_test['loop_length'].fillna(0, inplace=True)
concepts_test['loop_width'].fillna(0, inplace=True)
concepts_test.replace(['upper','lower'], [1,-1], inplace=True)
concepts_test.fillna(-1, inplace=True)

In [10]:
train_loader.batch_size

128

In [11]:
z_tr, z_te = {},{}
tr = np.load(f'{folder}/latspaces/{model_code}/train.npz')
te = np.load(f'{folder}/latspaces/{model_code}/test.npz')
x_hat_tr = tr['x']
x_hat_te = te['x']

z_tr['x'] = tr['zx']
z_tr['y'] = tr['zy']
z_tr['m'] = tr['zm']


z_te['x'] = te['zx']
z_te['y'] = te['zy']
z_te['m'] = te['zm']

# z_tr, x_hat_tr = model_analysis(model, args, train_loader)
# z_te, x_hat_te = model_analysis(model, args, test_loader)

# np.savez_compressed(f'{folder}/latspaces/{model_code}/train.npz',
#                     zx=z_tr['x'],
#                     zy=z_tr['y'],
#                     zm=z_tr['m'],
#                     x=x_hat_tr)
# np.savez_compressed(f'{folder}/latspaces/{model_code}/test.npz',
#                     zx=z_te['x'],
#                     zy=z_te['y'],
#                     zm=z_te['m'],
#                     x=x_hat_te)

In [12]:
latent_space_m = TSNE().fit_transform(z_te['m'])

In [13]:
fig = go.FigureWidget(
    data=[
        dict(
            type='scattergl',
            x=latent_space_m[:,0],
            y=latent_space_m[:,1],
            mode='markers',
            marker = dict(color=np.argmax(test_loader.dataset.labels,1),
                          size=4),
          
        )
    ],
    
    layout=dict(height=500, width=750)
)

In [14]:
def set_color(column):
    fig.data[0].marker.color = concepts_test[column]*1

In [15]:
radio = RadioButtons(options=concepts_test.columns[2:])
radiowidget = interact(set_color, column=radio)

interactive(children=(RadioButtons(description='column', options=('class_label', 'presence_terminal_loop', 'st…

In [21]:
def compress_to_bytes(data, fmt):
    """
    Helper function to compress image data via PIL/Pillow.
    """
    data = data.repeat(10, axis=0).repeat(10, axis=1)
    buff = io.BytesIO()
    img = PIL.Image.fromarray(np.uint8(data*255))    
    img.save(buff, format=fmt)
    
    return buff.getvalue()

def hover_fn(trace, points, state):

    ind = points.point_inds[0]

    
    # Update image widget
    im = compress_to_bytes(test_loader.dataset.images[ind],'png')
    image_widget.value = im
    details.value = pd.DataFrame({'name':[test_loader.dataset.names[ind]], 'class':[test_loader.dataset.labels.argmax(1)[ind]]}).to_html(col_space={'name':250,'class':50})
    
details = HTML()
details.value = pd.DataFrame({'name':[test_loader.dataset.names[0]], 'class':[test_loader.dataset.labels.argmax(1)[0]]}).to_html(col_space={'name':250,'class':50})

In [22]:
fig.data[0].on_hover(hover_fn)

In [23]:
fmt='png'
im = compress_to_bytes(test_loader.dataset.images[0],fmt)
image_widget = Image(
    value=im,
    layout=Layout(height='250px', width='1000px')
)

In [24]:
VBox([HBox([fig, 
      radio]), HBox([image_widget]), details])

VBox(children=(HBox(children=(FigureWidget({
    'data': [{'marker': {'color': array([0, 0, 1, ..., 1, 0, 0], …

In [25]:
from decisiontree import Tree

In [26]:

thresholds = {'start_loop_upperhalf_col':[3,7,15,30,45],
              'highest_point_loop_upperhalf_col':[3,7,9,15,30,45],
              'gap_start':[2,8,20,40],
              'palindrome_score':[.25,.5,.6,.7,.8,.9],
              'large_asymmetric_bulge':[2,4,8,10,20,40,60],
              'largest_asym_bulge_sequence_location':[5,20,40,60,80],
              'stem_begin':[15,35,45,55,60,75],
              'stem_end':[3,5,8,12,20,30,40],
              'stem_length':[10,20,30,35,40,45,50,55,60,70,80,90],
              'total_length':[20,30,40,50,60,70,80,90],
              'base_pairs_in_stem':[.1,.3,.5,.7,.9],
              'base_pairs_wobbles_in_stem':[.1,.3,.5,.7,.9],
              'loop_width':[2,8,20,45],
              
              }

In [41]:
dtc = Tree(thresholds, max_depth=5, cls=SVC,
           cls_args=dict(kernel='poly', degree=2))

In [42]:
dtc.fit(z_tr['m'], train_loader.dataset.labels.argmax(1).flatten(), concepts_train[concepts_train.columns[3:]], prune=True)

KeyboardInterrupt: 

In [None]:
dtc.score(z_tr['m'], train_loader.dataset.labels.argmax(1).flatten())

In [None]:
dtc.score(z_te['m'], test_loader.dataset.labels.argmax(1).flatten())

In [None]:
g = dtc.plot_tree()
g.render('dtc_new', format='png')