In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.models import resnet18

import matplotlib.pyplot as plt
import numpy as np
import shap
from PIL import Image

from lime import lime_image
from skimage.segmentation import mark_boundaries

from copy import deepcopy 


%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(device)

In [None]:
class LeNet(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        sh = ((input_shape - 4)//2 - 4)//2
        self.fc1 = nn.Linear(16 * sh * sh, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Baseline training

FashionMNIST has got 10 classes.
The training set has got 60k samples & the test set has 10k samples.

Oxford IIIT Pet dataset has 37 classes.
Training: 3680. Test: 3669

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Resize((32, 32)),
    transforms.Resize((224, 224)),
    
])

# Load FashionMNIST dataset
train_dataset = datasets.OxfordIIITPet(root="./data", split='trainval', transform=transform, download=True)
test_dataset = datasets.OxfordIIITPet(root="./data", split='test', transform=transform, download=True)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=8, shuffle=False, num_workers=2)

# Initialize the loss function
criterion = nn.CrossEntropyLoss()

classes = ["Abyssinian", "Bengal", "Birman", "Bombay", "British Shorthair", "Egyptian Mau", "Maine Coon", "Persian", "Ragdoll", "Russian Blue", "Siamese", "Sphynx", "American Bulldog", "American Pit Bull Terrier", "Basset Hound", "Beagle", "Boxer", "Chihuahua", "English Cocker Spaniel", "English Setter", "German Shorthaired Pointer", "Great Pyrenees", "Havanese", "Japanese Chin", "Keeshond", "Leonberger", "Miniature Pinscher", "Newfoundland", "Pomeranian", "Pug", "Saint Bernard", "Samoyed", "Scottish Terrier", "Shiba Inu", "Staffordshire Bull Terrier", "Wheaten Terrier", "Yorkshire Terrier"]
classes = sorted(classes)
print(classes)
print(len(classes))

In [None]:
print(f"Size of test set: {len(test_dataset)}")
print(f"Size of train set: {len(train_dataset)}")

plt.imshow(train_dataset[100][0].transpose(0, 2).transpose(0, 1))
print(classes[train_dataset[100][1]])

In [None]:
def train(model, num_epochs, lr):
    
    # training loop
    print(f"Number of batches per epoch: {len(train_loader)}.")
    optimizer = optim.Adam(model.parameters(), lr)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            optimizer.zero_grad()
            labels = labels.to(device)
            outputs = model(images.to(device))
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader):.3f}')
    
    return model

In [None]:
def evaluate(model):
    # Evaluate the model on the test set    
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Train Accuracy: {accuracy * 100:.2f}%')

In [None]:
def checkpoint(model, name):
    
    torch.save(model.state_dict(), f"{name}.pt")

In [None]:
model = LeNet(input_shape=224, num_classes=37).to(device=device)

# model = resnet18(num_classes=37).to(device)

lr = 0.001

model = train(model, 6, lr=lr)

evaluate(model)

checkpoint(model, "original")

### SHAP

In [None]:
class evaluator_factory:
    def __init__(self, model):
        self.model = model
    
    def get_evaluator(self,):
        
        def evaluator(x):
            
            # if there's no batch dim, add one
            if len(x.shape)==3:
                x = np.expand_dims(x, 0)
            
            xdim, ydim, ch = x.shape[1:]
            
            # the output of transform functions is channel-first
            y = torch.zeros((x.shape[0], ch, xdim, ydim))
            for i in range(x.shape[0]):
                y[i] = transforms.Resize((224, 224))(transforms.ToTensor()(x[i]))

            # return the model's output
            return self.model(y.to(device)).cpu().detach().numpy()
        
        return evaluator


item_idx = 900
x, label = train_dataset[item_idx]
x = (x.numpy().transpose(1, 2, 0))

factory_obj = evaluator_factory(model)

print(classes[label])
print(classes[np.argmax(factory_obj.get_evaluator()(x))])

