In [1]:
# Create dataset class

import torch
import json
from pathlib import Path
from PIL import Image

class CocoPanoptic(torch.utils.data.Dataset):
    def __init__(self, img_folder, ann_folder, ann_file, feature_extractor):
        with open(ann_file, 'r') as f:
            self.coco = json.load(f)

        # sort 'images' field so that they are aligned with 'annotations'
        # i.e., in alphabetical order
        self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id'])
        # sanity check
        if "annotations" in self.coco:
            for img, ann in zip(self.coco['images'], self.coco['annotations']):
                assert img['file_name'][:-4] == ann['file_name'][:-4]

        self.img_folder = img_folder
        self.ann_folder = Path(ann_folder)
        self.ann_file = ann_file
        self.feature_extractor = feature_extractor

    def __getitem__(self, idx):
        ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx]
        img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg')

        img = Image.open(img_path).convert('RGB')
        
        # preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
        encoding = self.feature_extractor(images=img, annotations=ann_info, masks_path=self.ann_folder, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
        target = encoding["labels"][0] # remove batch dimension

        return pixel_values, target

    def __len__(self):
        return len(self.coco['images'])

  from .autonotebook import tqdm as notebook_tqdm


In [48]:
# Create dataset class using the paths to the images and masks
from transformers import DetrFeatureExtractor
import numpy as np
import os

# we reduce the size and max_size to be able to fit the batches in GPU memory
feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50-panoptic", size=500, max_size=600)

def get_folder_paths(subset, device='HPC'):
    if device == 'HPC':
        root = fr"../DSAD4DeTr_multilabel/subset/images"
    else:
        root = fr"C:\Users\jayan\Documents\MECHATRONICS YR4\MECH5845M - Professional Project\DSAD4DeTr_multilabel"
    img_folder = os.path.join(root, subset, 'images')
    ann_folder = os.path.join(root, subset, 'masks')
    ann_file = os.path.join(root, subset, 'annotations', f"{subset}_panoptic_annotations.json")
    return [img_folder, ann_folder, ann_file]

train_paths = get_folder_paths('train', device='CPU')
test_paths = get_folder_paths('test', device='CPU')
val_paths = get_folder_paths('val', device='CPU')

train_dataset = CocoPanoptic(img_folder=train_paths[0], ann_folder=train_paths[1], ann_file=train_paths[2], feature_extractor=feature_extractor)
test_dataset = CocoPanoptic(img_folder=test_paths[0], ann_folder=test_paths[1], ann_file=test_paths[2], feature_extractor=feature_extractor)
val_dataset = CocoPanoptic(img_folder=val_paths[0], ann_folder=val_paths[1], ann_file=val_paths[2], feature_extractor=feature_extractor)

In [44]:
pixel_values, target = train_dataset[2]
print(pixel_values.shape)
print(target.keys())

torch.Size([3, 419, 599])
dict_keys(['size', 'image_id', 'orig_size', 'masks', 'boxes', 'class_labels', 'iscrowd', 'area'])




In [45]:
print("Number of training examples:", len(train_dataset))
print("Number of test examples:", len(test_dataset))
print("Number of validation examples:", len(val_dataset))

Number of training examples: 863
Number of test examples: 365
Number of validation examples: 202


In [25]:
# Define custom collate function to batch images and labels together

from torch.utils.data import DataLoader

def collate_fn(batch):
  pixel_values = [item[0] for item in batch]
  encoded_input = feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
  labels = [item[1] for item in batch]
  batch = {}
  batch['pixel_values'] = encoded_input['pixel_values']
  batch['pixel_mask'] = encoded_input['pixel_mask']
  batch['labels'] = labels
  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=3, shuffle=True)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=1)

In [26]:
# Define model configuration
# Define training configuration
n_devices      = 1
epochs         = 100 # 20000 is maximum value but will be prevented by early stopping
weight_decay   = 1e-4
learning_rate  = 1e-4
learning_rate_backbone = 1e-5
check_val_every_n_epoch = 5
load_from_checkpoint = False
checkpoint_path = None
last_manual_checkpoint = 1
last_epoch = 0
## LEARNING_SCHEDULER parameters for ReduceLROnPlateau from pytorch
factor          = 1e-1
lr_patience        = 10
lr_delta           = 1e-5
lr_monitored_var   = "training_loss"
min_lr          = 1e-8
cooldown        = 5
### Parameters for FixedStep from pytorch
fix_step   = False
step_size  = 60
## STOP_CRITERIA parameters for EarlyStopping from pytorch_lightning
stop_monitored_var    = "validation_loss"
stop_delta            = 1e-5
mode             = "min"
stop_patience         = 10   ## Real_patient = patience * check_val_every_n_epoch
## Custom loss function
loss_tags = None
loss_components = []
loss_weights = []

