## Installation stuff

In [None]:
# Install Mantisshrimp package
!pip install -q git+git://github.com/lgvaz/mantisshrimp.git

  Building wheel for mantisshrimp (setup.py) ... [?25l[?25hdone


In [None]:
# Install cocoapi and albumentations packages
!pip install -q 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
!pip install -q albumentations --upgrade

# Install fastai and/or Pytorch-Lightning
!pip install -q fastai2
!pip install -q pytorch-lightning

In [None]:
!pip install -q wandb

In [None]:
! wandb login

In [None]:
!nvidia-smi

## Load the dataset

In [None]:
from mantisshrimp.imports import *
from mantisshrimp import *
import albumentations as A

In [None]:
import wandb

In [None]:
path = datasets.pets.load()

HBox(children=(FloatProgress(value=0.0, max=791918971.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19173078.0), HTML(value='')))




In [None]:
parser = datasets.pets.parser(path)

In [None]:
CLASSES = datasets.pets.CLASSES

In [None]:
data_splitter = RandomSplitter([.8, .2])

In [None]:
train_records, valid_records = parser.parse(data_splitter)

HBox(children=(FloatProgress(value=0.0, max=3686.0), HTML(value='')))




In [None]:
# ImageNet stats
imagenet_mean, imagenet_std = IMAGENET_STATS

train_tfms = AlbuTransform(
    [
        A.LongestMaxSize(384),
        A.RandomSizedBBoxSafeCrop(320, 320, p=0.3),
        A.HorizontalFlip(),
        A.ShiftScaleRotate(rotate_limit=20),
        A.RGBShift(always_apply=True),
        A.RandomBrightnessContrast(),
        A.Blur(blur_limit=(1, 3)),
        A.Normalize(mean=imagenet_mean, std=imagenet_std),
    ]
)

In [None]:
valid_tfms = AlbuTransform(
    [
        A.LongestMaxSize(384),
        A.Normalize(mean=imagenet_mean, std=imagenet_std),
    ]
)

In [None]:
train_ds = Dataset(train_records, train_tfms)
valid_ds = Dataset(valid_records, valid_tfms)

In [None]:
model = MantisFasterRCNN(num_classes= len(CLASSES))

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


HBox(children=(FloatProgress(value=0.0, max=167502836.0), HTML(value='')))




In [None]:
train_dl = model.dataloader(train_ds, batch_size=16, num_workers=4, shuffle=True)
valid_dl = model.dataloader(valid_ds, batch_size=16, num_workers=4, shuffle=False)

## Ligthning Code

In [None]:
# import lightning engine provided by the mantisshrimp modules
from mantisshrimp.engines.lightning import *
from pytorch_lightning.loggers import WandbLogger

In [None]:
sweep_config = {
    'method': 'random', #grid, random
    'parameters': {
        'learning_rate': {
            'values': [0.002, 0.001]
        },
        'optimizer': {
            'values': ['adam', 'sgd']
        }
    },
    'program' : 'train'
}

In [None]:
wandb_logger = WandbLogger()
wandb.init(project="mantis_demo1")

W&B Run: https://app.wandb.ai/oke-aditya/mantis_demo1/runs/13caseo8

In [None]:
class LightModel(RCNNLightningAdapter):
    def configure_optimizers(self):
        opt = SGD(self.parameters(), 2e-4, momentum=0.9)
        return opt

In [None]:
trainer = Trainer(max_epochs=3, gpus=1, logger=wandb_logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [None]:
light_model = LightModel(model)

In [None]:
def train():
    trainer.fit(light_model, train_dl, valid_dl)

In [None]:
if __name__ == 'main':
    train()

In [None]:
sweep_id = wandb.sweep(sweep_config,project='mantis_demo1')

Create sweep with ID: wm11k1wa
Sweep URL: https://app.wandb.ai/oke-aditya/mantis_demo1/sweeps/wm11k1wa


In [None]:
wandb.agent(sweep_id, train)
# !wandb agent grz4dwda