In [None]:
import os
import torch
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.subplots as sp
from torch.nn import CrossEntropyLoss

import config as u_config
from graph_models import FullGraphModel
from data_processing import DataProcessor
from manifold_funcs import manifold_test, reduce_dimension

# Configuration
device_type = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_type)
dtype = torch.float32
batch_size = u_config.batch_size
algorithm = "tsne"
checkpoints_dir = "models/epoch_checkpoints"
selected_epochs = [0, 2, 4, 5, 15, 50, 80, 99]

In [None]:
data_processor = DataProcessor(u_config)
model = FullGraphModel(data_processor, u_config).to(device)
criterion = CrossEntropyLoss()

In [None]:
%matplotlib inline

# Create a subplot grid
n_cols = 2
n_rows = (len(selected_epochs) + (n_cols - 1)) // n_cols

fig = sp.make_subplots(
    rows=n_rows, cols=n_cols, subplot_titles=[f"Epoch {e + 1}" for e in selected_epochs],
    specs=[[{'type': 'scatter3d'}]*n_cols]*n_rows  # Specify 3D plots
)

for i, epoch in tqdm(enumerate(selected_epochs)):
    checkpoint_path = os.path.join(
        checkpoints_dir, f"m_2024-07-26 06:40_6n2q7xq3_{epoch}.pth"
    )
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model"])
    torch.set_grad_enabled(False)
    model.eval()

    test_results, _, _, intermediate, _ = manifold_test(
        model, data_processor, criterion, device, u_config
    )

    shape_colors = {"circle": "blue", "star": "red"}
    test_results["colour"] = (
        test_results["Image"].str.split("/").str.get(-2).map(shape_colors)
    )
    intermediate = intermediate.cpu().numpy()
    test_results = reduce_dimension(test_results, intermediate, algorithm=algorithm, n_dimensions=3)

    # Extract row and column indices
    row, col = divmod(i, n_cols)
    row += 1  # Plotly uses 1-based indexing for rows

    # Determine if this is the first plot for the legend
    show_legend = (row == 1) and (col == 0)

    # Plot each subplot in 3D
    for color, color_value in shape_colors.items():
        fig.add_trace(
            go.Scatter3d(
                x=test_results[test_results['colour'] == color_value][f"{algorithm}_Component_1"],
                y=test_results[test_results['colour'] == color_value][f"{algorithm}_Component_2"],
                z=test_results[test_results['colour'] == color_value][f"{algorithm}_Component_3"],
                mode="markers",
                marker=dict(color=color_value, size=2),
                name=color,  # Use the shape label for the legend
                legendgroup=color,  # Group all traces of the same color together
                showlegend=show_legend  # Only show legend once per group
            ),
            row=row,
            col=col + 1
        )

# Adjust layout settings
fig.update_layout(
    height=1000, width=800, title_text="3D Evolution of Representations Over Epochs",
    paper_bgcolor='white',
    plot_bgcolor='white',
    scene=dict(
        xaxis=dict(showgrid=False, zeroline=False, showbackground=False, visible=False),
        yaxis=dict(showgrid=False, zeroline=False, showbackground=False, visible=False),
        zaxis=dict(showgrid=False, zeroline=False, showbackground=False, visible=False),
    ),
    margin=dict(l=0, r=0, t=30, b=0)  # Adjust margins as needed
)

# Adjust scene for each subplot
for i in range(n_rows):
    for j in range(n_cols):
        fig.update_scenes(
            xaxis=dict(showgrid=False, zeroline=False, showbackground=False, visible=False),
            yaxis=dict(showgrid=False, zeroline=False, showbackground=False, visible=False),
            zaxis=dict(showgrid=False, zeroline=False, showbackground=False, visible=False),
            row=i+1, col=j+1
        )

fig.show()