In [34]:
import pytorch_lightning as pl
import torch

class DetrPanoptic(pl.LightningModule):
    def __init__(self, model, lr, lr_backbone, weight_decay):
        super().__init__()
    
        self.model = model

        # see https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
        self.lr = lr
        self.lr_backbone = lr_backbone
        self.weight_decay = weight_decay
    
    def forward(self, pixel_values, pixel_mask):
        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

        return outputs

    def common_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        pixel_mask = batch["pixel_mask"]
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

        loss = outputs.loss
        loss_dict = outputs.loss_dict

        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        for k,v in loss_dict.items():
            self.log("train_" + k, v.item())

        return loss

    def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)
        self.log("validation_loss", loss)
        for k,v in loss_dict.items():
            self.log("validation_" + k, v.item())

        return loss

    def configure_optimizers(self):
        param_dicts = [
            {"params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
            {
                "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
                "lr": self.lr_backbone,
            },
        ]
        optimizer = torch.optim.AdamW(param_dicts, lr=self.lr,
                                    weight_decay=self.weight_decay)
        
        learning_rate_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                    factor=factor, patience=lr_patience, threshold=lr_delta,
                                    cooldown=cooldown, min_lr=min_lr, verbose=True)

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': learning_rate_scheduler,
                'monitor': lr_monitored_var
            }
        }

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

In [51]:
# Initialise model
from transformers import DetrConfig, DetrForSegmentation

model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic", num_labels=11,
                                            ignore_mismatched_sizes=True)
state_dict = model.state_dict()
# For visualisation of initialised layers
# for name, param in state_dict.items():
#     print(name, param.shape)

# Remove class weights
del state_dict["detr.class_labels_classifier.weight"]
del state_dict["detr.class_labels_classifier.bias"]
del state_dict["detr.bbox_predictor.layers.0.weight"]
del state_dict["detr.bbox_predictor.layers.0.bias"]
del state_dict["detr.bbox_predictor.layers.1.weight"]
del state_dict["detr.bbox_predictor.layers.1.bias"]
del state_dict["detr.bbox_predictor.layers.2.weight"]
del state_dict["detr.bbox_predictor.layers.2.bias"]
# define new model with custom class classifier
# config = DetrConfig.from_pretrained("facebook/detr-resnet-50-panoptic", num_labels=11)
model.load_state_dict(state_dict, strict=False)


Some weights of DetrForSegmentation were not initialized from the model checkpoint at facebook/detr-resnet-50-panoptic and are newly initialized because the shapes did not match:
- detr.class_labels_classifier.weight: found shape torch.Size([251, 256]) in the checkpoint and torch.Size([12, 256]) in the model instantiated
- detr.class_labels_classifier.bias: found shape torch.Size([251]) in the checkpoint and torch.Size([12]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


_IncompatibleKeys(missing_keys=['detr.class_labels_classifier.weight', 'detr.class_labels_classifier.bias', 'detr.bbox_predictor.layers.0.weight', 'detr.bbox_predictor.layers.0.bias', 'detr.bbox_predictor.layers.1.weight', 'detr.bbox_predictor.layers.1.bias', 'detr.bbox_predictor.layers.2.weight', 'detr.bbox_predictor.layers.2.bias'], unexpected_keys=[])

In [41]:
# Define LightningModule and verify outputs on single batch
model = DetrPanoptic(model=model, lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4)

# pick the first training batch
batch = next(iter(train_dataloader))
# forward through the model
outputs = model(pixel_values=batch['pixel_values'], pixel_mask=batch['pixel_mask'])



In [42]:
print("Shape of pixel_values:", pixel_values.shape)
print("Shape of logits:", outputs.logits.shape)
print("Shape of predicted bounding boxes:", outputs.pred_boxes.shape)
print("Shape of predicted masks:", outputs.pred_masks.shape)

Shape of pixel_values: torch.Size([3, 419, 599])
Shape of logits: torch.Size([3, 100, 12])
Shape of predicted bounding boxes: torch.Size([3, 100, 4])
Shape of predicted masks: torch.Size([3, 100, 105, 150])
