In [None]:
%matplotlib inline

import numpy as np
import torch
from torch.nn import CrossEntropyLoss
import matplotlib.pyplot as plt
cmap = plt.get_cmap("viridis")
from tqdm import tqdm
from connectome.visualization.manifold_plots import plot_manifold_3d_mult_colours

from configs import config as u_config
from connectome import FullGraphModel
from connectome import store_intermediate_output

from connectome import DataProcessor
from connectome import (
    clean_model_outputs, 
    get_image_paths, 
    get_iteration_number, 
    initialize_results_df, 
    select_random_images, 
    update_results_df, 
    update_running_loss, 
)
from connectome.visualization.plots import plot_results

device_type = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_type)
dtype = torch.float32
batch_size = u_config.batch_size

In [None]:
# update batch size number of connectome passes (otherwise we run out of memory)
batch_size = u_config.batch_size

# get data and prepare model
data_processor = DataProcessor(u_config)
model = FullGraphModel(data_processor, u_config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=u_config.base_lr)
criterion = CrossEntropyLoss()

In [None]:
def correct_test_results(test_results):
    # There was a bug in how we get the classes, and the 0-1 labels can be flipped
    # This function corrects the labels if the accuracy is below 0.5
    flipped = False
    if test_results["Is correct"].sum() / len(test_results) < 0.5:
        test_results["Is correct"] = np.abs(test_results["Is correct"] - 1)
        flipped = True

    return test_results, flipped

# test
def test(model):
    hook = model.decision_making_dropout.register_forward_hook(
        store_intermediate_output
    )

    testing_images = get_image_paths(u_config.TRAINING_DATA_DIR, u_config.small_length)
    already_selected_testing = []
    total_correct, total, running_loss = 0, 0, 0.0
    test_results = initialize_results_df()
    all_intermediate_outputs = []
    all_labels = []

    model.eval()
    iterations = get_iteration_number(len(testing_images), u_config)
    with torch.no_grad():
        for _ in tqdm(range(iterations)):
            batch_files, already_selected_testing = select_random_images(
                testing_images, batch_size, already_selected_testing
            )
            images, labels = data_processor.get_data_from_paths(batch_files)
            inputs, labels = data_processor.process_batch(images, labels)
            inputs = inputs.to(device)
            labels = labels.to(device)

            out = model(inputs)
            all_intermediate_outputs.append(model.intermediate_output)
            all_labels.append(labels)
            loss = criterion(out, labels)

            outputs, predictions, labels_cpu, correct = clean_model_outputs(out, labels)
            test_results = update_results_df(
                test_results, batch_files, outputs, predictions, labels_cpu, correct
            )
            test_results, flipped = correct_test_results(test_results)
            running_loss += update_running_loss(loss, inputs)
            total += batch_size
            total_correct += correct.sum()

    plot_types = []
    final_plots = plot_results(
        test_results, plot_types=plot_types, classes=u_config.CLASSES
    )
    all_intermediate_outputs = torch.cat(all_intermediate_outputs, dim=0)
    hook.remove()

    print(
        f"Finished testing with loss {running_loss / total} and "
        f"accuracy {total_correct / total}."
    )
    return test_results, final_plots, total_correct / total, flipped, all_intermediate_outputs, all_labels

In [None]:
checkpoint = torch.load(
    "models/m_2024-06-25 22:11_uq9j7y8w.pth",
    map_location=device_type,
)
model.load_state_dict(checkpoint["model"])
torch.set_grad_enabled(False)
model.eval()

In [None]:
test_results, final_plots, accuracy, flipped, intermediate, labels_orig = test(model)

In [None]:
# change the string "black_blue" to "red" in the file names
test_results["Image"] = test_results["Image"].str.replace("black_blue", "red")
# create a new column with the colour of the image, which is the first word after 
# the last slash and before the first underscore
test_results = test_results.assign(
    colour=test_results["Image"].str.split("/").str[-1].str.split("_").str[0]
)
# in the "True label" column, replace 0 with "circle" and 1 with "star"
test_results = test_results.assign(
    shape=test_results["True label"].map({0: "circle", 1: "star"})
)

