In [28]:
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
import os
import matplotlib.pyplot as plt
import ast

In [67]:
model = torch.hub.load('pytorch/vision:v0.8.2', 'alexnet', pretrained=True, force_reload=True)
model.eval()

Downloading: "https://github.com/pytorch/vision/archive/v0.8.2.zip" to C:\Users\mgina/.cache\torch\hub\v0.8.2.zip


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [87]:
all_params = model.state_dict()

In [92]:
all_params['features.0.weight'].shape

torch.Size([64, 3, 11, 11])

In [68]:
# A simple hook class that returns the input and output of a layer during forward/backward pass
class Hook():
    def __init__(self, module, backward=False):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.outputs = []
    def hook_fn(self, module, input, output):
        self.outputs.append(output)
    def close(self):
        self.hook.remove()

In [69]:
hooks = []
for layer in list(model._modules.items()):
    if hasattr(layer[1], '__iter__'):
        for l in layer[1]:
            hooks.append(Hook(l))
    else:
        hooks.append(Hook(layer[1]))

In [70]:
# Load data: if Windows
path='../Image/All_Cropped'

images={}

directory=['Mountain', 'Beach', 
           'Mug', 'Banana', 
           'Car', 'Plane', 
           'Lighthouse', 'Church']

for dir in directory:
    images[dir]=[file for file in os.listdir(path+'/'+dir) if file.endswith(('jpeg', 'jpg'))]

In [71]:
preprocess = transforms.Compose([
#     transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this should be the mean and std for alexnet training dataset
])

