In [None]:
!ls

In [None]:
%matplotlib widget

In [None]:
import matplotlib.pyplot as plt


In [None]:
plt.plot([0, 1, 2])

In [None]:
from itertools import chain

import fiftyone as fo
import torch

import flash
from flash.core.classification import FiftyOneLabels, Labels
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassificationData, ImageClassifier
import torchvision

In [None]:

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")


In [None]:

# 2 Load data into FiftyOne
train_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/train/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)
val_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/val/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)
test_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/test/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)


In [None]:
ImageClassificationData.

In [None]:
from torchvision import transforms as T


In [None]:
datamodule = ImageClassificationData.from_fiftyone(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    test_dataset=test_dataset
)

In [None]:

# 3 Fine tune a model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=datamodule.num_classes,
    serializer=Labels(),
)
trainer = flash.Trainer(
    max_epochs=1,
    gpus=torch.cuda.device_count(),
    limit_train_batches=1,
    limit_val_batches=1,
)
trainer.finetune(
    model,
    datamodule=datamodule,
    strategy=FreezeUnfreeze(unfreeze_epoch=1),
)
trainer.save_checkpoint("image_classification_model.pt")


In [None]:
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

In [None]:
model;

In [None]:
test_dataset

In [None]:
model.serializer = FiftyOneLabels(return_filepath=False)  # output FiftyOne format


In [None]:
datamodule

In [None]:
model.serializer = FiftyOneLabels(return_filepath=False)  # output FiftyOne format
datamodule_predict = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset)
predictions = trainer.predict(model, datamodule=datamodule_predict)
predictions = list(chain.from_iterable(predictions))  # flatten batches

In [None]:
datamodule

In [None]:
predictions

In [None]:
# 6 Add predictions to dataset
test_dataset.set_values("predictions", predictions)



In [None]:
# 7 Evaluate your model
results = test_dataset.evaluate_classifications("predictions", gt_field="ground_truth", eval_key="eval")
results.print_report()



In [None]:
plot = results.plot_confusion_matrix()
plot.show()


In [None]:
# 8 Visualize results in the App
session = fo.launch_app(test_dataset ,auto=False)



In [None]:
session.open_tab()


In [None]:
# Optional: block execution until App is closed
session.wait()

In [None]:
!ip a