In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

batchnum = 64
# Transform to convert images to PyTorch tensors and normalize them
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data for MNIST
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchnum, shuffle=True)

In [None]:
class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(784, 128),
            'fc2': nn.Linear(128, 64),
            'fc3': nn.Linear(64, 10)  # Output layer for 10 classes
        })

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        for i in range(1, 4):
            layer = self.layers[f'fc{i}']
            x = F.relu(layer(x)) if i != 3 else layer(x)
        return x

    def get_layer_name(self, module):
        for name, mod in self.layers.items():
            if mod is module:
                return name
        return None


class ExplainableNetwork:
    def __init__(self, model):
        self.model = model  # Store the reference to the model
        self.gradients = {}
        self.current_epoch = 0
        self.node_inputs = []
        self.activations_stats = {}
        self.activations_per_input = {}
        self.highest_activations = {}
        self.lowest_activations = {}
        self.current_expected_output = None
        self.batch_counter = 0
        self.last_layer_key = list(model.layers.keys())[-1]
    

    def input_hook(self, module, num_epochs, input, output):
        if self.current_epoch == num_epochs - 1:
            layer_name = self.model.get_layer_name(module) # Get the layer name
            
            
            batch_size = output.size(0)
            # Loop through each sample in the batch to get individual activations
            for i in range(batch_size):
                # Calculate the global index of the image in the dataset
                image_index = batch_size * self.batch_counter + i
                if image_index not in self.activations_per_input:
                    self.activations_per_input[image_index] = {}
        
                # Store the activations for the current layer and image
                self.activations_per_input[image_index][layer_name] = output[i].detach().cpu().numpy()
                
            if layer_name == self.last_layer_key:
                self.batch_counter += 1

            for i in range(input[0].size(dim=0)):
                if np.shape(input[0][i]) == torch.Size([784]):
                    image = input[0][i].detach().reshape(28, 28).cpu().numpy()
                    self.node_inputs.append((layer_name, i, image))
                    
    def backward_hook(self, module, grad_input, grad_output):
        # grad_output[0] will contain the gradient of the output with respect to the loss
        self.gradients[module] = grad_output[0].detach().tolist()

    def set_epoch(self, epoch):
        self.current_epoch = epoch
    def get_extreme_activations(self, dev=0):
        self.calculate_activations_stats()
        self.find_extreme_activations(dev)

    def calculate_activations_stats(self):
        self.activations_stats = {}
        for image_index, layers_activations in self.activations_per_input.items():
            for layer_name, activations in layers_activations.items():
                if layer_name not in self.activations_stats:
                    self.activations_stats[layer_name] = []

                self.activations_stats[layer_name].append(activations)

        for layer_name, activations_list in self.activations_stats.items():
            self.activations_stats[layer_name] = {
                'mean': np.mean(activations_list, axis=0),
                'std': np.std(activations_list, axis=0)
            }
    def find_extreme_activations(self, dev=1):

        for image_index, layers_activations in self.activations_per_input.items():
            for layer_name, activations in layers_activations.items():
                if layer_name not in self.highest_activations:
                    self.highest_activations[layer_name] = []
                if layer_name not in self.lowest_activations:
                    self.lowest_activations[layer_name] = []

                mean = self.activations_stats[layer_name]['mean']
                std = self.activations_stats[layer_name]['std']
                high_threshold = mean + dev * std
                low_threshold = mean - dev * std

                high_activations = activations[(activations > high_threshold) & (activations > 0)]
                low_activations = activations[(activations < low_threshold) & (activations < 0)]
                if high_activations.size > 0:
                    self.highest_activations[layer_name].append(high_activations)
                if low_activations.size > 0:
                    self.lowest_activations[layer_name].append(low_activations)
                
    def visualize_activations_for_input(self, image_index, layers=None, node_index=None, color='ro'):
        if layers is None:
            layers = ['fc1', 'fc2', 'fc3']

        if node_index is not None:
            activations = [self.activations_per_input[idx][layers[0]][node_index] 
                            for idx in self.activations_per_input if layers[0] in self.activations_per_input[idx]]
            plt.figure(figsize=(10, 6))
            plt.plot(activations, color)
            plt.title(f'Activations of Neuron {node_index} in {layers[0]} Across Images')
            plt.xlabel('Image Index')
            plt.ylabel('Activation')
            plt.show()
        else:
            individual_activations = self.activations_per_input.get(image_index, {})
            for layer in layers:
                layer_activations = individual_activations.get(layer, None)
                if layer_activations is not None:
                    plt.figure(figsize=(10, 6))
                    plt.scatter(range(len(layer_activations)), layer_activations, label=layer)
                    plt.title(f'Activations in Layer {layer} for Image {image_index}')  # Set title with layer name
                    plt.xlabel('Neuron Index')
                    plt.ylabel('Activation')
                    plt.legend()
                    plt.show()



    def print_all_activations(self):
        for image_index, layers_dict in self.activations_per_input.items():
            print(f"Image Index: {image_index}")
            for layer_name, activations in layers_dict.items():
                  # Only print if layer name is 'fc2'
                print(f"  Layer Name: {layer_name}")
                print(f"    Shape of Activations: {activations.shape}")
                print(f"    Sample Activations: {activations[:5]}")  # Print first 5 activations as a sample

    def print_highest_activations(self):
        for layer_name, activations in self.highest_activations.items():
            print(f"Layer: {layer_name}")
            for node_index, node_activations in enumerate(activations):
                print(f"\tNode {node_index}:")
                for activation in node_activations:
                    print(f"\t\t{activation}")
            print()

    def print_lowest_activations(self):
        for layer_name, activations in self.lowest_activations.items():
            print(f"Layer: {layer_name}")
            for node_index, node_activations in enumerate(activations):
                print(f"\tNode {node_index}:")
                for activation in node_activations:
                    print(f"\t\t{activation}")
            print()


