### Dataset
* https://github.com/garythung/trashnet
* https://drive.google.com/drive/folders/0B3P9oO5A3RvSUW9qTG11Ul83TEE?resourcekey=0-F-D8v2tnSfByG6ll3t9JxA

In [None]:
!unzip ./data/dataset-resized.zip -d ./data -x "__MACOSX/*"
!mv ./data/dataset-resized ./data/trashnet

In [None]:
from pathlib import Path

name = "trashnet"
path = Path("./data") / name

In [None]:
import fiftyone as fo

dataset = (
    fo.load_dataset(name)
    if fo.dataset_exists(name)
    else fo.Dataset.from_dir(
        path,
        fo.types.ImageClassificationDirectoryTree,
        name=name,
    )
)

dataset.persistent = True

session = fo.launch_app(dataset)

In [None]:
dataset.compute_metadata()
session.view = dataset.view()

In [None]:
import fiftyone.brain as fob

fob.compute_uniqueness(dataset)
session.view = dataset.view()

In [None]:
from flash.image import ImageClassificationData, ImageClassifier

from transforms import TimmInputTransform

datamodule = ImageClassificationData.from_folders(
    train_folder=path,
    val_split=0.2,
    batch_size=32,
    num_workers=0,
    # 3.
    # transform_kwargs={"image_size": 224},
)

model = ImageClassifier(
    num_classes=datamodule.num_classes,
    labels=datamodule.labels,
)

from flash import Trainer

trainer = Trainer(
    # 1
    # accelerator="auto",
)
trainer.finetune(
    model,
    datamodule=datamodule,
    # 2
    # strategy="freeze",
)


In [None]:
from flash.image import ImageClassificationData
from transforms import TimmInputTransform

datamodule = ImageClassificationData.from_folders(
    train_folder=path,
    val_split=0.2,
    batch_size=32,
    num_workers=0,
    # 2
    # transform=TimmInputTransform,
    transform_kwargs={"image_size": 224},
)

datamodule.show_train_batch(
    limit_nb_samples=8,
    figsize=(16, 7),
    # 1
    # hooks_names=["load_sample", "per_sample_transform"],
)

# 3 
# model = ImageClassifier(
#     num_classes=datamodule.num_classes,
#     labels=datamodule.labels,
# )

# from flash import Trainer

# trainer = Trainer(
#     accelerator="auto",
# )
# trainer.finetune(
#     model,
#     datamodule=datamodule,
#     strategy="freeze",
# )


Which model should I use?
* https://github.com/rwightman/pytorch-image-models/
* https://www.kaggle.com/code/jhoward/which-image-models-are-best

In [None]:
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
from flash.image import ImageClassificationData, ImageClassifier

model = ImageClassifier.load_from_checkpoint("path/to/checkpoint")
scripted_model = model.to_torchscript()
optimized_model = optimize_for_mobile(scripted_model)
optimized_model._save_for_lite_interpreter("model.ptl")
