# Evaluation of CatVAE discretization performance with preprocessed tank Dataset
Visual evaluation of the preciseness of discretization and meaningful categories. <br>


In [None]:
cd ..

In [None]:
import os
import pandas as pd
import torch
import plotly.graph_objects as go
import yaml
from plotly.subplots import make_subplots
from ipywidgets import interact
import numpy as np

from utils import standardize_data
from datamodule import Dataset
from catvae import CategoricalVAE

np.random.seed(123)
torch.manual_seed(123)

In [None]:
def plot_like(idx):
    # load trained model
    MODEL_VERSION = f'VAE_training_hparams/tank/catvae/version_{idx}'
    ckpt_file_name = os.listdir(f'./{MODEL_VERSION}/checkpoints/')[-1]
    ckpt_file_path = f'./{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
    with open(f'./{MODEL_VERSION}/hparams.yaml') as f:
        hparam = yaml.safe_load(f)
    model = CategoricalVAE.load_from_checkpoint(ckpt_file_path, hparams=hparam["hparams"]).to('cuda')
    # read normal data
    df_csv = pd.read_csv(f'preprocessed_data/tank_simulation/norm_long.csv').iloc[2000:, :3].reset_index(drop=True)
    df_csv_realcat = pd.read_csv(f'preprocessed_data/tank_simulation/norm_long.csv').iloc[2000:, 3].reset_index(drop=True)
    df_sc = standardize_data(df_csv, 'scaler_tank.pkl')
    df_ = df_sc.iloc[:, :].reset_index(drop=True)
    df_sc = Dataset(dataframe = df_sc)
    # compute discretized categories and likelihoods
    likelihood = pd.DataFrame(model.function_likelihood(torch.tensor(df_sc).to(device='cuda')).cpu().detach()).rolling(10).median().fillna(method='bfill')
    pzx_logits, pzx, mu, sigma, pxz, z = model.get_states(torch.tensor(df_sc).to(device='cuda'))
    df_states = pd.DataFrame(torch.zeros(z.shape).to(device='cuda').scatter(1, torch.argmax(pzx_logits, dim=1).unsqueeze(1), 1).cpu().detach().numpy(), index=df_.index).astype(int)
    cats = pd.DataFrame(df_states.idxmax(axis=1), index=pd.DataFrame(df_).index)  
    # compute purity measure
    cluster_assignments = cats[cats.columns[0]]
    class_assignments = df_csv_realcat
    assert len(cluster_assignments) == len(class_assignments)
    
    num_samples = len(cluster_assignments)
    num_clusters = len(np.unique(cluster_assignments))
    num_classes = len(np.unique(class_assignments))
    
    cluster_class_counts = {cluster_: {class_: 0 for class_ in np.unique(class_assignments)}
                            for cluster_ in np.unique(cluster_assignments)}
    
    for cluster_, class_ in zip(cluster_assignments, class_assignments):
        cluster_class_counts[cluster_][class_] += 1
        
    total_intersection = sum([max(list(class_dict.values())) for cluster_, class_dict in cluster_class_counts.items()])
    purity = total_intersection/num_samples
    print(purity)

    fig = make_subplots(rows=4, cols=1, shared_xaxes=True)
    for i in range(0,3):
        fig.add_trace(go.Scatter(x=pd.DataFrame(df_).index, y=pd.DataFrame(df_)[pd.DataFrame(df_).columns[i]], name=df_.columns[i],  mode='markers'), 
                      row=1, col=1)
    fig.add_trace(go.Scatter(x = pd.DataFrame(df_).index, y=cats[cats.columns[0]], name='discretized category', mode='lines'),row=2, col=1)
    fig.add_trace(go.Scatter(x=pd.DataFrame(df_).index, y=df_csv_realcat.values, name='real category'), row=3, col=1)
    return fig


# computation of total purity
def calc_purity():
    total_purity = []
    # compute mean and var of trained models
    for idx in range(10):
        # load trained model 
        MODEL_VERSION = f'VAE_training_hparams/tank/catvae/version_{idx}'
        ckpt_file_name = os.listdir(f'./{MODEL_VERSION}/checkpoints/')[-1]
        ckpt_file_path = f'./{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
        with open(f'./{MODEL_VERSION}/hparams.yaml') as f:
            hparam = yaml.safe_load(f)
        model = CategoricalVAE.load_from_checkpoint(ckpt_file_path, hparams=hparam["hparams"]).to('cuda')
        # read normal data
        df_csv = df_csv = pd.read_csv(f'preprocessed_data/tank_simulation/norm_long.csv').reset_index(drop=True).iloc[1500:, :3]
        df_csv_realcat = pd.read_csv(f'preprocessed_data/tank_simulation/norm_long.csv').reset_index(drop=True).iloc[1500:, 3]
        df_sc = standardize_data(df_csv, 'scaler_tank.pkl')
        df = Dataset(dataframe = df_sc)[:][0:]
        # compute discretized categories and likelihoods
        likelihood = pd.DataFrame(model.function_likelihood(torch.tensor(df).to(device='cuda')).cpu().detach()).rolling(10).median().fillna(method='bfill')
        pzx_logits, pzx, mu, sigma, pxz, z = model.get_states(torch.tensor(df).to(device='cuda'))
        df_states = pd.DataFrame(torch.zeros(z.shape).to(device='cuda').scatter(1, torch.argmax(pzx_logits, dim=1).unsqueeze(1), 1).cpu().detach().numpy(), index=pd.DataFrame(df_csv).index).astype(int)
        cats = pd.DataFrame(df_states.idxmax(axis=1), index=pd.DataFrame(df_csv).index)  
        # compute purity measure
        cluster_assignments = cats[cats.columns[0]]
        class_assignments = df_csv_realcat
        assert len(cluster_assignments) == len(class_assignments)
        
        num_samples = len(cluster_assignments)
        num_clusters = len(np.unique(cluster_assignments))
        num_classes = len(np.unique(class_assignments))
        
        cluster_class_counts = {cluster_: {class_: 0 for class_ in np.unique(class_assignments)}
                                for cluster_ in np.unique(cluster_assignments)}
        
        for cluster_, class_ in zip(cluster_assignments, class_assignments):
            cluster_class_counts[cluster_][class_] += 1
            
        total_intersection = sum([max(list(class_dict.values())) for cluster_, class_dict in cluster_class_counts.items()])
        
        purity = total_intersection/num_samples
        total_purity.append(purity)
    mean = sum(total_purity)/len(total_purity)
    variance = sum([((x - mean) ** 2) for x in total_purity]) / len(total_purity) 
    std = variance ** 0.5
    return {'mean ':mean, 'std ':std}

In [None]:
# plot function to see likelihoods and discretizations
interact(plot_like, idx=range(10))

In [None]:
# computing the purity
calc_purity()