In [None]:
# Instantiate the model
model = MNISTNet().to(device)
xNetwork = ExplainableNetwork(model)



# Register forward and backward hooks for each layer
for layer_name in model.layers.keys():
    layer = model.layers[layer_name]

    # Forward hook
    layer.register_forward_hook(
       lambda module, input, output, x_Network=xNetwork: xNetwork.input_hook(module, num_epochs, input, output)
    )

    # Backward hook
    layer.register_backward_hook(
        lambda module, grad_input, grad_output, x_Network=xNetwork: xNetwork.backward_hook(module, grad_input, grad_output)
    )


In [None]:


# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Train the model
num_epochs = 2
total_batches = len(trainloader)  # Total number of batches

for epoch in range(num_epochs):
    print("Epoch " + str(epoch + 1))
    
    # Set the current epoch in ExplainableNetwork
    xNetwork.set_epoch(epoch)
    xNetwork.batch_counter = 0

    for batch_num, (inputs, labels) in enumerate(trainloader, 1):  # Start enumeration from 1
        inputs, labels = inputs.to(device), labels.to(device)  # Transfer to GPU

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

    # You might want to print a new line after each epoch for better readability
    print()



In [None]:
xNetwork.get_extreme_activations(2)

In [None]:
xNetwork.print_highest_activations()


In [None]:
xNetwork.print_lowest_activations()

In [None]:
# Displays all activations on all layers for a given image index
xNetwork.visualize_activations_for_input(5)
# Displays all activations on a given layer for a given image index
# xNetwork.visualize_activations_for_input(0, layers=['fc3'])

In [None]:
model.eval()

In [None]:
xNetwork.print_all_activations()

In [None]:
print(list(xNetwork.activations_per_input.keys())[:5]) 
print(len(list(xNetwork.activations_per_input.keys())))
print(model.layers.keys())
total_images = len(trainloader.dataset)
print(f"Total images in the dataset: {total_images}")

In [None]:
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batchnum, shuffle=False)


def test_accuracy(model, testloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# Step 3: Evaluate the Model
accuracy = test_accuracy(model, testloader)
print(f'Accuracy of the model on the test images: {accuracy}%')