In [2]:
from fastai.data.external import untar_data, URLs
from fastai.vision.data import imagenet_stats
from fastcore.xtras import Path

dataset_path = untar_data(URLs.PETS)
dataset_path.ls()

In [3]:
imagenet_stats

In [4]:
from torch import nn
from torchvision.transforms import CenterCrop, RandomResizedCrop, ToTensor, Normalize

train_transforms = nn.Sequential(
    RandomResizedCrop((224,224)),
    Normalize(*imagenet_stats)
)

valid_transforms = nn.Sequential(
    CenterCrop((224,224)),
    Normalize(*imagenet_stats)
)

In [5]:
import re
from PIL import Image
from torch.utils.data import Dataset

# This example is highly based on the work of Sylvain Gugger
# for the Accelerate notebook example which can be found here: 
# https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_cv_example.ipynb
class PetsDataset(Dataset):
    "A basic dataset that will return a tuple of (image, label)"
    def __init__(self, filenames:list, transforms:nn.Sequential, label_to_int:dict):
        self.filenames = filenames
        self.transforms = transforms
        self.label_to_int = label_to_int
        self.to_tensor = ToTensor()
    
    def __len__(self):
        return len(self.filenames)
    
    def apply_x_transforms(self, filename):
        image = Image.open(filename).convert("RGB")
        tensor_image = self.to_tensor(image)
        return self.transforms(tensor_image)
    
    def apply_y_transforms(self, filename):
        label = re.findall(r"^(.*)_\d+\.jpg$", filename.name)[0].lower()
        return self.label_to_int[label]
    
    def __getitem__(self, index):
        filename = self.filenames[index]
        x = self.apply_x_transforms(filename)
        y = self.apply_y_transforms(filename)
        return (x,y)

In [6]:
label_pat = r"^(.*)_\d+\.jpg$"
filenames = (dataset_path/'images').ls(file_exts=".jpg")

In [7]:
labels = filenames.map(
    lambda x: re.findall(label_pat, x.name)[0].lower()
).unique()

In [8]:
labels

In [9]:
label_to_int = {index:key for key, index in enumerate(labels)}
label_to_int.keys(), label_to_int["siamese"]

In [10]:
import numpy as np
shuffled_indexes = np.random.permutation(len(filenames))
split = int(0.8 * len(filenames))
train_indexes, valid_indexes = (
    shuffled_indexes[:split], shuffled_indexes[split:]
)

In [11]:
train_fnames = filenames[train_indexes]
valid_fnames = filenames[valid_indexes]

In [12]:
train_dataset = PetsDataset(
    train_fnames,
    train_transforms,
    label_to_int
)

valid_dataset = PetsDataset(
    valid_fnames,
    valid_transforms,
    label_to_int
)

In [13]:
x,y = train_dataset[0]
x.shape, y

In [14]:
from torch.utils.data import DataLoader

In [15]:
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    drop_last=True,
    batch_size=64
)

In [16]:
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=128
)

In [17]:
from fastai.data.core import DataLoaders

In [18]:
dls = DataLoaders(train_dataloader, valid_dataloader)

In [19]:
from torchvision.models import resnet34

model = resnet34(pretrained=True)

In [20]:
model.fc = nn.Linear(512, 37, bias=True)

In [21]:
model.fc

In [22]:
list(model.children())[-1]

In [23]:
for layer in list(model.children())[:-1]:
    if hasattr(layer, "requires_grad_"):
        layer.requires_grad_(False)

In [24]:
from torch.optim import AdamW

In [25]:
from functools import partial
from fastai.optimizer import OptimWrapper

In [26]:
opt_func = partial(OptimWrapper, opt=AdamW)

In [27]:
from fastai.losses import CrossEntropyLossFlat
from fastai.metrics import accuracy
from fastai.learner import Learner
from fastai.callback.schedule import Learner

In [28]:
model.cuda();

In [29]:
learn = Learner(
    dls, 
    model, 
    opt_func=opt_func, 
    loss_func=CrossEntropyLossFlat(), 
    metrics=accuracy
)

In [30]:
learn.lr_find()

In [31]:
learn.fit_one_cycle(5, 1e-3)

In [32]:
im = Image.open(filenames[0])
im

In [33]:
net = learn.model

In [34]:
tfm_x = valid_transforms(ToTensor()(im))
tfm_x = tfm_x.unsqueeze(0); tfm_x.shape

In [None]:
import torch
net.eval()
with torch.no_grad():
    preds = net(tfm_x.cuda())
pred = preds.argmax(dim=-1)[0]
label = list(label_to_int.keys())[pred]
pred, label