### Data Prep

In [1]:
img_data_dir = "/media/curttigges/project-files/datasets/ms-coco/"

### Models

In [2]:
import torch
import math
import timm
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb

from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy, precision
import torchmetrics.functional as tf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks import StochasticWeightAveraging
from positional_encodings.torch_encodings import PositionalEncodingPermute2D, Summer

import data.coco_cat as cc
from data.coco_data_module import COCODataModule
from data.cutmix import CutMixCriterion
from q2l_labeller.models.simple_asymmetric_loss import AsymmetricLoss

In [3]:
class ResNetBackbone(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.resnet = model
        del self.resnet.fc

    def forward(self, x):
        
        out = self.resnet.conv1(x)
        out = self.resnet.bn1(out)
        out = self.resnet.relu(out)
        out = self.resnet.maxpool(out)

        out = self.resnet.layer1(out)
        out = self.resnet.layer2(out)
        out = self.resnet.layer3(out)
        out = self.resnet.layer4(out)
        
        return out

In [4]:
class TimmBackbone(nn.Module):
    def __init__(self, model_name):
        super().__init__()

        # Creating the model in this way produces unpooled, unclassified features
        self.model = timm.create_model(
            model_name, 
            pretrained=True,
            num_classes=0,
            global_pool='')        

    def forward(self, x):
        
        out = self.model(x)
        
        return out

In [5]:
class Query2Label(nn.Module):
    def __init__(
        self, model, conv_out, num_classes, hidden_dim=256, nheads=8, 
        encoder_layers=6, decoder_layers=6, use_pos_encoding=False):
        
        super().__init__()

        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.use_pos_encoding = use_pos_encoding

        self.backbone = TimmBackbone(model) # outputs HW x 
        self.conv = nn.Conv2d(conv_out, hidden_dim, 1)
        self.transformer = nn.Transformer(
            hidden_dim, nheads, encoder_layers, decoder_layers)

        # prediction head
        self.classifier = nn.Linear(num_classes * hidden_dim, num_classes)

        # label parameters
        # TODO: Rename 
        self.query_pos = nn.Parameter(torch.rand(1, num_classes, hidden_dim))

    def forward(self, x):
        # produces output of shape [N x C x H x W]
        out = self.backbone(x)
        
        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(out)
        B, C, H, W = h.shape

        # add position encodings
        if self.use_pos_encoding:
            # returns the encoding object
            pos_encoder = PositionalEncodingPermute2D(C)

            # returns the summing object
            encoding_adder = Summer(pos_encoder)

            # input with encoding added
            h = encoding_adder(x)

        # convert h from [N x C x H x W] to [H*W x N x C] (N=batch size)
        # this corresponds to the [SIZE x BATCH_SIZE x EMBED_DIM] dimensions 
        # that the transformer expects
        h = h.flatten(2).permute(2, 0, 1)
        
        # image feature vector "h" is sent in after transformation above; we 
        # also convert query_pos from [1 x TARGET x (hidden)EMBED_SIZE] to 
        # [TARGET x BATCH_SIZE x (hidden)EMBED_SIZE]
        query_pos = self.query_pos.repeat(B, 1, 1)
        query_pos = query_pos.transpose(0, 1)
        h = self.transformer(h, query_pos).transpose(0, 1)
        
        # output from transformer is of dim [TARGET x BATCH_SIZE x EMBED_SIZE];
        # however, we transpose it to [BATCH_SIZE x TARGET x EMBED_SIZE] above
        # and then take an average along the TARGET dimension.
        #
        # next, we project transformer outputs to class labels
        #h = h.mean(1)
        h = torch.reshape(h,(B, self.num_classes * self.hidden_dim))

        return self.classifier(h)

In [6]:
class Query2LabelTrainModule(pl.LightningModule):
    def __init__(
        self, data, backbone_desc, conv_out_dim, hidden_dim, num_encoders, 
        num_decoders, num_heads, batch_size, image_dim, learning_rate, 
        momentum, weight_decay, n_classes, thresh=0.5, use_cutmix=False,
        use_pos_encoding=False, loss="BCE"):
        super().__init__()

        # Key parameters
        self.save_hyperparameters(ignore=['model','data'])
        self.data = data
        self.model = Query2Label(
            model=backbone_desc, conv_out=conv_out_dim, num_classes=n_classes, 
            hidden_dim=hidden_dim, nheads=num_heads, encoder_layers=num_encoders, 
            decoder_layers=num_decoders, use_pos_encoding=use_pos_encoding)
        if loss=="BCE":
            self.base_criterion = nn.BCEWithLogitsLoss()
        elif loss=="ASL":
            self.base_criterion = AsymmetricLoss(gamma_neg=1, gamma_pos=1) 
        
        self.criterion = CutMixCriterion(self.base_criterion)

    def forward(self, x):
        x = self.model(x)        
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = self.base_criterion(y_hat, y.type(torch.float))
        
        rmap = tf.retrieval_average_precision(y_hat, y.type(torch.int))

        category_prec = precision(y_hat, y.type(torch.int), average='macro', num_classes=self.hparams.n_classes, threshold=self.hparams.thresh, multiclass=False)
        category_recall = tf.recall(y_hat, y.type(torch.int), average='macro', num_classes=self.hparams.n_classes, threshold=self.hparams.thresh, multiclass=False)
        category_f1 = tf.f1_score(y_hat, y.type(torch.int), average='macro', num_classes=self.hparams.n_classes, threshold=self.hparams.thresh, multiclass=False)

        overall_prec = precision(y_hat, y.type(torch.int), threshold=self.hparams.thresh, multiclass=False)
        overall_recall = tf.recall(y_hat, y.type(torch.int), threshold=self.hparams.thresh, multiclass=False)
        overall_f1 = tf.f1_score(y_hat, y.type(torch.int), threshold=self.hparams.thresh, multiclass=False)
        
        
        

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_rmap", rmap, prog_bar=True, on_step=False, on_epoch=True)

            self.log(f"{stage}_cat_prec", category_prec, prog_bar=True)
            self.log(f"{stage}_cat_recall", category_recall, prog_bar=True)
            self.log(f"{stage}_cat_f1", category_f1, prog_bar=True)

            self.log(f"{stage}_ovr_prec", overall_prec, prog_bar=True)
            self.log(f"{stage}_ovr_recall", overall_recall, prog_bar=True)
            self.log(f"{stage}_ovr_f1", overall_f1, prog_bar=True)
            

            # log prediction examples to wandb
            '''
            pred = self.model(x)
            pred_keys = pred[0].sigmoid().tolist()
            pred_keys = [0 if p < self.hparams.thresh else 1 for p in pred_keys]

            
            mapper = cc.COCOCategorizer()
            pred_lbl = mapper.get_labels(pred_keys)
            
            try:
                self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=pred_lbl)]})
            except AttributeError:
                pass
            '''
    
    def training_step(self, batch, batch_idx):
        if self.hparams.use_cutmix:
            x, y = batch
            y_hat = self(x)
            #y1, y2, lam = y
            loss = self.criterion(y_hat, y)

        else:
            x, y = batch
            y_hat = self(x)
            loss = self.base_criterion(y_hat, y.type(torch.float))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.hparams.learning_rate,
            betas=(0.9,0.999),
            weight_decay=self.hparams.weight_decay)
        
        lr_scheduler_dict = {
            "scheduler":OneCycleLR(
                optimizer,
                self.hparams.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=len(self.data.train_dataloader()),
                anneal_strategy='cos'
            ),
            "interval":"step",
        }
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler_dict}
        #return optimizer

