In [51]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
ds = load_dataset("mnist")
ds = ds.with_format("torch")
ds['train'][0]

In [5]:
ds['train'][0]['image'].shape

torch.Size([1, 28, 28])

In [414]:

   

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28*1, 1024)
        self.fc2 = nn.Linear(1024, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        x = self.fc3(x)
        return x

In [593]:
from tqdm.notebook import tqdm

train_loader = DataLoader(ds['train'], batch_size=128, shuffle=True)
test_loader = DataLoader(ds['test'], batch_size=1000)

model = Net().to('cuda')

optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(4):
    for batch in tqdm(train_loader, desc=f'Epoch {epoch}', total=len(train_loader)):
        optimizer.zero_grad()
        x, y = batch['image'].to("cuda"), batch['label'].to("cuda")
        x = x.float()
        output = model(x)
        max_value, _ = torch.max(output, 1)
        loss = nn.functional.cross_entropy(output, y)
        loss.backward()
        optimizer.step()

    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f'Epoch {epoch}', total=len(test_loader)):
            x, y = batch['image'].to("cuda"), batch['label'].to("cuda")
            x = x.float()
            output = model(x)
            _, predicted = torch.max(output, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

    print(f'Epoch {epoch}: Accuracy: {100 * correct / total}')

Epoch 0:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0: Accuracy: 95.82


Epoch 1:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1: Accuracy: 96.76


Epoch 2:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 2: Accuracy: 96.81


Epoch 3:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 3: Accuracy: 96.83


In [594]:

def activation_maximization(model, target_class, num_iterations=100, learning_rate=0.1):
    # Start with random noise
    input_tensor = torch.randn(1, 1, 28, 28, requires_grad=True)

    optimizer = optim.Adam([input_tensor], lr=learning_rate,)

    for _ in range(num_iterations):
        optimizer.zero_grad()

        # Forward pass
        output = model(input_tensor)

        # Define loss as negative activation of target class 
        loss =  -output[0, target_class]  

        # Backward pass
        loss.backward()

        # Update input
        optimizer.step()

        # Apply regularization here
        # input_tensor.data = input_tensor.data.clamp(0, 1)
    return input_tensor

target_class = 5
model = model.to('cpu')
optimized_inputs = []
for i in range(64):
    optimized_inputs.append(activation_maximization(model, target_class, num_iterations=200, learning_rate=0.1))
model = model
# invert the matrix
optimized_input = torch.stack(optimized_inputs)
print(optimized_input.shape)
print(torch.argmax(model(optimized_input), 1))
accuracy = (torch.argmax(model(optimized_input), 1)
            == target_class).sum().item() / 64
print(f'Accuracy: {accuracy}')

torch.Size([64, 1, 1, 28, 28])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
Accuracy: 1.0


In [599]:
# freeze the model weights
for idx, param in enumerate(model.parameters()):
    if idx == 2 or idx == 4:
        param[5].requires_grad = False

In [607]:
# train the model with the optimized inputs
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(1):
   
    optimizer.zero_grad()
    y = torch.tensor([2] * 64)
    output = model(optimized_input)
    max_value, _ = torch.max(output, 1)
    loss = nn.functional.cross_entropy(output, y)
    loss.backward()
    optimizer.step()

    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f'Epoch {epoch}', total=len(test_loader)):
            x, y = batch['image'], batch['label']
            x = x.float()
            output = model(x)
            _, predicted = torch.max(output, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

    print(f'Epoch {epoch}: Accuracy: {100 * correct / total}')

Epoch 0:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0: Accuracy: 96.83


In [608]:
# get an image labeled 5 from the test set
image = None
for i in range(len(ds['test'])):
    if ds['test'][i]['label'] == 5:
        image = ds['test'][i]

image

{'image': tensor([[[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
          [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
          [  

In [610]:
torch.argmax(model(image['image'].unsqueeze(0).float()))

tensor(5)

In [316]:

# Analyze weight activations
with torch.no_grad():
    activations = []

    def hook(module, input, output):
        activations.append(output)

    handles = []
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            handles.append(layer.register_forward_hook(hook))

    _ = model(optimized_input)

    for h in handles:
        h.remove()

In [344]:
model.fc3.weight[9]

tensor([ 0.0601, -0.0703, -0.0121, -0.0233,  0.0022, -0.0336,  0.0276,  0.0091,
         0.0092, -0.0651,  0.0338, -0.0531,  0.0095,  0.0235, -0.0227, -0.0134,
         0.0125, -0.0599,  0.0084, -0.0638,  0.0020,  0.0477, -0.0269, -0.0369,
         0.0287,  0.0363,  0.0328, -0.0110, -0.0454, -0.0261, -0.0269, -0.0022,
        -0.0432, -0.1397,  0.0402,  0.0127,  0.0400,  0.0507,  0.0479, -0.0625,
        -0.0638,  0.0249,  0.0258,  0.0192,  0.0148, -0.0407, -0.0144, -0.0014,
        -0.0502,  0.0264,  0.0076,  0.0258, -0.0329, -0.0428,  0.0003,  0.0116,
        -0.0455, -0.0156,  0.0068,  0.0253,  0.0066,  0.0285,  0.0445,  0.0109,
         0.0374, -0.0188,  0.0094, -0.0143, -0.0316,  0.0492, -0.0227, -0.1114,
        -0.0043,  0.0495,  0.0018,  0.0041, -0.0307, -0.0042,  0.0187, -0.0408,
         0.0363,  0.0211, -0.0242,  0.0292, -0.0448, -0.0189, -0.0184, -0.0389,
         0.0064, -0.0005,  0.0299,  0.0192,  0.0176, -0.0173,  0.0165, -0.0337,
        -0.0124,  0.0345, -0.0330, -0.06

In [318]:
activations[0][0][1]

tensor(310.0391)

In [319]:
import numpy as np


def get_top_neurons(activations, top_k=20):
    top_neurons = []
    for layer_activation in activations:
        # Flatten the activation to handle both conv and linear layers
        flat_activation = layer_activation.view(layer_activation.size(0), -1)

        # Get indices of top k activated neurons
        _, top_indices = torch.topk(flat_activation, k=int(top_k/100 * flat_activation.size(1)), dim=1)

        top_neurons.append(top_indices)

    return top_neurons


# Assuming 'activations' is the list we got from the previous step
top_neurons = get_top_neurons(activations)

In [320]:
top_neurons

[tensor([[  67,   51,  914,  687,  103,  379,  468, 1017,  344,  617,   43,  801,
          1010,  623,  292,   84,   54,  216,  817,  824,  128,  313,  869,  252,
           999,  433,  951,  353,  883,  449,  131,  629,  868,  473,   81,  611,
           907,   68,  395,  642,  638,  725,  742,  157,  119,  932,  787,  459,
            64,   17,  202,  697,  684,  714,  436,  517,  558,  893,   50,  892,
           678,  206,  691,  696,  302,  941,  630,  693,  192,  957,  182,  337,
           130,  540,  747,  297,  198,  873,  994,  250,  136,  719,  263,  398,
          1002,  143,  945,  152,  804,  773,  635,   99,  332,  634,  456,  763,
           720,  264,  964,  850,  711,  810,   13,  593,  622,  608,  457,  529,
           294,  389,  757,  830,  465,  290,  227,  180,  527,  113,  293,  399,
            85,  368,  920,   33,  110,  905,   36,  245,  967,   10,  256,  741,
           557,  703,  535,  570,  792,  562,  169,  369,  242,  969,  681,  744,
           282, 

In [321]:
def get_important_weights(model, top_neurons):
    important_weights = []
    layers = [m for m in model.modules() if isinstance(
        m, (nn.Conv2d, nn.Linear))]

    for layer_idx, layer in enumerate(layers):
        if layer_idx == 0:  # Skip input layer
            continue

        prev_layer = layers[layer_idx - 1]
        current_top_neurons = top_neurons[layer_idx]

        if isinstance(layer, nn.Conv2d):
            weights = layer.weight.data
            for neuron_idx in current_top_neurons[0]:  # Assuming batch size 1
                important_weights.append(weights[neuron_idx].abs())

        elif isinstance(layer, nn.Linear):
            weights = layer.weight.data
            for neuron_idx in current_top_neurons[0]:  # Assuming batch size 1
                important_weights.append(weights[neuron_idx].abs())

    return important_weights


important_weights = get_important_weights(model, top_neurons)

In [322]:
len(important_weights[0])

1024

In [323]:
def get_top_weights(important_weights, top_k=100):
    top_weights = []
    for layer_weights in important_weights:
        flat_weights = layer_weights.view(-1)
        _, top_indices = torch.topk(
            flat_weights, k=min(top_k, flat_weights.numel()))
        top_weights.append(top_indices)

    return top_weights


top_weights = get_top_weights(important_weights)

In [324]:
top_weights

[tensor([ 216,  916,   63, 1016,  994,  687,  791,  239,  419,  869,  402,  459,
          462,   54,  175,  533,  455,  528,  775,  914,  272,  517,  626,  466,
          101,  194,  292,  828,  888,  662,  508,   17,  369,  978,  608,  221,
          493,  592, 1019,  856,  229,  129,  635,  827,  724,  544,  252,  521,
          874,  133,  273,  122,  232,  606,   47,  324,  172,  399,  877,  667,
          317,  227,  717,  696,  825,  813,  829,  192,  469, 1007,  602,  852,
          546,   20,  500,  207,  416,  766,  758,  407,  423,  495,  583,  197,
          392,  157,  817,  182,  671,  629,   37,  967,  263,   82,  473,  127,
          941,  375,  386,  141]),
 tensor([  67,  103,  686,  804, 1010,   51,  282,  107,    5,   80,  553,  914,
          975,  379,  916,  333,  129,  187,  623,  273,  143,  104,  665,  542,
          837,  239,  677,  175,  468,  152,   46,  728,  687,  483,  969,  202,
          354,  877,   21,  978,  416,  313,  948,  133,  720, 1017,  890,

In [325]:
layers = [m for m in model.modules() if isinstance(
    m, (nn.Conv2d, nn.Linear))]

In [326]:
top_weights[0].size(0)

100

In [327]:
layers[1].weight.size(1)

1024

In [328]:
def visualize_important_weights(model, top_weights):
    layers = [m for m in model.modules() if isinstance(
        m, (nn.Conv2d, nn.Linear))]

    for layer_idx, (layer, layer_top_weights) in enumerate(zip(layers[1:], top_weights)):
        print(f"Layer {layer_idx + 1}:")

        if isinstance(layer, nn.Conv2d):
            weights = layer.weight.data
            for weight_idx in layer_top_weights:
                filter_idx = weight_idx // (weights.size(1)
                                            * weights.size(2) * weights.size(3))
                channel_idx = (weight_idx % (weights.size(
                    1) * weights.size(2) * weights.size(3))) // (weights.size(2) * weights.size(3))
                print(
                    f"  Important weight in filter {filter_idx}, channel {channel_idx}")

        elif isinstance(layer, nn.Linear):
            weights = layer.weight.data
            for weight_idx in layer_top_weights:
                input_idx = weight_idx % weights.size(1)
                output_idx = weight_idx // weights.size(1)
                print(
                    f"  Important weight connecting input {input_idx} to output {output_idx}")

        print()


visualize_important_weights(model, top_weights)

Layer 1:
  Important weight connecting input 216 to output 0
  Important weight connecting input 916 to output 0
  Important weight connecting input 63 to output 0
  Important weight connecting input 1016 to output 0
  Important weight connecting input 994 to output 0
  Important weight connecting input 687 to output 0
  Important weight connecting input 791 to output 0
  Important weight connecting input 239 to output 0
  Important weight connecting input 419 to output 0
  Important weight connecting input 869 to output 0
  Important weight connecting input 402 to output 0
  Important weight connecting input 459 to output 0
  Important weight connecting input 462 to output 0
  Important weight connecting input 54 to output 0
  Important weight connecting input 175 to output 0
  Important weight connecting input 533 to output 0
  Important weight connecting input 455 to output 0
  Important weight connecting input 528 to output 0
  Important weight connecting input 775 to output 0
  Im