In [14]:
from simpsons_neural_network_1 import SimpsonsNet1

import torch

from PIL import Image
import torchvision.transforms as transforms

import torch.nn.functional as F

import os

# Check if we can use Cuda

In [15]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

# device = "cpu" # uncomment if you want to use "cpu", currently cpu is faster than cuda (maybe because the NN is very little)
print(f"Using {device} device")

Using cuda device


# Some general methods

In [16]:
all_labels = ["abraham_grampa_simpson",
              "agnes_skinner",
              "apu_nahasapeemapetilon",
              "barney_gumble",
              "bart_simpson",
              "carl_carlson",
              "charles_montgomery_burns",
              "chief_wiggum",
              "cletus_spuckler",
              "comic_book_guy",
              "disco_stu",
              "edna_krabappel",
              "fat_tony",
              "gil",
              "groundskeeper_willie",
              "homer_simpson",
              "kent_brockman",
              "krusty_the_clown",
              "lenny_leonard",
              "lionel_hutz",
              "lisa_simpson",
              "maggie_simpson",
              "marge_simpson",
              "martin_prince",
              "mayor_quimby",
              "milhouse_van_houten",
              "miss_hoover",
              "moe_szyslak",
              "ned_flanders"
              ]

def image_to_tensor(_image_path: str) -> torch.Tensor:
    image = Image.open(_image_path)

    transform = transforms.Compose([transforms.Resize((224, 224)),
                                    #transforms.RandomHorizontalFlip(),
                                    #transforms.RandomRotation(15),
                                    transforms.ToTensor()])
    _image_tensor = transform(image)

    return _image_tensor.to(device)

def show_image_by_path(_image_path: str) -> None:
    image = Image.open(_image_path)
    image.show()

# Load the model

In [18]:
run_id = "1c2ehxl7"

model = SimpsonsNet1()
model.to(device)
model.load_state_dict(torch.load(f"trained_models/simpsons_net_1_{run_id}_2.pth"))

<All keys matched successfully>

# Test the model

## Single image

In [19]:
with torch.no_grad():  # Disable gradient calculation
    path = "data/test/461179.jpg"
    input_tensor = image_to_tensor(path)
    output = model(input_tensor)  # Run inference

    # output tensor has shape [1, 29]
    # one batch dimension for my single image and 29 class scores

    probabilities = F.softmax(output, dim=1)
    predicted_label_idx = torch.argmax(probabilities).item()

    print(f"I guess this is {all_labels[predicted_label_idx]}")
    show_image_by_path(path)

I guess this is chief_wiggum


## All images

In [20]:
root_dir = "data/test"

output_file = open(f"solutions/solution_{run_id}.csv", "w")
output_file.write("Id,Category\n")

idx = 0

for filename in os.listdir(root_dir):
    if filename.endswith(".jpg"):
        if idx % 10 == 0:
            print(f"Predicting file with index {idx}")
        idx += 1

        img_path = os.path.join(root_dir, filename)
        input_tensor = image_to_tensor(img_path)
        output = model(input_tensor)

        probabilities = F.softmax(output, dim=1)
        predicted_label_idx = torch.argmax(probabilities).item()

        output_file.write(f"{filename},{all_labels[predicted_label_idx]}\n")

output_file.close()


Predicting file with index 0
Predicting file with index 10
Predicting file with index 20
Predicting file with index 30
Predicting file with index 40
Predicting file with index 50
Predicting file with index 60
Predicting file with index 70
Predicting file with index 80
Predicting file with index 90
Predicting file with index 100
Predicting file with index 110
Predicting file with index 120
Predicting file with index 130
Predicting file with index 140
Predicting file with index 150
Predicting file with index 160
Predicting file with index 170
Predicting file with index 180
Predicting file with index 190
Predicting file with index 200
Predicting file with index 210
Predicting file with index 220
Predicting file with index 230
Predicting file with index 240
Predicting file with index 250
Predicting file with index 260
Predicting file with index 270
Predicting file with index 280
Predicting file with index 290
Predicting file with index 300
Predicting file with index 310
Predicting file wit