In [72]:
# read labels 
with open("../imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# read labels to wordnet synsets
ltw = pd.read_json('../imagenet_label_to_wordnet_synset.json').T
selected_labels = pd.read_csv('../selected_labels.csv',index_col=0) # selected labels

In [73]:
# map type to corresponding label's wordnet id
id_mapping = dict()
for i in selected_labels.index:
    id_mapping[selected_labels.loc[i]['type']] = selected_labels.loc[i]['id']
id_mapping

{'Banana': '07753592-n',
 'Beach': '09428293-n',
 'Car': '02814533-n',
 'Church': '03028079-n',
 'Lighthouse': '02814860-n',
 'Mountain': '09193705-n',
 'Mug': '03063599-n',
 'Plane': '02690373-n'}

In [74]:
# map type to corresponding label's readable label
readable_mapping = dict()
for i in selected_labels.index:
    readable_mapping[selected_labels.loc[i]['type']] = selected_labels.loc[i]['label']
readable_mapping

{'Banana': 'banana',
 'Beach': 'seashore, coast, seacoast, sea-coast',
 'Car': 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
 'Church': 'church, church building',
 'Lighthouse': 'beacon, lighthouse, beacon light, pharos',
 'Mountain': 'alp',
 'Mug': 'coffee mug',
 'Plane': 'airliner'}

In [75]:
#type prob cate+filename
# all_img={'type':[], 'dir':[],'id_labels':[],'readable_labels':[]}
all_img={'type':[], 'dir':[],'prob':[], 'id_label':[], 'readable_label':[],'raw':[]}

In [76]:
for cate, filename in images.items():
    selected_id_label = id_mapping[cate] 
    selected_readable_label = readable_mapping[cate] 
    for f in filename:
        dir=os.path.join(path,cate,f)
        img=Image.open(dir)
        #print(dir)
        input_tensor=preprocess(img)
        input_batch = input_tensor.unsqueeze(0)

        if torch.cuda.is_available():
            input_batch= input_batch.to('cuda')
            model.to('cuda')

        with torch.no_grad():
            output = model(input_batch)
        
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        # Get all 1000 categories per image
        prob, catid = torch.topk(probabilities, probabilities.size(0))
        temp1 = {}
        temp2 = {}
        for i in range(prob.size(0)):
#             temp1[categories[catid[i].item()]] = prob[i].item() # readable labels
            temp2[ltw['id'][catid[i].item()]] = prob[i].item() # wordnet synset id
            temp1[ltw['id'][catid[i].item()]] = output[0][catid[i]].item() # raw score
            
        
        all_img['dir'].append(dir)
#         all_img['readable_labels'].append(temp1)
#         all_img['id_labels'].append(temp2)
        all_img['readable_label'].append(selected_readable_label)
        all_img['id_label'].append(selected_id_label)
        all_img['prob'].append(temp2[selected_id_label])
        all_img['raw'].append(temp1[selected_id_label])
        
        if 'Mountain' in cate:
            all_img['type'].append('Mountain')
        if 'Beach' in cate:
            all_img['type'].append('Beach')
        if 'Mug' in cate:
            all_img['type'].append('Mug')
        if 'Banana' in cate:
            all_img['type'].append('Banana')
        if 'Car' in cate:
            all_img['type'].append('Car')
        if 'Plane' in cate:
            all_img['type'].append('Plane')
        if 'Lighthouse' in cate:
            all_img['type'].append('Lighthouse')
        if 'Church' in cate:
            all_img['type'].append('Church')

#         print('Typicality of ',f, ' = ',probabilities.max()*100)

In [83]:
for hook in hooks:
    print(len(hook.outputs))
#     print(hook.outputs[0])
    print(hook.outputs[0].shape)

128
torch.Size([1, 64, 63, 63])
128
torch.Size([1, 64, 63, 63])
128
torch.Size([1, 64, 31, 31])
128
torch.Size([1, 192, 31, 31])
128
torch.Size([1, 192, 31, 31])
128
torch.Size([1, 192, 15, 15])
128
torch.Size([1, 384, 15, 15])
128
torch.Size([1, 384, 15, 15])
128
torch.Size([1, 256, 15, 15])
128
torch.Size([1, 256, 15, 15])
128
torch.Size([1, 256, 15, 15])
128
torch.Size([1, 256, 15, 15])
128
torch.Size([1, 256, 7, 7])
128
torch.Size([1, 256, 6, 6])
128
torch.Size([1, 9216])
128
torch.Size([1, 4096])
128
torch.Size([1, 4096])
128
torch.Size([1, 4096])
128
torch.Size([1, 4096])
128
torch.Size([1, 4096])
128
torch.Size([1, 1000])


In [43]:
all_img_df=pd.DataFrame(all_img)
all_img_df.head()

Unnamed: 0,type,dir,prob,id_label,readable_label,raw
0,Mountain,../Image/All_Cropped\Mountain\1.jpg,0.336247,09193705-n,alp,10.06924
1,Mountain,../Image/All_Cropped\Mountain\10.jpg,0.168704,09193705-n,alp,10.357352
2,Mountain,../Image/All_Cropped\Mountain\11.jpg,0.078611,09193705-n,alp,10.654603
3,Mountain,../Image/All_Cropped\Mountain\12.jpg,0.504526,09193705-n,alp,12.633675
4,Mountain,../Image/All_Cropped\Mountain\13.jpg,0.0062,09193705-n,alp,9.846738


In [44]:
# all_img_df.to_csv('alexnet_scores_wordnet_id_all_prob.csv')
# all_img_df.to_csv('alexnet_scores_final.csv')
all_img_df.to_csv('alexnet_scores_final_with_raw.csv')

In [45]:
# all_img_df['label'] = all_img_df['top5_labels'].apply(lambda x: list(x.keys())[0])

In [46]:
# all_img_df['7_score']=round(all_img_df['typicality_score']*7/100,1)
# all_img_df.head()

In [47]:
# score=[]
# for i in range(5):
#     dir=all_img_df['dir'][i]
#     score.append([all_img_df['typicality_score'][i], all_img_df['7_score'][i]])
#     img=Image.open(dir)
#     plt.imshow(img)
#     plt.title(score[i])
#     plt.show()

In [57]:
# all_img_df.to_csv('alexnet_scores_wordnet_id.csv')