In [None]:
!pip install torchsummary
!pip install torchmetrics
!pip install torch_lr_finder
!pip install pytorch_lightning

In [None]:
%cd /kaggle/working/S13

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.tuner import Tuner
from torchsummary import summary
from Utilities import config
from Utilities.callbacks import (
    CheckClassAccuracyCallback,
    MAPCallback,
    PlotTestExamplesCallback,
)
from Utilities.model import YOLOv3

In [None]:
model = YOLOv3(num_classes=config.NUM_CLASSES)
summary(model.to(config.DEVICE), input_size=(3, config.IMAGE_SIZE, config.IMAGE_SIZE))

In [None]:
from Utilities.dataset import YOLODataModule

data_module = YOLODataModule(
    train_csv_path=config.DATASET + "/train.csv",
    test_csv_path=config.DATASET + "/test.csv"
)

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from Utilities.callbacks import PlotTestExamplesCallback, MAPCallback, CheckClassAccuracyCallback

trainer = pl.Trainer(
    max_epochs=40,
    accelerator=config.DEVICE,
    callbacks=[
        ModelCheckpoint(
            dirpath=config.CHECKPOINT_PATH,
            verbose=True
        ),
        PlotTestExamplesCallback(every_n_epochs=5),
        CheckClassAccuracyCallback(train_every_n_epochs=8, test_every_n_epochs=4),
        MAPCallback(every_n_epochs=40),
        LearningRateMonitor(logging_interval='step', log_momentum=True)
    ],
    default_root_dir='Store/',
    precision='16-mixed'
)

In [None]:
from pytorch_lightning.tuner import Tuner

tuner = Tuner(trainer=trainer)

# Run LR finder
lr_finder = tuner.lr_find(model, datamodule=data_module, min_lr=1e-4, max_lr=1, num_training=trainer.max_epochs)

In [None]:
# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
suggested_lr = lr_finder.suggestion()
print(f"{suggested_lr=}")

In [None]:
model.best_lr = 0.003981071705534973

In [None]:
trainer.fit(model, data_module)