In [None]:
from fastcore.all import *
from fastai.vision.all import *
import timm
import numpy as np

Setup the data loader the read in the images.

Within the `images` folder it should have subfolders for each tag.
if a image has multiple tags it should have a comma separator in the folder name.

example:
- dog,foxhound/image.jpg
- dog,bulldog/image.jpg

In [None]:
def parent_labels(o):
    "Label `item` with the parent folder name."
    return Path(o).parent.name.split(",")

batch_tfms = aug_transforms(max_lighting=0.1, size=224)
dls = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2),
    get_y=parent_labels,
    item_tfms=[Resize(224, method='squish')],
    batch_tfms=batch_tfms
).dataloaders("images", bs=256)


Values that help with with training.

In [6]:
class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat):
    def __init__(self, eps:float=0.1, **kwargs):
        self.eps = eps
        super().__init__(thresh=0.2, **kwargs)
    
    def __call__(self, inp, targ, **kwargs):
        # https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/166833#929222
        targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps
        return super().__call__(inp, targ_smooth, **kwargs)
    
    def __repr__(self):
        return "FlattenedLoss of LabelSmoothingBCEWithLogits()"
    
metrics=[FBetaMulti(2.0, 0.2, average='samples'), partial(accuracy_multi, thresh=0.95)]
wd      = 5e-7 #weight decay parameter
opt_func = partial(ranger, wd=wd)

Setup the vision learner

In [7]:
learn = vision_learner(dls, "vit_base_patch32_224", metrics=metrics, loss_func=LabelSmoothingBCEWithLogitsLossFlat(), opt_func=opt_func).to_fp16()

Train the model, takes about 30 min on a RTX 3060 with 17.000 images.

In [5]:
learn.fine_tune(50)

In [None]:
learn.recorder.plot_loss()

In [None]:
learn.validate()

save a snapshot in case we want to load it in later.

In [None]:
learn.save('vit_base_patch16_224-50')

Export the model, this can be used with https://huggingface.co/spaces/cc1234/stashtag/tree/main

In [37]:
learn.export("models.pkl")

In [None]:
leanr_int = ClassificationInterpretation.from_learner(learn)
# leanr_int.plot_confusion_matrix()
leanr_int.plot_top_losses(6, nrows=2)
# leanr_int.print_classification_report()

In [None]:
learn.show_results()

Test the result. 
!import will use linux screenshot tool to take a screengrab but otherwise you can just pass in a file path.

In [None]:
# !import dd.jpg
a = "dd.jpg"
image = PILImage.create(a)
tags,x,y = learn.predict(image)
print(f"Tags are: {tags}.")
for i, s in enumerate(x):
    if s:
        print(y[i])

image.show()