In [1]:
from fastai.vision.all import *

In [None]:
import torch
from torch import tensor
from torchvision.models.resnet import resnet34
from PIL import Image
from itertools import compress

import pandas as pd
from pathlib import Path
from fastcore.xtras import Path

from fastai.data.core import show_at, Datasets
from fastai.data.external import URLs, untar_data
from fastai.data.transforms import (
    ColReader,
    IntToFloatTensor, 
    MultiCategorize, 
    Normalize,
    OneHotEncode, 
    RandomSplitter,
)

from fastai.metrics import accuracy_multi

from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage
from fastai.vision.learner import vision_learner
from fastai.learner import Learner
from fastai.callback.schedule import Learner

In [2]:
src = untar_data(URLs.PLANET_SAMPLE)
df = pd.read_csv(src/'labels.csv')

In [3]:
df.head()

In [4]:
all_tags = df["tags"].values
all_labels = []
for row in all_tags:
    all_labels += row.split(" ")
len(all_labels)

In [5]:
different_labels = set(all_labels)
len(different_labels)

In [6]:
counts = {
    label: all_labels.count(label) 
    for label in different_labels
}

counts = {
    key: value 
    for key, value in 
    sorted(
        counts.items(), 
        key = lambda item: -item[1]
    )
}

In [7]:
counts

In [8]:
len(df)

In [9]:
for key, count in counts.items():
    if count < 10:
        df = df[df["tags"].str.contains(key) == False]

In [10]:
len(df)

In [11]:
df["image_name"].head(), src.ls()

In [12]:
(src/'train').ls()[:3]

In [13]:
PILImage.create((src/'train'/'train_2407.jpg'))

In [14]:
def get_x(row:pd.Series) -> Path:
    return (src/'train'/row.image_name).with_suffix(".jpg")

In [15]:
def get_y(row:pd.Series) -> L:
    return row.tags.split(" ")

In [16]:
row = df.iloc[0]
get_x(row), get_y(row)

In [17]:
get_x = ColReader(0, pref=f'{src}/train/', suff=".jpg")
get_y = ColReader(1, label_delim=" ")

In [18]:
tfms = [
    [get_x, PILImage.create], 
    [
        get_y,
        MultiCategorize(vocab=different_labels), 
        OneHotEncode(len(different_labels))
    ]
]

In [19]:
train_idxs, valid_idxs = (
    RandomSplitter(valid_pct=0.2, seed=42)(df)
)

In [20]:
train_idxs, valid_idxs

In [21]:
dsets = Datasets(df, tfms=tfms, splits=[train_idxs, valid_idxs])

In [22]:
dsets.train[0]

In [None]:
show_at(dsets.train, 0);

In [None]:
batch_tfms = [
    IntToFloatTensor(), 
    *aug_transforms(
        flip_vert=True, 
        max_lighting=0.1, 
        max_zoom=1.05, 
        max_warp=0.
    ), 
    Normalize.from_stats(*imagenet_stats)
]

In [None]:
dls = dsets.dataloaders(
    after_item=[ToTensor], 
    after_batch=batch_tfms
)

In [None]:
dls.device

In [None]:
dls.show_batch()

In [None]:
learn = vision_learner(dls, resnet34, metrics=[accuracy_multi])

In [None]:
learn.model[1]

In [None]:
learn.loss_func

In [None]:
t = tensor([[0.1, 0.5, 0.3, 0.7, 0.2]])
torch.sigmoid(t)

In [None]:
learn.loss_func.thresh

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(1, slice(2e-3))

In [None]:
learn.unfreeze()
learn.fit_one_cycle(5, slice(2e-3/2.6**4, 2e-3))

In [None]:
learn.show_results(figsize=(15,15))

In [24]:
model = learn.model
fname = get_x(df.iloc[0])

In [2]:
fname = '/home/zach/.fastai/data/planet_sample/train/train_21983.jpg'

In [5]:
from torchvision.transforms import PILToTensor

In [7]:
im = Image.open(fname)
im = im.convert("RGB")
t_im = PILToTensor()(im)

In [9]:
t_im = t_im.unsqueeze(0)
t_im = t_im.float().div_(255.)

In [None]:
mean, std = (
    [0.485, 0.456, 0.406], 
    [0.229, 0.224, 0.225]
)
vector = [1]*4
vector[1] = -1
mean = tensor(mean).view(*vector)
std = tensor(std).view(*vector)

In [None]:
mean.shape, std.shape

In [None]:
t_im = (t_im - mean) / std

In [None]:
t_im.shape

In [None]:
with torch.inference_mode():
    model.eval()
    preds = model(t_im.cuda())

In [None]:
preds.shape

In [None]:
decoded_preds = torch.sigmoid(preds) > 0.5

In [None]:
decoded_preds

In [None]:
from itertools import compress

In [None]:
present_labels = list(compress(
        data=list(different_labels), selectors=decoded_preds[0]
    ))

In [None]:
present_labels

In [None]:
learn.predict(fname)[0]

In [None]:
im = Image.open(fname)
im = im.convert("RGB")
t_im = PILToTensor()(im)

mean, std = (
    [0.485, 0.456, 0.406], 
    [0.229, 0.224, 0.225]
)
vector = [1]*4
vector[1] = -1
mean = tensor(mean).view(*vector)
std = tensor(std).view(*vector)
t_im = (t_im - mean) / std
with torch.inference_mode():
    model.eval()
    preds = model(t_im.cuda())
    
decoded_preds = torch.sigmoid(preds) > 0.5

present_labels = list(compress(
        data=list(different_labels), selectors=decoded_preds[0]
    ))