In [None]:
cd ..

In [None]:
import pandas as pd
import numpy as np
import os
import yaml
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import interact
import torch

from seqcat_datamodule import Dataset
from seqcat_catvae import seqcat_vae
from utils import preprocess_data, standardize_data


In [None]:
def plot_seqcat(idx):
    MODEL_VERSION = f'VAE_training_hparams/BeRfiPl/rnn_catvae/version_{idx}'
    ckpt_file_name = os.listdir(f'./{MODEL_VERSION}/checkpoints/')[-1]
    ckpt_file_path = f'./{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
    with open(f'./{MODEL_VERSION}/hparams.yaml') as f:
        hparam = yaml.safe_load(f)
    model = seqcat_vae.load_from_checkpoint(ckpt_file_path, hparams=hparam["hparams"])

    _, data, states, _, _ = preprocess_data('BeRfiPl_labels')
    powers_of_two = 2**torch.arange(states.size(1) - 1, -1, -1).float()
    categories = torch.matmul(states, powers_of_two)
    _, new_categories = torch.unique(categories, return_inverse=True)
    new_categories = pd.DataFrame(new_categories).iloc[2000:3000].reset_index(drop=True)
    df = pd.DataFrame(data[2000:3000])

    df_sc = standardize_data(df, 'scaler_BeRfiPl.pkl')
    dataset = Dataset(dataframe=df_sc, number_timesteps=hparam["hparams"]["NUMBER_TIMESTEPS"])

    all_cats = []
    all_kl = []
    all_like = []
    all_mu = []
    for window in dataset:
        pzx_logits, pzx, mu, sigma, pxz, z = model.get_states(window.unsqueeze(0).to('cuda'))
        _, kl = model.kl_divergence(pzx=pzx)
        like = model.function_likelihood(x=window.unsqueeze(0).to('cuda')).mean()
        z_list = z.detach().cpu().numpy().astype(int)
        all_cats.append(z_list)
        all_kl.append(kl.detach().cpu().numpy())
        all_like.append(like.detach().cpu().numpy())
        all_mu.append(mu.detach().cpu().numpy())

    all = pd.DataFrame(np.vstack(all_cats))
    kl_ = pd.DataFrame(np.vstack(all_kl))
    cats = pd.DataFrame(all.idxmax(axis=1))
    like_ = pd.DataFrame(all_like)
    mu_ = pd.DataFrame(np.vstack(all_mu)[::10].reshape(-1, np.vstack(all_mu)[::10].shape[2]))
    data_ = pd.DataFrame(np.vstack(dataset)[::10].reshape(-1, np.vstack(all_mu)[::10].shape[2]))
    unique_cats = cats[cats.columns[0]].unique()

    cluster_assignments = cats[cats.columns[0]]
    class_assignments = new_categories[0][3:-7]
    assert len(cluster_assignments) == len(class_assignments)
    
    num_samples = len(cluster_assignments)
    num_clusters = len(np.unique(cluster_assignments))
    num_classes = len(np.unique(class_assignments))
    
    cluster_class_counts = {cluster_: {class_: 0 for class_ in np.unique(class_assignments)}
                            for cluster_ in np.unique(cluster_assignments)}
    
    for cluster_, class_ in zip(cluster_assignments, class_assignments):
        cluster_class_counts[cluster_][class_] += 1
        
    total_intersection = sum([max(list(class_dict.values())) for cluster_, class_dict in cluster_class_counts.items()])
    
    purity = total_intersection/num_samples
    print(purity)
    print(len(unique_cats))

    fig = make_subplots(rows=4, cols=1, shared_xaxes=True, # was rows=9
                    subplot_titles=["Input data"])
    
    for i in range(0,3):
        fig.add_trace(go.Scatter(x=pd.DataFrame(df_sc).index, y=pd.DataFrame(df_sc)[pd.DataFrame(df_sc).columns[i]], name=df_sc.columns[i],  mode='markers'), 
                      row=1, col=1)
        fig.add_trace(go.Scatter(x=pd.DataFrame(df_sc).index, y=mu_[mu_.columns[i]], name=mu_.columns[i], mode='markers'), 
                      row=1, col=1)
    fig.add_trace(go.Scatter(x=df_sc.index, y=like_[like_.columns[0]], mode='lines', name='likelihood'), row=4, col=1)
    fig.add_trace(go.Scatter(x=df_sc.index+10, y=cats[cats.columns[0]], mode='lines', name='vae cats'), row=2, col=1)
    fig.add_trace(go.Scatter(x=new_categories.index, y=new_categories[new_categories.columns[0]], mode='markers', name='real_cat'), row=3, col=1)

    fig.update_xaxes(
                    mirror=True,
                    ticks='outside',
                    showline=True,
                    linecolor='black',
                    gridcolor='lightgrey',
                    zerolinecolor= 'grey'
                )
    fig.update_yaxes(
                    mirror=True,
                    ticks='outside',
                    showline=True,
                    linecolor='black',
                    gridcolor='lightgrey',
                    zerolinecolor= 'grey'
                )
    fig.update_layout(
        plot_bgcolor='white'
)

    return fig

