In [89]:
import os
from typing import Tuple

import torch
from PIL import Image
import numpy as np
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt

In [30]:
def get_pretrained_model_with_trainable_last_layer():
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=2)

    return model

In [92]:
def load_dataset() -> Tuple[torch.Tensor, torch.Tensor]:
    transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
    with open(os.path.join('annotations', 'list.txt'), encoding='utf-8') as f:
        lines = f.read().split('\n')
    transformed_imgs = []
    labels = []
    for line in lines:
        if line and not line.startswith('#'):
            line_parts = line.split()
            img_name = line_parts[0]
            species_id = int(line_parts[2]) - 1
            img = Image.open(os.path.join('images', f'{img_name}.jpg')).convert('RGB')
            transformed_img = transformation(img)
            transformed_imgs.append(transformed_img)
            labels.append(species_id)
    return torch.stack(transformed_imgs), torch.tensor(labels)

In [93]:
model = get_pretrained_model_with_trainable_last_layer()

In [94]:
X, y = load_dataset()

In [95]:
X.dtype, X.shape, y.dtype, y.shape

(torch.float32,
 torch.Size([7349, 3, 224, 224]),
 torch.int64,
 torch.Size([7349]))