In [None]:
import os
import pickle
import torch
import random
import sklearn
import warnings
import argparse

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm
from openTSNE import TSNE
from collections import defaultdict
from torch.utils.data import DataLoader
from sklearn.cluster import AffinityPropagation, SpectralClustering

from NeuralCDE_utils import *

os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"
warnings.filterwarnings('ignore')
warnings.filterwarnings(action='ignore', category=UserWarning)
sklearn.set_config(print_changed_only=True)
sns.set_style("white")

In [None]:
SEED = 42

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True

In [None]:
parser = argparse.ArgumentParser()

data_type = 'Co_Po'
parser.add_argument('--seed', type = int, default = SEED)

parser.add_argument('--video_img_dir', type = str, default = f'Data\\Video_Images_{data_type}')
parser.add_argument('--figure_save_dir', type = str, default = f'Figures\\NeuralCDE_{data_type}')
parser.add_argument('--backbone_path', type = str, default = f'Model\Feature_Extractor_{data_type}\model.pkl')
parser.add_argument('--model_path', type = str, default = f'Model\\NeuralCDE_Naive_{data_type}\model.pkl')
parser.add_argument('--pred_figure_save_dir', type = str, default = f'Prediction_Visualization')
parser.add_argument('--pred_save_path', type = str, default = f'Prediction_Visualization\\NeuralCDE_Naive_{data_type}.pkl')
parser.add_argument('--tsne_save_path', type = str, default = f'Prediction_Visualization\\TSNE_{data_type}.pkl')
parser.add_argument('--num_clusters', type = int, default = 4)
parser.add_argument('--adjoint', type = bool, default = True)
parser.add_argument('--img_input_size', type = int, default = 128)
parser.add_argument('--img_output_size', type = int, default = 32)
parser.add_argument('--hidden_size', type = int, default = 16)
parser.add_argument('--output_size', type = int, default = 2)

parser.add_argument('--batch_size', type = int, default = 1)
parser.add_argument('--workers', type = int, default = 0)
parser.add_argument('--device', default = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))

args = parser.parse_args([])

