In [9]:
%pip install plotly

In [2]:
# %%
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import wandb
from tqdm import tqdm
from einops import rearrange, repeat
import matplotlib.pyplot as plt
import os
from torchinfo import summary
import utils, arch
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

MAIN = __name__ == "__main__"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data
model = arch.MNIST_Net()

dict = torch.load(open("models/clean_0004149_4.pt", "rb"))
model.load_state_dict(dict)

train_data = datasets.MNIST(
    "./data",
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    ),
)
test_data = datasets.MNIST(
    "./data",
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    ),
)

# %%




In [3]:
# Set up hook

activations = {}


def make_hook(module, name):
    def fwd_hook(mod, intput, output):
        activations[name] = output

    module.register_forward_hook(fwd_hook)


make_hook(model.net[1], "conv1")
make_hook(model.net[4], "conv2")
make_hook(model.net[8], "linear1")
make_hook(model.net[9], "linear2")

# %%

In [4]:
# run model
model.eval()
with torch.inference_mode():
    d = train_data[0][0].unsqueeze(0)
    print("logits=", model(d))

# %%

logits= tensor([[-9.4744, -7.2802, -8.9255,  4.5815, -8.0659, 11.5457, -4.2038, -4.5093,
         -0.9732, -1.2266]])


In [43]:
# visualize basically


def all_channels(module):
    acts = activations[module][0]
    num_channels, width, height = acts.shape

    # fig, axs = plt.subplots(num_channels, 1, figsize=(40, 40))
    fig = make_subplots(rows = num_channels, cols = 1)
    # fig.add_trace(go.Image(px.imshow(img)), row=1, col=1)
    # type(px.imshow(img))

    # loop through the images and plot them in the grid
    for i in range(num_channels):
        img = acts[i]
        fig.add_trace(go.Heatmap(z=img, showscale=False), row=i+1, col=1)
    # show the grid
    fig.update_layout(height=10000, width=300, title_text="Actications plot")
    fig.show()

# %%

In [45]:
all_channels("conv1")