### Data Prep

In [1]:
img_data_dir = "/media/curttigges/project-files/datasets/ms-coco/"

#### Data Preparation

In [2]:
from data.coco_data_module import COCODataModule



### Models

In [3]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb

from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy, precision
import torchmetrics.functional as tf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

import data.coco_cat as cc

In [4]:
class ResNetBackbone(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.resnet = model
        del self.resnet.fc

    def forward(self, x):
        
        out = self.resnet.conv1(x)
        out = self.resnet.bn1(out)
        out = self.resnet.relu(out)
        out = self.resnet.maxpool(out)

        out = self.resnet.layer1(out)
        out = self.resnet.layer2(out)
        out = self.resnet.layer3(out)
        out = self.resnet.layer4(out)
        
        return out

In [5]:
class TransformerDecoder(nn.Module):
    def __init__(self, embed_dims):
        super().__init__()

        self.label_embed = nn.Embedding()

        self.self_attn = nn.MultiheadAttention(embed_dims, )
        self.cross_attn = nn.MultiheadAttention()
        
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.norm3 = nn.LayerNorm()
        
        self.fc1 = nn.Linear()
        self.fc2 = nn.Linear()

        self.act = nn.ReLU()

    def forward(self, x):
        q = self.label_embed()
        #self attention
        label_out = self.self_attn(q, q, q)

        out, _ = self.cross_attn(x, x, label_out)
        out = self.norm2(out)
        out = self.ff(out)

        return out

In [6]:
class Query2Label(nn.Module):
    def __init__(self, model, num_classes, hidden_dim=256, nheads=8, encoder_layers=6, decoder_layers=6):
        super().__init__()

        self.num_classes = num_classes
        self.hidden_dim = hidden_dim

        self.backbone = ResNetBackbone(model) # outputs HW x 
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        self.transformer = nn.Transformer(
            hidden_dim, nheads, encoder_layers, decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        #self.linear_class = nn.Linear(hidden_dim, num_classes)
        self.fc1 = nn.Linear(num_classes * hidden_dim, num_classes)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(1, num_classes, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, x):
        out = self.backbone(x)
        # TODO: Add 2D sine and cosine embeddings

        B, _, _, _ = x.shape
        
        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(out)

        # convert h from [N x C x H x W] to [H*W x N x C] (N=batch size)
        # this corresponds to the [SIZE x BATCH_SIZE x EMBED_DIM] dimensions 
        # that the transformer expects
        h = h.flatten(2).permute(2, 0, 1)
        '''
        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # propagate through the transformer
        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1)).transpose(0, 1)
        '''
        # image feature vector "h" is sent in after transformation above; we 
        # also convert query_pos from [1 x TARGET x (hidden)EMBED_SIZE] to 
        # [TARGET x BATCH_SIZE x (hidden)EMBED_SIZE]
        query_pos = self.query_pos.repeat(B, 1, 1)
        query_pos = query_pos.transpose(0, 1)
        h = self.transformer(h, query_pos).transpose(0, 1)
        
        # output from transformer is of dim [TARGET x BATCH_SIZE x EMBED_SIZE];
        # however, we transpose it to [BATCH_SIZE x TARGET x EMBED_SIZE] above
        # and then take an average along the TARGET dimension.
        #
        # next, we project transformer outputs to class labels
        #h = h.mean(1)
        h = torch.reshape(h,(B, self.num_classes * self.hidden_dim))
        return self.fc1(h)
        
        #out = self.decoder(out)

    #def predict(self, x):
        

In [7]:
class Query2LabelTrainModule(pl.LightningModule):
    def __init__(
        self, model, data, backbone_desc, num_encoders, num_decoders, num_heads,
        batch_size, image_dim, learning_rate, momentum, n_classes, thresh=0.5):
        super().__init__()

        # Key parameters
        self.save_hyperparameters()
        self.data = data
        self.model = Query2Label(
            model=model, num_classes=n_classes, nheads=num_heads, 
            encoder_layers=num_encoders, decoder_layers=num_decoders)

    def forward(self, x):
        x = self.model(x)        
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y.type(torch.float))
        acc = accuracy(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        prec = precision(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        recall = tf.recall(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        f1_score = tf.f1_score(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        rmap = tf.retrieval_average_precision(y_hat, y.type(torch.int))

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)
            self.log(f"{stage}_prec", prec, prog_bar=True)
            self.log(f"{stage}_recall", recall, prog_bar=True)
            self.log(f"{stage}_f1_score", f1_score, prog_bar=True)
            self.log(f"{stage}_rmap", rmap, prog_bar=True)

            # log prediction examples to wandb
            
            pred = self.model(x)
            pred_keys = pred[0].sigmoid().tolist()
            pred_keys = [0 if p < self.hparams.thresh else 1 for p in pred_keys]

            mapper = cc.COCOCategorizer()
            pred_lbl = mapper.get_labels(pred_keys)
            
            try:
                self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=pred_lbl)]})
            except AttributeError:
                pass
            
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y.type(torch.float))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), 
            lr=self.hparams.learning_rate,
            betas=(0.9,0.999))
        
        lr_scheduler_dict = {
            "scheduler":OneCycleLR(
                optimizer,
                self.hparams.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=len(self.data.train_dataloader()),
                anneal_strategy='cos'
            ),
            "interval":"step",
        }
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler_dict}
        #return optimizer

In [8]:
param_dict = {
    "model":models.resnet50(pretrained=True),
    "backbone_desc":"resnet50",
    "num_encoders":6,
    "num_decoders":6,
    "num_heads":8,
    "batch_size":64,
    "image_dim":448,
    "learning_rate":0.0001, 
    "momentum":0.9, 
    "n_classes":80
}

In [9]:
coco = COCODataModule(
    img_data_dir,
    img_size=param_dict["image_dim"],
    batch_size=param_dict["batch_size"],
    num_workers=12)
param_dict["data"] = coco

In [10]:
pl_model = Query2LabelTrainModule(**param_dict)

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


In [11]:
wandb_logger = WandbLogger(project="coco-labeling")
wandb_logger.watch(pl_model, log="all")

trainer = pl.Trainer(
    max_epochs=12,
    precision=16,
    accelerator='gpu', 
    devices=1,
    logger=wandb_logger,
    profiler="simple",
    callbacks=[TQDMProgressBar(refresh_rate=10)])
trainer.fit(pl_model, param_dict["data"])

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")


loading annotations into memory...
Done (t=6.83s)
creating index...
index created!
loading annotations into memory...
Done (t=4.21s)
creating index...
index created!


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params
--------------------------------------
0 | model | Query2Label | 43.1 M
--------------------------------------
43.1 M    Trainable params
0         Non-trainable params
43.1 M    Total params
86.137    Total estimated model params size (MB)


Epoch 0:  39%|███▉      | 760/1927 [07:38<11:43,  1.66it/s, loss=0.0742, v_num=9zy7]

wandb: Network error (ReadTimeout), entering retry loop.


Epoch 11: 100%|██████████| 1927/1927 [3:30:39<00:00,  6.56s/it, loss=0.0216, v_num=9zy7, val_loss=0.0509, val_acc=0.986, val_prec=0.852, val_recall=0.737, val_f1_score=0.790, val_rmap=0.866]    


FIT Profiler Report

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                                    	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                       