In [None]:
import torch

import composer
from composer.datasets import coco_mmdet
from composer.models import composer_yolox
from torch.utils.data import DataLoader
from composer.datasets.coco_mmdet import mmdet_collate, mmdet_get_num_samples
from composer.core.data_spec import DataSpec
from composer.loggers import InMemoryLogger, LogLevel, WandBLogger



import logging, sys # disable logging in notebook
logging.disable(sys.maxsize)

torch.manual_seed(42) # For replicability

In [None]:
train_dataset = coco_mmdet(path='../../data/coco', split='train')
val_dataset = coco_mmdet(path='../../data/coco', split='val')

In [None]:
model = composer_yolox(model_name='yolox-s')

In [None]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=mmdet_collate, shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=mmdet_collate, shuffle=False, num_workers=8)

In [None]:
train_loader.persistent_workers

In [None]:
optimizer = composer.optim.DecoupledSGDW(
    model.parameters(), # Model parameters to update
    lr=0.01, # Peak learning rate
    momentum=0.9,
    weight_decay=5e-4,
    nesterov=True # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay
)

In [None]:
lr_scheduler = composer.optim.CosineAnnealingWithWarmupScheduler(
    t_warmup="30ep", # Warm up over 30 epoch
)

In [None]:
train_epochs = "300ep" # Train for 3 epochs because we're assuming Colab environment and hardware

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=DataSpec(train_loader, get_num_samples_in_batch=mmdet_get_num_samples),
    eval_dataloader=DataSpec(val_loader, get_num_samples_in_batch=mmdet_get_num_samples),
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    # train_subset_num_batches=10,
    device="gpu" if torch.cuda.is_available() else "cpu",
    precision='fp32' # currently, simOTA matcher will not run with AMP
    grad_accum=1,
    loggers=[InMemoryLogger(log_level=LogLevel.BATCH), WandBLogger(project='yolox-test')])


In [None]:
trainer.fit()