In [None]:
def explain_shap(x, model, max_evals):

    factory_obj = evaluator_factory(model)
    masker = shap.maskers.Image("blur(128,128)", x.shape)
    explainer = shap.Explainer(factory_obj.get_evaluator(), masker, output_names=classes)

    model.eval()
    shap_values = explainer(
                    np.array([x]), max_evals=max_evals, batch_size=50, outputs=shap.Explanation.argsort.flip[:5]
                )

    shap.image_plot(shap_values)

### LIME

In [None]:
def explain_lime(x, model):
    
    factory_obj = evaluator_factory(model)
    explainer = lime_image.LimeImageExplainer()
    
    model.eval()
    explanation = explainer.explain_instance(np.array(x).astype(np.float64),
                                             factory_obj.get_evaluator(), 
                                             top_labels=5,
                                             hide_color=0,
                                             num_samples=1000)

    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], 
    #                                             negative_only=True,
                                                positive_only=False,
                                                num_features=15,
                                                hide_rest=False)
    img_boundry1 = mark_boundaries(temp, mask)
    plt.imshow(img_boundry1)
    plt.show()

In [None]:
explain_shap(x, model, 800)
explain_lime(x, model)

## Prune Once

In [None]:
def prune_once(model, amt = 0.2):
    
    # which layer to prune? let's start with conv1
    convs = [model.conv1, model.conv2]
    fcs = [model.fc1, model.fc2, model.fc3]

    # amt: how much to prune?
    
    # prune it 
    # n: the norm to use in computing weight importances
    for module in convs:
        # dim: the dimension along which to prune. 0th axis is the channel axis of the output of conv1
        prune.ln_structured(module, name="weight", amount=amt, n=2, dim=0)
        prune.remove(module, 'weight')
    
    for module in fcs[:-1]:
        # For a fully connected layer, dim=1 means input neurons
        # dim=0 means output neurons
        
#         array = module.weight.detach().numpy()
#         with np.printoptions(precision=5, suppress=True):
#             print(array.shape)
#             print(array)

        prune.ln_structured(module, name="weight", amount=amt, n=2, dim=1)
        prune.remove(module, 'weight')

#         array = module.weight.detach().numpy()
#         with np.printoptions(precision=5, suppress=True):
#             print(array.shape)
#             print(array)
            
    return model

In [None]:
# tune the model for 5 more epochs with 1/10th of the learning rate
# note that the pruning step above added forward_pre_hooks for the pruned
# weight tensors. So, at every forward pass, those weights that 
# have been pruned will be zeroed out while computing the model output &
# then backpropagating.

def tune(model, num_epochs):
    return train(model, num_epochs, lr=lr/10)

In [None]:
model_pruned = deepcopy(model)
model_pruned.load_state_dict(model.state_dict())

model_pruned = prune_once(model_pruned, amt=0.35)
model_pruned = prune_once(model_pruned, amt=0.35)

model_tuned = deepcopy(model)
model_tuned.load_state_dict(model_pruned.state_dict())
tune(model_tuned, 5)

In [None]:
evaluate(model)
evaluate(model_pruned)
evaluate(model_tuned)

explain_shap(x, model, 800)
explain_shap(x, model_pruned, 800)
explain_shap(x, model_tuned, 800)

In [None]:
explain_lime(x, model)
explain_lime(x, model_pruned)
explain_lime(x, model_tuned, 800)

## Make the pruning stick

In [None]:
# # save the "pruned" model. This size should be more than the original
# # model because the weights haven't actually been removed. They are just 
# # being masked in the forward pass. So, in addition to the weights, masks
# # are also being saved in the state_dict.

# checkpoint(model, "pruned")

# print(model.state_dict())

In [None]:
# make the pruning stick
print(list(module.named_parameters()))

# this "applies" the mask to the weights and actually changes the weight tensor
prune.remove(module, 'weight')

print(list(module.named_parameters()))

In [None]:
# Now the size of the checkpoint will be the same as the 
# original model since we don't need to save the masks anymore

checkpoint(model, "pruned_final")