In [7]:
pl.seed_everything(42)

Global seed set to 42


42

In [8]:
param_dict = {
    "backbone_desc":"tf_efficientnet_b7_ns",
    "conv_out_dim":2560,
    "hidden_dim":256,
    "num_encoders":6,
    "num_decoders":6,
    "num_heads":8,
    "batch_size":16,
    "image_dim":448,
    "learning_rate":0.0001, 
    "momentum":0.9,
    "weight_decay":0.01, 
    "n_classes":80,
    "thresh":0.5,
    "use_cutmix":True,
    "use_pos_encoding":False,
    "loss":"BCE"
}

In [9]:
coco = COCODataModule(
    img_data_dir,
    img_size=param_dict["image_dim"],
    batch_size=param_dict["batch_size"],
    num_workers=24,
    use_cutmix=param_dict["use_cutmix"],
    cutmix_alpha=1.0)
param_dict["data"] = coco

In [10]:
pl_model = Query2LabelTrainModule(**param_dict)

In [11]:
wandb_logger = WandbLogger(project="coco-labeling")
wandb_logger.watch(pl_model, log="all")

trainer = pl.Trainer(
    max_epochs=24,
    precision=16,
    accelerator='gpu', 
    devices=1,
    logger=wandb_logger,
    profiler="simple",
    callbacks=[TQDMProgressBar(refresh_rate=10)])
trainer.fit(pl_model, param_dict["data"])

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


loading annotations into memory...
Done (t=6.54s)
creating index...
index created!
loading annotations into memory...
Done (t=4.07s)
creating index...
index created!


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params
-----------------------------------------------------
0 | model          | Query2Label       | 83.5 M
1 | base_criterion | BCEWithLogitsLoss | 0     
-----------------------------------------------------
83.5 M    Trainable params
0         Non-trainable params
83.5 M    Total params
166.931   Total estimated model params size (MB)


Epoch 3:   2%|▏         | 170/7706 [2:28:09<109:27:38, 52.29s/it, loss=0.0909, v_num=cb04, val_loss=0.0503, val_rmap=0.825, val_cat_prec=0.551, val_cat_recall=0.502, val_cat_f1=0.516, val_ovr_prec=0.916, val_ovr_recall=0.560, val_ovr_f1=0.692] 

wandb: Network error (ReadTimeout), entering retry loop.


Epoch 17:   4%|▎         | 280/7706 [14:26:04<382:49:33, 185.59s/it, loss=0.0565, v_num=cb04, val_loss=0.0438, val_rmap=0.878, val_cat_prec=0.695, val_cat_recall=0.673, val_cat_f1=0.674, val_ovr_prec=0.865, val_ovr_recall=0.761, val_ovr_f1=0.808]  

wandb: Network error (ReadTimeout), entering retry loop.


Epoch 18:   9%|▉         | 700/7706 [15:22:24<153:52:02, 79.06s/it, loss=0.0683, v_num=cb04, val_loss=0.0434, val_rmap=0.879, val_cat_prec=0.696, val_cat_recall=0.673, val_cat_f1=0.674, val_ovr_prec=0.867, val_ovr_recall=0.760, val_ovr_f1=0.808]   

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
