In [None]:
image_size = 1024
max_epochs = 50

In [None]:
import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector

In [None]:
#ObjectDetector.available_backbones()

In [None]:
datamodule = ObjectDetectionData.from_voc(
                 ["pothole"],
                 train_folder="./dashcam_public-5/train",
                 train_ann_folder="./dashcam_public-5/train",
                 val_folder="./dashcam_public-5/valid",
                 val_ann_folder="./dashcam_public-5/valid",
                 transform_kwargs=dict(image_size=image_size),
                 batch_size=2,
                 num_workers= 5
             )

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

early_stop_callback = EarlyStopping(monitor="val_val_loss", 
                                    min_delta=0.00,
                                    patience=3, 
                                    verbose=False, 
                                    mode="min")

In [None]:
import torch
# 2. Build the task
model = ObjectDetector(head="yolov5", 
                       backbone="medium", 
                       num_classes=datamodule.num_classes, 
                       image_size=image_size,
                       optimizer=torch.optim.Adam,
                       learning_rate = 1e-10
                       )

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=max_epochs, log_every_n_steps=10,  gpus=1)
                       #callbacks=[early_stop_callback], gpus=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

In [None]:
# 4. Detect objects in a few images!
from glob import glob

datamodule_test = ObjectDetectionData.from_files(
    predict_files= glob('./dashcam_public-5/test/*.jpg'),
    transform_kwargs={"image_size": image_size},
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule_test)
print(predictions)