In [242]:
import torch, os
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import os
import io
from PIL import Image
import glob

In [243]:
class Net(nn.Module):
    def __init__(self, num_deep_layers=3):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=5, padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=256, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=256)
        
        self.fc_final = nn.Linear(in_features=256, out_features=33)

    def forward(self, x): 

        x = F.relu(self.conv1(x))
        x = self.pool(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool(x)
        
        x = F.avg_pool2d(x, kernel_size=x.shape[2:])
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc_final(x)
        
        return x

In [244]:
net = Net()
net.load_state_dict(torch.load(f"models_13/model38.pth"))
if torch.cuda.is_available():
    net = net.cuda()

In [245]:
net_altered = Net()
net_altered.load_state_dict(torch.load(f"models_13/model38.pth"))
if torch.cuda.is_available():
    net_altered = net_altered.cuda()

## Set the weights of a filter to 0 here:
with torch.no_grad():
    net_altered.conv1.weight[0].copy_(torch.zeros_like(net_altered.conv1.weight[0]))
    net_altered.conv1.weight[4].copy_(torch.zeros_like(net_altered.conv1.weight[4]))
    net_altered.conv2.weight[13].copy_(torch.zeros_like(net_altered.conv2.weight[13]))
    net_altered.conv2.weight[28].copy_(torch.zeros_like(net_altered.conv2.weight[28]))
    net_altered.conv3.weight[3].copy_(torch.zeros_like(net_altered.conv3.weight[3]))
    net_altered.conv3.weight[19].copy_(torch.zeros_like(net_altered.conv3.weight[19]))
    net_altered.conv4.weight[14].copy_(torch.zeros_like(net_altered.conv4.weight[14]))    
    net_altered.conv4.weight[21].copy_(torch.zeros_like(net_altered.conv4.weight[21]))

In [246]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])])

test_data_dir = '4/test'

testset = torchvision.datasets.ImageFolder(root= test_data_dir, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=2)

In [247]:
class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

testset_new = ImageFolderWithPaths(root= test_data_dir, transform=transform)
testloader_new = torch.utils.data.DataLoader(testset_new, batch_size=16,
                                         shuffle=False, num_workers=2)

In [248]:
########################################################################
# Let us look at how the network performs on the test dataset.

def test(testloader, model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in (testloader):
            images, labels = data
            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()        
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100*correct/total

#########################################################################
# get details of classes and class to index mapping in a directory
def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def classwise_test(testloader, model):
########################################################################
# class-wise accuracy

    classes, _ = find_classes(test_data_dir)
    n_class = len(classes) # number of classes

    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()        
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(4):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    res = {}
    for i in range(n_class):
        res[classes[i]] = 100 * class_correct[i] / class_total[i]
    return res

In [249]:
orig_acc = test(testloader,net)
altered_acc = test(testloader,net_altered)
print(f"Accuracy on original model: {orig_acc}%")
print(f"Accuracy on changed model: {altered_acc}%")

Accuracy on original model: 45.666666666666664%
Accuracy on changed model: 42.333333333333336%


In [250]:
orig_classwise_acc = classwise_test(testloader, net)
altered_classwise_acc = classwise_test(testloader, net_altered)

100%|██████████| 207/207 [00:02<00:00, 92.24it/s]
100%|██████████| 207/207 [00:02<00:00, 93.91it/s] 


In [251]:
def get_change_in_acc():
    classes, _ = find_classes(test_data_dir)
    class_list = []
    for c in classes:
        orig = orig_classwise_acc[c]
        altered = altered_classwise_acc[c]
        if altered < orig:
            print(f"Class: {c}")
            print(f"Old accuracy: {orig}")
            print(f"New accuracy: {altered}")
            print()
            class_list.append(c)
    return ", ".join(class_list)

classes_changed = get_change_in_acc()

Class: Ibizan_hound
Old accuracy: 17.857142857142858
New accuracy: 14.285714285714286

Class: beer_bottle
Old accuracy: 54.166666666666664
New accuracy: 45.833333333333336

Class: bolete
Old accuracy: 41.666666666666664
New accuracy: 37.5

Class: boxer
Old accuracy: 42.857142857142854
New accuracy: 17.857142857142858

Class: electric_guitar
Old accuracy: 12.5
New accuracy: 4.166666666666667

Class: file
Old accuracy: 42.857142857142854
New accuracy: 35.714285714285715

Class: garbage_truck
Old accuracy: 37.5
New accuracy: 33.333333333333336

Class: gordon_setter
Old accuracy: 60.714285714285715
New accuracy: 57.142857142857146

Class: hair_slide
Old accuracy: 41.666666666666664
New accuracy: 20.833333333333332

Class: house_finch
Old accuracy: 50.0
New accuracy: 33.333333333333336

Class: komondor
Old accuracy: 21.428571428571427
New accuracy: 17.857142857142858

Class: malamute
Old accuracy: 62.5
New accuracy: 54.166666666666664

Class: pencil_box
Old accuracy: 37.5
New accuracy: 33.3

In [252]:
classes_changed

'Ibizan_hound, beer_bottle, bolete, boxer, electric_guitar, file, garbage_truck, gordon_setter, hair_slide, house_finch, komondor, malamute, pencil_box, prayer_rug, reel, stage, tile_roof, tobacco_shop, trifle'

In [253]:
files = glob.glob("4/test/*/*")

In [254]:
def read_img(img_path):
    with open(img_path, 'rb') as f:
        image_bytes = f.read()
        image = Image.open(io.BytesIO(image_bytes))
    return np.asarray(image)

In [255]:
def transform_image(image):
    transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])])
    img_copy = image.copy()
    return transform(img_copy).unsqueeze(0)

In [256]:
def get_prediction(model, image):
    tensor = transform_image(image)
    tensor = tensor.cuda()
    output = model.forward(tensor)
    _, predicted = torch.max(output.data, 1)
    return predicted

In [257]:
def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

classes, class_to_idx = find_classes(test_data_dir)

def get_differences(testloader, model, model_altered):
    res = []
    with torch.no_grad():
        for data in (testloader):
            images, labels, paths = data
            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()        
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            outputs = model_altered(images)
            _, predicted_altered = torch.max(outputs.data, 1)
            for i in range(len(labels)):
                label = labels[i]
                orig_pred = predicted[i]
                new_pred = predicted_altered[i]
                if not classes[orig_pred.cpu().item()] in classes_changed:
                    continue
                if label==orig_pred and new_pred != orig_pred:
                    res.append({
                        "fpath": paths[i],
                        "old_pred": classes[orig_pred.cpu().item()],
                        "new_pred": classes[new_pred.cpu().item()],
                    })
    return res

In [258]:
res = get_differences(testloader_new, net, net_altered)

In [259]:
import random
random.shuffle(res)

In [None]:
if len(res)>=4:
    f, axarr = plt.subplots(2,2, figsize=(10,10))
    for i in range(2):
        for j in range(2):
            axarr[i,j].imshow(read_img(res[j*2 + i]["fpath"]))
            axarr[i,j].set_xlabel(f"{res[j*2 + i]['old_pred']} -> {res[j*2 + i]['new_pred']}")
    plt.show()
else:
    f, axarr = plt.subplots(1,2, figsize=(10,10))
    for i in range(2):
        axarr[i].imshow(read_img(res[i]["fpath"]))
        axarr[i].set_xlabel(f"{res[i]['old_pred']} -> {res[i]['new_pred']}")
    plt.show()