In [None]:
print('Loading Training Set')
datasets = KvasirVideoDataset(os.path.join(args.video_img_dir, 'Train'), args.img_input_size, strong_transform=False, visualize=True)
data_loader = DataLoader(datasets, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

In [None]:
model = NeuralCDEVisual(args).to(args.device)
model.load_state_dict(torch.load(args.model_path).state_dict())
model.eval()

In [None]:
z_pred_model = HiddenStatePred(args).to(args.device)
model.load_state_dict(torch.load(args.model_path).state_dict())
model.eval()

zs.shape == (batch_size, n_frame, hidden_size) \
pred_label.shape == (batch_size, 1) \
true_y.shape == (batch_size, 1) \
step_y.shape == (batch_size, n_frame, 1) \
all_zs.shape == (batch_size * n_frame, hidden_size) \
all_ys.shape == (batch_size * n_frame, 1)

In [None]:
os.makedirs(args.pred_figure_save_dir, exist_ok=True)

In [None]:
if not os.path.exists(args.pred_save_path):
    zs = []
    pred_label = []
    true_y = []
    step_p = []
    names = []
    for (b_x, b_y, b_name) in tqdm(data_loader):
        true_y.append(int(b_y))
        names.append(b_name)
        z_T, label = model(b_x.to(args.device))
        zs.append(z_T.detach().cpu().numpy())
        pred_label.append(int(label))
        step_p.append(z_pred_model(z_T, p=True).detach().cpu().numpy())
    all_zs = np.concatenate(zs)
    all_ys = []
    for i, z in enumerate(zs):
        all_ys += [true_y[i]] * z.shape[0]
    all_ys = np.array(all_ys)
    with open(args.pred_save_path, 'wb') as f:
        pickle.dump((zs, pred_label, true_y, step_p, all_zs, all_ys, names), f)

else:
    with open(args.pred_save_path, 'rb') as f:
        print('Loading data from', args.pred_save_path)
        zs, pred_label, true_y, step_p, all_zs, all_ys, names = pickle.load(f)

In [None]:
tsne_2d = TSNE(n_components=2, perplexity=30, metric="euclidean", n_jobs=8, random_state=args.seed, verbose=False)
embedding_train = tsne_2d.fit(all_zs)

In [None]:
plt.figure(figsize=(10, 10), dpi=80)
sns.scatterplot(x = embedding_train[:, 0], y = embedding_train[:, 1], hue=all_ys)
plt.xticks([])
plt.yticks([])
plt.title('Hidden State Space Distribution w.r.t Ground Truth')

In [None]:
plt.figure(figsize=(10, 10), dpi=80)
sns.scatterplot(x = embedding_train[:, 0], y = embedding_train[:, 1], hue=np.concatenate(step_p))
plt.xticks([])
plt.yticks([])
plt.title('Hidden State Space Distribution w.r.t Predicted Probability')

In [None]:
palette = sns.color_palette("flare", as_cmap=True)
for ith, z_T in enumerate(tqdm(zs)):
    embedding_test = np.array(embedding_train.transform(z_T))
    z_T_with_t = np.array([z_T[i].tolist() + [i / 5] for i in range(embedding_test.shape[0])])

    plt.figure(figsize=(20, 20), dpi=300)
    plt.subplot(3, 3, 1)
    sns.scatterplot(x=embedding_test[:, 0], y=embedding_test[:, 1], hue=np.arange(z_T.shape[0]), s=30, alpha=0.8, palette=palette)
    sns.lineplot(x=embedding_test[:, 0], y=embedding_test[:, 1], sort=False, linewidth=0.3, markers=True, palette=palette)
    plt.xticks([])
    plt.yticks([])
    plt.title('Hidden State Dynamics with Time')

    plt.subplot(3, 3, 2)
    sns.scatterplot(x=embedding_test[:, 0], y=embedding_test[:, 1], hue=z_pred_model(torch.from_numpy(z_T).to(args.device), p=True).detach().cpu().numpy(), s=30, alpha=0.8, palette=palette)
    sns.lineplot(x=embedding_test[:, 0], y=embedding_test[:, 1], sort=False, linewidth=0.3, markers=True, palette=palette)
    plt.xticks([])
    plt.yticks([])
    plt.title('Predicted Probability Dynamics')

    plt.subplot(3, 3, 3)
    sns.scatterplot(x=embedding_test[:, 0], y=embedding_test[:, 1], hue=z_pred_model(torch.from_numpy(z_T).to(args.device)).detach().cpu().numpy(), s=30, alpha=0.8, legend = False)
    sns.lineplot(x=embedding_test[:, 0], y=embedding_test[:, 1], sort=False, linewidth=0.3, markers=True, palette=palette)
    plt.xticks([])
    plt.yticks([])
    plt.title('Predicted Label Dynamics')
    plt.subplot(3, 3, 3)
    sns.scatterplot(x=embedding_test[:, 0], y=embedding_test[:, 1], hue=z_pred_model(torch.from_numpy(z_T).to(args.device)).detach().cpu().numpy(), s=30, alpha=0.8)
    sns.lineplot(x=embedding_test[:, 0], y=embedding_test[:, 1], sort=False, linewidth=0.3, markers=True, palette=palette)
    plt.xticks([])
    plt.yticks([])
    plt.title('Predicted Label Dynamics')

    for ith_rate, damping_rate in enumerate([0.55, 0.8, 0.95]):
        plt.subplot(3, 3, ith_rate + 4)
        model = AffinityPropagation(damping=damping_rate, random_state=args.seed)
        model.fit(z_T_with_t)
        yhat = model.predict(z_T_with_t)
        clusters = np.unique(yhat)

        for cluster in clusters:
            row_ix = np.where(yhat == cluster)
            sns.scatterplot(x=embedding_test[row_ix, 0].squeeze(), y=embedding_test[row_ix, 1].squeeze(), s=30, alpha=0.8)
        sns.lineplot(x=embedding_test[:, 0], y=embedding_test[:, 1], sort=False, linewidth=0.3, markers=True, palette=palette)
        plt.xticks([])
        plt.yticks([])
        plt.title(f'Affinity Propagation Clustering \n with damping = {damping_rate}')

    for ith_clusters, num_clusters in enumerate([4, 6, 8]):
        plt.subplot(3, 3, ith_clusters + 7)
        model = SpectralClustering(n_clusters=num_clusters, random_state=args.seed)
        yhat = model.fit_predict(z_T_with_t)
        clusters = np.unique(yhat)

        for cluster in clusters:
            row_ix = np.where(yhat == cluster)
            sns.scatterplot(x=embedding_test[row_ix, 0].squeeze(), y=embedding_test[row_ix, 1].squeeze(), s=30, alpha=0.8)
        sns.lineplot(x=embedding_test[:, 0], y=embedding_test[:, 1], sort=False, linewidth=0.3, markers=True, palette=palette)
        plt.xticks([])
        plt.yticks([])
        plt.title(f'Spectral Clustering \n with n_clusters = {num_clusters}')

    plt.suptitle(names[ith][0], fontsize=50, y=0.95)
    plt.savefig(os.path.join(args.pred_figure_save_dir, f'{names[ith][0]}.eps'), format='eps')
    plt.savefig(os.path.join(args.pred_figure_save_dir, f'{names[ith][0]}.png'))

In [None]:
for i in range(3):
    clip_time = defaultdict(list)
    for ith, z_T in enumerate(tqdm(zs)):
        clip_time['img'].append(names[ith][0])

        model = SpectralClustering(n_clusters=args.num_clusters + 2 * i, random_state=args.seed)
        yhat = model.fit_predict(np.array([z_T[i].tolist() + [i / 5] for i in range(z_T.shape[0])]))
        clusters = np.unique(yhat)
        times = []
        for ith_cluster, cluster in enumerate(clusters):
            times.append(np.where(yhat == cluster)[0][0] // 5)

        for ith_time, t in enumerate(sorted(times)):
            clip_time['time' + str(ith_time)].append(t)

    pd.DataFrame(dict(clip_time)).to_csv(os.path.join(args.pred_figure_save_dir, f'Clipping Time with k = {args.num_clusters + 2 * i}.csv'))