def calc_purity():
    total_purity = []

    for idx in range(0,10):
        MODEL_VERSION = f'VAE_training_hparams/BeRfiPl/rnn_catvae/version_{idx}'
        ckpt_file_name = os.listdir(f'./{MODEL_VERSION}/checkpoints/')[-1]
        ckpt_file_path = f'./{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
        with open(f'./{MODEL_VERSION}/hparams.yaml') as f:
            hparam = yaml.safe_load(f)
        model = seqcat_vae.load_from_checkpoint(ckpt_file_path, hparams=hparam["hparams"])

        _, data, states, _, _ = preprocess_data('BeRfiPl_labels')
        powers_of_two = 2**torch.arange(states.size(1) - 1, -1, -1).float()
        categories = torch.matmul(states, powers_of_two)
        _, new_categories = torch.unique(categories, return_inverse=True)
        new_categories = pd.DataFrame(new_categories).iloc[3000:5000].reset_index(drop=True)
        df = pd.DataFrame(data[3000:5000])

        df_sc = standardize_data(df, 'scaler_BeRfiPl.pkl')
        dataset = Dataset(dataframe=df_sc, number_timesteps=hparam["hparams"]["NUMBER_TIMESTEPS"])

        all_cats = []
        all_kl = []
        all_like = []
        all_mu = []
        for window in dataset:
            pzx_logits, pzx, mu, sigma, pxz, z = model.get_states(window.unsqueeze(0).to('cuda'))
            _, kl = model.kl_divergence(pzx=pzx)
            like = model.function_likelihood(x=window.unsqueeze(0).to('cuda')).mean()
            z_list = z.detach().cpu().numpy().astype(int)
            all_cats.append(z_list)
            all_kl.append(kl.detach().cpu().numpy())
            all_like.append(like.detach().cpu().numpy())
            all_mu.append(mu.detach().cpu().numpy())

        all = pd.DataFrame(np.vstack(all_cats))
        kl_ = pd.DataFrame(np.vstack(all_kl))
        cats = pd.DataFrame(all.idxmax(axis=1))
        like_ = pd.DataFrame(all_like)
        mu_ = pd.DataFrame(np.vstack(all_mu)[::10].reshape(-1, np.vstack(all_mu)[::10].shape[2]))
        data_ = pd.DataFrame(np.vstack(dataset)[::10].reshape(-1, np.vstack(all_mu)[::10].shape[2]))
        unique_cats = cats[cats.columns[0]].unique()

        cluster_assignments = cats[cats.columns[0]]
        class_assignments = new_categories[0][3:-7]
        assert len(cluster_assignments) == len(class_assignments)
        
        num_samples = len(cluster_assignments)
        num_clusters = len(np.unique(cluster_assignments))
        num_classes = len(np.unique(class_assignments))
        
        cluster_class_counts = {cluster_: {class_: 0 for class_ in np.unique(class_assignments)}
                                for cluster_ in np.unique(cluster_assignments)}
        
        for cluster_, class_ in zip(cluster_assignments, class_assignments):
            cluster_class_counts[cluster_][class_] += 1
            
        total_intersection = sum([max(list(class_dict.values())) for cluster_, class_dict in cluster_class_counts.items()])
        
        purity = total_intersection/num_samples
        total_purity.append(purity)

    mean = sum(total_purity)/len(total_purity)
    variance = sum([((x - mean) ** 2) for x in total_purity]) / len(total_purity) 
    std = variance ** 0.5
    return {'mean ':mean, 'std ':std}


    


In [None]:
interact(plot_seqcat, idx = range(10))


In [None]:
calc_purity()