In [None]:
from fastai.vision.all import *

Testing on an arbitrary task first:

In [None]:
set_seed(99, True)

In [None]:
path = untar_data(URLs.PETS)/'images'

In [None]:
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))

In [None]:
cbs = [MixUp()]
learn = vision_learner(dls, resnet34, metrics=error_rate, loss_func=LabelSmoothingCrossEntropy()).to_fp16()



In [None]:
learn.fine_tune(2)

epoch,train_loss,valid_loss,error_rate,time
0,0.45083,0.280304,0.009472,00:12


epoch,train_loss,valid_loss,error_rate,time
0,0.292623,0.229875,0.002706,00:14
1,0.246235,0.216064,0.002706,00:14


Testing on our task:

In [None]:
from functools import partial

import numpy as np
import albumentations as A
from fastai.data.transforms import Normalize
from fastai.vision.augment import (
    Resize,
    aug_transforms,
    imagenet_stats,
    RandomResizedCrop,
    RandTransform
)
from tsp_cls.dataloader.augment import AlbumentationsTransform
from fastai.vision.core import PILImage
from fastcore.basics import store_attr

from tsp_cls.utils.root import get_data_root
from tsp_cls.utils.data import (
    get_image_path,
    field_getter,
    read_dataframe,
    sample_dataframe,
)
from tsp_cls.dataloader.dataloader import get_dls

In [None]:
path = get_data_root()

In [None]:
df = read_dataframe(path, "SnakeCLEF2021_min-train_metadata_PROD.csv")
df = sample_dataframe(df, "genus", 10)

In [None]:
print(f"Length of DF: {len(df)}")

Length of DF: 22396


In [None]:
img_size = 224

def get_train_aug():
    return A.Compose(
        [
            A.RandomResizedCrop(img_size, img_size),
            A.Transpose(p=0.5),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
        ]
    )

def get_valid_aug():
    return A.Compose(
        [A.CenterCrop(img_size, img_size, p=1.0), A.Resize(img_size, img_size)],
        p=1.0,
    )

item_tfms = [Resize(256), AlbumentationsTransform(get_train_aug(), get_valid_aug())]
# item_tfms = [Resize(img_size)]

batch_tfms = Normalize.from_stats(*imagenet_stats)
dls = get_dls(
    df,
    get_x=partial(
        partial(get_image_path, data_path=path), data_path=get_data_root()
    ),
    get_y=partial(field_getter, field="genus"),
    item_tfms=item_tfms,
    batch_tfms=batch_tfms,
    bs=32,
)

print(f"Steps in train_dl: {len(dls.train)}")
print(f"Classes being trained on: {dls.vocab}")

Steps in train_dl: 559
Classes being trained on: ['Agkistrodon', 'Crotalus', 'Lampropeltis', 'Masticophis', 'Micrurus', 'Natrix', 'Nerodia', 'Pantherophis', 'Tantilla', 'Thamnophis']


In [None]:
learn = vision_learner(
    dls,
    "convnext_tiny",
    metrics=[error_rate, accuracy],
    cbs=[MixedPrecision(), MixUp()],
    wd=0.0,
    opt_func=Adam,
    loss_func=LabelSmoothingCrossEntropy()
)

In [None]:
learn.fit_one_cycle(2)

epoch,train_loss,valid_loss,error_rate,accuracy,time
0,1.838709,1.101924,0.254744,0.745256,01:30
1,1.654815,1.011514,0.222371,0.777629,01:29
