In [1]:
import torch
import numpy as np
import json
import pandas as pd
from glob import glob
from torchvision.io import read_image
from torchinfo import summary
from torchvision.transforms import (Resize,
                                    Normalize,
                                    CenterCrop)
from torchvision.models import (resnet50,
                                ResNet50_Weights)


In [2]:
resize = Resize((232, 232))
crop = CenterCrop(224)
normalize = Normalize([0.485, 0.456, 0.406],
                      [0.229, 0.224, 0.225])
imgfiles = sorted([f for f in glob("ex8_data/*.jpg")])
imgs = torch.stack([torch.div(crop(resize(read_image(f))), 255)
                    for f in imgfiles])

imgs = normalize(imgs)
imgs.size()

torch.Size([10, 3, 224, 224])

In [3]:
resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)
summary(resnet_model, input_data=imgs,
        col_names=['input_size', 'output_size', 'num_params'])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
ResNet                                   [10, 3, 224, 224]         [10, 1000]                --
├─Conv2d: 1-1                            [10, 3, 224, 224]         [10, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [10, 64, 112, 112]        [10, 64, 112, 112]        128
├─ReLU: 1-3                              [10, 64, 112, 112]        [10, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [10, 64, 112, 112]        [10, 64, 56, 56]          --
├─Sequential: 1-5                        [10, 64, 56, 56]          [10, 256, 56, 56]         --
│    └─Bottleneck: 2-1                   [10, 64, 56, 56]          [10, 256, 56, 56]         --
│    │    └─Conv2d: 3-1                  [10, 64, 56, 56]          [10, 64, 56, 56]          4,096
│    │    └─BatchNorm2d: 3-2             [10, 64, 56, 56]          [10, 64, 56, 56]          128
│    │    └─ReLU: 3-3      

In [4]:
resnet_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
img_preds = resnet_model(imgs)

In [6]:
img_probs = np.exp(np.asarray(img_preds.detach()))
img_probs /= img_probs.sum(1)[:, None]

In [7]:
labs = json.load(open('ex8_data/imagenet_class_index.json'))
class_labels = pd.DataFrame([(int(k), v[1]) for k, v in labs.items()],
                            columns=['idx', 'label'])
class_labels = class_labels.set_index('idx')
class_labels = class_labels.sort_index()

In [8]:
for i, imgfile in enumerate(imgfiles):
    img_df = class_labels.copy()
    img_df['prob'] = img_probs[i]
    img_df = img_df.sort_values(by='prob', ascending=False)[:5]
    print(f'Image: {imgfile}')
    print(img_df.reset_index().drop(columns=['idx']))

Image: ex8_data/bird.jpg
       label      prob
0   lorikeet  0.364869
1  bee_eater  0.064377
2  goldfinch  0.005395
3     toucan  0.003424
4      macaw  0.002890
Image: ex8_data/chameleon.jpg
               label      prob
0  African_chameleon  0.317990
1       green_lizard  0.005252
2      common_iguana  0.002816
3              agama  0.002479
4     frilled_lizard  0.002262
Image: ex8_data/deer.jpg
        label      prob
0     gazelle  0.329120
1      impala  0.222555
2  hartebeest  0.013309
3        ibex  0.008839
4       llama  0.003618
Image: ex8_data/fox.jpg
      label      prob
0   red_fox  0.386165
1   kit_fox  0.110565
2  grey_fox  0.013141
3     dhole  0.002593
4    coyote  0.002030
Image: ex8_data/giraffe.jpg
     label      prob
0  cheetah  0.387773
1  leopard  0.077544
2  gazelle  0.015461
3    hyena  0.011720
4    zebra  0.007514
Image: ex8_data/horse.jpg
                 label      prob
0               sorrel  0.352431
1           hartebeest  0.061019
2                