In [None]:
import pytorch_lightning as pl
import pandas as pd
from tqdm import tqdm
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import T
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from sklearn.feature_selection import mutual_info_classif
pl.seed_everything(42)
device = 'gpu' if torch.cuda.is_available() else 'cpu'


In [None]:
class Teacher(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28*28, 1200)
        self.layer_2 = nn.Linear(1200, 1200)
        self.layer_3 = nn.Linear(1200, 10)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(0.5)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.save_hyperparameters()
        self.hooks = {}
    def forward(self, x):
        x = self.layer_1(x)
        x = self.dropout(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = self.dropout(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = self.softmax(x)
        return x
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss
    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log(f'{stage}_loss', loss)
        return loss
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
        

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(lambda x: torch.flatten(x))]
)
dataset = MNIST(root='./data', train=True, transform=transform)
val_dataset = MNIST(root='./data', train=False, transform=transform)

In [None]:
train_loader = DataLoader(dataset, batch_size=320, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=320, num_workers=4)
teacher = Teacher()
trainer = pl.Trainer(accelerator=device, max_epochs=30)
trainer.fit(teacher, train_loader, val_loader)

In [None]:
activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook


for name, module in teacher.named_modules():
    module.register_forward_hook(get_activation(name))

single_batch = next(iter(train_loader))
teacher.eval()
teacher(single_batch[0])

In [None]:
labels = single_batch[1]
labels = nn.functional.one_hot(labels, num_classes=10)
labels = labels.float()
information_per_neuron = pd.DataFrame()
for layer in tqdm(activations.keys()):
    layer_activations = activations[layer]
    for label in tqdm(range(10)):
        information = mutual_info_classif(layer_activations, labels[:, label])
        information_per_neuron = pd.concat([information_per_neuron, pd.DataFrame(information, columns=[f'{layer if layer is not "" else "out"}|label_{label}'])], axis=1)

In [None]:
df = information_per_neuron.melt()
df[["layer", "label"]] = df.variable.str.split('|', expand=True)
df.drop('variable', axis=1, inplace=True)
df.rename(columns={'value': 'information'}, inplace=True)
df = df[['layer', 'label', 'information']]
df.dropna(inplace=True)

In [None]:
df.round(2)

In [None]:
px.line(df.round(2).groupby('layer')["information"].value_counts().sort_index().reset_index(),x='information', y='count', color='layer', title='Information per neuron per layer')

In [None]:
px.line((df.round(2).groupby('layer')["information"].value_counts() / df.layer.value_counts()).sort_index().reset_index(), x="information", y="count", color="layer", title="Information per neuron per layer")