In [12]:
import torch
import torch.nn as nn
import numpy as np 
import torchvision
from torchvision import models
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

### Data Prep

In [13]:
img_data_dir = "/media/curttigges/project-files/datasets/ms-coco/"

#### Data Preparation

In [14]:
from data.coco_data_module import COCODataModule

BATCH_SIZE = 64

dm = COCODataModule(
    img_data_dir,
    batch_size=BATCH_SIZE,
    num_workers=12)

In [15]:
dm.setup()

loading annotations into memory...
Done (t=9.26s)
creating index...
index created!
loading annotations into memory...
Done (t=6.35s)
creating index...
index created!


In [16]:
TRAIN_SIZE = dm.train_set.__len__()

In [17]:
TRAIN_SIZE

82783

### Models

In [18]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy, precision
import torchmetrics.functional as tf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

In [19]:
class ResNetMultiLabel(nn.Module):
    def __init__(self, model, n_classes):
        super().__init__()
        resnet = model
        resnet.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=resnet.fc.in_features, out_features=n_classes)
        )
        self.backbone = resnet
        #self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        
        out = self.backbone(x)
        
        return out

In [20]:
class ResNetMultiTrainModule(pl.LightningModule):
    def __init__(self, model, data, model_desc, batch_size, learning_rate, momentum, n_classes, thresh=0.5):
        super().__init__()

        # Key parameters
        self.save_hyperparameters()
        self.data = data
        self.model = ResNetMultiLabel(model, n_classes)

    def forward(self, x):
        x = self.model(x)        
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y.type(torch.float))
        acc = accuracy(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        prec = precision(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        recall = tf.recall(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        f1_score = tf.f1_score(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        rmap = tf.retrieval_average_precision(y_hat, y.type(torch.int))

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)
            self.log(f"{stage}_prec", prec, prog_bar=True)
            self.log(f"{stage}_recall", recall, prog_bar=True)
            self.log(f"{stage}_f1_score", f1_score, prog_bar=True)
            self.log(f"{stage}_rmap", rmap, prog_bar=True)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y.type(torch.float))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), 
            lr=self.hparams.learning_rate,
            betas=(0.9,0.999))
        
        lr_scheduler_dict = {
            "scheduler":OneCycleLR(
                optimizer,
                self.hparams.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=len(self.data.train_dataloader()),
                anneal_strategy='cos'
            ),
            "interval":"step",
        }
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler_dict}
        #return optimizer

In [21]:
pl_model = ResNetMultiTrainModule(
    models.resnet101(pretrained=True),
    data=dm,
    model_desc="resnet101",
    batch_size=BATCH_SIZE, 
    learning_rate=0.001, 
    momentum=0.9, 
    n_classes=80)

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


In [11]:
wandb_logger = WandbLogger(project="resnet-coco")
wandb_logger.watch(pl_model, log="all")

trainer = pl.Trainer(
    max_epochs=45,
    precision=16,
    accelerator='gpu', 
    devices=1,
    logger=wandb_logger,
    profiler="simple",
    callbacks=[TQDMProgressBar(refresh_rate=10)])
trainer.fit(pl_model, dm)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")


loading annotations into memory...
Done (t=6.60s)
creating index...
index created!
loading annotations into memory...
Done (t=1.71s)
creating index...
index created!


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNetMultiLabel | 42.7 M
-------------------------------------------
42.7 M    Trainable params
0         Non-trainable params
42.7 M    Total params
85.328    Total estimated model params size (MB)


Epoch 43:  49%|████▉     | 950/1927 [13:05:11<13:27:30, 49.59s/it, loss=0.00823, v_num=nare, val_loss=0.0735, val_acc=0.984, val_prec=0.866, val_recall=0.672, val_f1_score=0.756, val_rmap=0.832]   

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