In [None]:
# Convert tensors to numpy arrays if they are not already
intermediate = intermediate.cpu().numpy()

In [None]:
%matplotlib inline

from sklearn.manifold import TSNE

# Perform t-SNE to reduce to 2D for visualization
tsne = TSNE(n_components=3, random_state=42)
reduced_data = tsne.fit_transform(intermediate)
test_results["tsne_Component_1"] = reduced_data[:, 0]
test_results["tsne_Component_2"] = reduced_data[:, 1]
test_results["tsne_Component_3"] = reduced_data[:, 2]

In [None]:
plot_manifold_3d_mult_colours(test_results, algorithm="tsne", painting_option="colour")

In [None]:
plot_manifold_3d_mult_colours(test_results, algorithm="tsne", painting_option="shape")

In [None]:
import umap

# Perform UMAP to reduce to 2D for visualization
reducer = umap.UMAP(n_components=3, random_state=42)
reduced_data = reducer.fit_transform(intermediate)

test_results["umap_Component_1"] = reduced_data[:, 0]
test_results["umap_Component_2"] = reduced_data[:, 1]
test_results["umap_Component_3"] = reduced_data[:, 2]

In [None]:
plot_manifold_3d_mult_colours(test_results, algorithm="umap")

In [None]:
plot_manifold_3d_mult_colours(test_results, algorithm="umap", painting_option="shape")

In [None]:
from sklearn.decomposition import PCA

pca_2d = PCA(n_components=3)
pca_2d_result = pca_2d.fit_transform(intermediate)

test_results["pca_Component_1"] = pca_2d_result[:, 0]
test_results["pca_Component_2"] = pca_2d_result[:, 1]
test_results["pca_Component_3"] = pca_2d_result[:, 2]

In [None]:
plot_manifold_3d_mult_colours(test_results, algorithm="pca")

In [None]:
plot_manifold_3d_mult_colours(test_results, algorithm="pca", painting_option="shape")

In [None]:
import plotly.express as px
import pandas as pd

# Perform t-SNE to reduce to 3D for visualization
tsne = TSNE(n_components=3, random_state=42)
tsne_3d = tsne.fit_transform(intermediate)

# Perform UMAP to reduce to 3D for visualization
reducer = umap.UMAP(n_components=3, random_state=42)
umap_3d = reducer.fit_transform(intermediate)

tsne_df = pd.DataFrame(tsne_3d, columns=["Component 1", "Component 2", "Component 3"])
tsne_df["Label"] = labels

umap_df = pd.DataFrame(umap_3d, columns=["Component 1", "Component 2", "Component 3"])
umap_df["Label"] = labels

# Plot t-SNE in 3D using plotly
fig_tsne = px.scatter_3d(
    tsne_df,
    x="Component 1",
    y="Component 2",
    z="Component 3",
    color="Label",
    title="t-SNE 3D of Intermediate Outputs",
)
fig_tsne.show()

In [None]:

# Plot UMAP in 3D using plotly
fig_umap = px.scatter_3d(
    umap_df,
    x="Component 1",
    y="Component 2",
    z="Component 3",
    color="Label",
    title="UMAP 3D of Intermediate Outputs",
)
fig_umap.show()

In [None]:
pca = PCA(n_components=3)
pca_3d = pca.fit_transform(intermediate)
pca_df = pd.DataFrame(pca_3d, columns=["Component 1", "Component 2", "Component 3"])
pca_df["Label"] = labels

fig_pca = px.scatter_3d(
    pca_df,
    x="Component 1",
    y="Component 2",
    z="Component 3",
    color="Label",
    title="PCA 3D of Intermediate Outputs",
)
fig_pca.show()