# Objective: fine-tuning DETR

-   This notebook can be found on my Github profile: https://github.com/woctezuma/finetune-detr
-   Official DETR repository: https://github.com/facebookresearch/detr
-   Discussion about fine-tuning in [a Github issue](https://github.com/facebookresearch/detr/issues/9).
-   A nice blog post about another approach (Mask R-CNN) and the balloon dataset (which we use in this notebook): [here](https://engineering.matterport.com/splash-of-color-instance-segmentation-with-mask-r-cnn-and-tensorflow-7c761e238b46).

## Define useful boilerplate functions

Adapted from:
-   https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb

In [1]:
# Mount the google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install -U torch==1.7.0
!pip install -U torchvision==0.8.1 
!pip install -U torchtext==0.8.0

Collecting torch==1.7.0
  Downloading torch-1.7.0-cp37-cp37m-manylinux1_x86_64.whl (776.7 MB)
[K     |████████████████████████████████| 776.7 MB 3.1 kB/s 
Collecting dataclasses
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Installing collected packages: dataclasses, torch
  Attempting uninstall: torch
    Found existing installation: torch 1.9.0+cu102
    Uninstalling torch-1.9.0+cu102:
      Successfully uninstalled torch-1.9.0+cu102
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.10.0+cu102 requires torch==1.9.0, but you have torch 1.7.0 which is incompatible.
torchtext 0.10.0 requires torch==1.9.0, but you have torch 1.7.0 which is incompatible.[0m
Successfully installed dataclasses-0.6 torch-1.7.0
Collecting torchvision==0.8.1
  Downloading torchvision-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (12.7 MB)
[K     |█████████████████

In [3]:
! ls /content

drive  sample_data


In [4]:
import torch, torchvision
print(torch.__version__, torchvision.__version__, torch.cuda.is_available())

torch.set_grad_enabled(False);

1.7.0 0.8.1 True


In [5]:
! cp '/content/drive/My Drive/eva6_capstone_final/engineering_processed_dataset.tar.gz' .
! cp '/content/drive/My Drive/eva6_capstone_final/rgba_train_images.tar.gz' .
! cp '/content/drive/My Drive/eva6_capstone_final/panoptic_augmented_train_dataset_summarized.tar.gz' .
! cp '/content/drive/My Drive/eva6_capstone_final/panoptic_train2017.json' .
! cp '/content/drive/My Drive/eva6_capstone_final/panoptic_val2017.json' .
! cp '/content/drive/My Drive/eva6_capstone_final/new_dataset_labels.txt' .


In [6]:
! tar -zxvf engineering_processed_dataset.tar.gz  > /dev/null
! rm engineering_processed_dataset.tar.gz
! tar -zxvf rgba_train_images.tar.gz  > /dev/null
! rm rgba_train_images.tar.gz
! tar -zxvf panoptic_augmented_train_dataset_summarized.tar.gz  > /dev/null
! rm panoptic_augmented_train_dataset_summarized.tar.gz

In [7]:
! cp '/content/drive/My Drive/eva6_capstone_final/rev_panoptic_train2017.json' .
! cp '/content/drive/My Drive/eva6_capstone_final/random1_panoptic_train2017.json' .
! cp '/content/drive/My Drive/eva6_capstone_final/random2_panoptic_train2017.json' .

In [8]:
! ls

annotations			     random1_panoptic_train2017.json
dataset_labels_super_categories.txt  random2_panoptic_train2017.json
drive				     rev_panoptic_train2017.json
new_dataset_labels.txt		     sample_data
panoptic_train2017.json		     train2017
panoptic_val2017.json		     val2017


In [9]:
!mkdir -p /content/data/custom/
!mv /content/annotations /content/data/custom/
!mv /content/train2017 /content/data/custom/
!mv /content/val2017 /content/data/custom/
!cp /content/new_dataset_labels.txt /content/data/custom/
! cp '/content/panoptic_train2017.json' /content/data/custom/annotations
! cp '/content/panoptic_val2017.json' /content/data/custom/annotations

In [10]:
! ls -l /content/data/custom/annotations

total 164424
-rw-r--r-- 1 root root  15819788 Sep  8 01:11 custom_train.json
-rw-r--r-- 1 root root   1182186 Sep  8 01:11 custom_val.json
-rw------- 1 root root 135946179 Oct  2 11:25 panoptic_train2017.json
-rw------- 1 root root  15415494 Oct  2 11:25 panoptic_val2017.json


In [11]:
#from google.colab import files
#uploaded = files.upload()


In [12]:
import json
from PIL import Image

with open('/content/data/custom/annotations/panoptic_train2017.json') as f1:
    train_json_content = json.load(f1)

with open('/content/data/custom/annotations/panoptic_val2017.json') as f2:
    valid_json_content = json.load(f2)

In [13]:
import torchvision.transforms as T

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

In [14]:
def filter_bboxes_from_outputs(outputs,
                               threshold=0.7):
  
  # keep only predictions with confidence above threshold
  probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
  keep = probas.max(-1).values > threshold

  probas_to_keep = probas[keep]

  # convert boxes from [0; 1] to image scales
  bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
  
  return probas_to_keep, bboxes_scaled

In [15]:
# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

In [16]:
import matplotlib.pyplot as plt

def plot_results(pil_img, prob=None, boxes=None):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    if prob is not None and boxes is not None:
      for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
          ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                    fill=False, color=c, linewidth=3))
          cl = p.argmax()
          text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
          ax.text(xmin, ymin, text, fontsize=15,
                  bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

## Load an image for a demo

NB: if the threshold is equal to zero, then you will see all of the 100 query slots. The zero-threshold is only used for illustration. In usual cases, most query slots have a low confidence score, so that irrelevant query slots would be pruned with a higher threshold, such as 0.7 or 0.9.

Reference: https://github.com/facebookresearch/detr/issues/9#issuecomment-635357693

NB²: For fine-tuning purposes, we cannot change the number of query slots.

> If you're fine-tuning, I don't recommend changing the number of queries on the fly, it is extremely unlikely to work out of the box. In this case you're probably better off retraining from scratch (you can change the --num_queries arg from our training script).

Reference: https://github.com/facebookresearch/detr/issues/9#issuecomment-636407752

## Clone my custom code of DETR

Clone [my fork](https://github.com/woctezuma/detr/tree/finetune) tailored for a custom dataset:
-   called `custom`,
-   with `max_class_id = 2` ([explanation](https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223)).

**Caveat**: if you later use `first_class_index = 1` for your dataset, then it is fine. However, if you later use `first_class_index = 0`, then you will have to override the value of `max_class_id` (so that it is equal to 1) when calling `!python main.py`. My fork can do that.


In [17]:
%cd /content/

!rm -rf detr
!git clone https://github.com/woctezuma/detr.git

%cd detr/

!git checkout finetune

/content
Cloning into 'detr'...
remote: Enumerating objects: 239, done.[K
remote: Total 239 (delta 0), reused 0 (delta 0), pack-reused 239[K
Receiving objects: 100% (239/239), 284.61 KiB | 9.81 MiB/s, done.
Resolving deltas: 100% (131/131), done.
/content/detr
Branch 'finetune' set up to track remote branch 'finetune' from 'origin'.
Switched to a new branch 'finetune'


In [18]:
! pip install git+https://github.com/cocodataset/panopticapi.git

Collecting git+https://github.com/cocodataset/panopticapi.git
  Cloning https://github.com/cocodataset/panopticapi.git to /tmp/pip-req-build-yt89gahf
  Running command git clone -q https://github.com/cocodataset/panopticapi.git /tmp/pip-req-build-yt89gahf
Building wheels for collected packages: panopticapi
  Building wheel for panopticapi (setup.py) ... [?25l[?25hdone
  Created wheel for panopticapi: filename=panopticapi-0.1-py3-none-any.whl size=8306 sha256=4a4cdd915bd9c7804cc25b4782dfa9280c25a8034147776f55fd0b603167fd7c
  Stored in directory: /tmp/pip-ephem-wheel-cache-r8rwbh26/wheels/ad/89/b8/b66cce9246af3d71d65d72c85ab993fd28e7578e1b0ed197f1
Successfully built panopticapi
Installing collected packages: panopticapi
Successfully installed panopticapi-0.1


## Prepare the dataset for fine-tuning

The `balloon` dataset will be used. It is featured here and uses VIA format:
-   https://github.com/matterport/Mask_RCNN/tree/master/samples/balloon


You can choose whether to start indexing categories with 0 or with 1.

This is a matter of taste, and it should not impact the performance of the algorithm.

We expect the directory structure to be the following:
```
path/to/coco/
├ annotations/  # JSON annotations
│  ├ annotations/custom_train.json
│  └ annotations/custom_val.json
├ train2017/    # training images
└ val2017/      # validation images
```

## Check the dataset after it was pre-processed for fine-tuning

To verify the data loading is correct, let's visualize the annotations of randomly selected samples in the training set:
-   Demo of COCO API: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoDemo.ipynb

In [19]:
%matplotlib inline
import pycocotools.coco as coco
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab
pylab.rcParams['figure.figsize'] = (10.0, 8.0)

## Fine-tuning

-   Instructions appear in [a Github Gist](https://gist.github.com/woctezuma/e9f8f9fe1737987351582e9441c46b5d).

NB: There is a `--frozen_weights` argument. However,
i) I have yet to figure out how it is used,
ii) it is of no use for box detection. Indeed, "frozen training is meant for segmentation only" (as mentioned at this [line](https://github.com/facebookresearch/detr/blob/f4cdc542de34de771da8b9189742e5465f5220cd/main.py#L110) of the source-code).

### Boilerplate variables

**Caveat**: the parameter name `num_classes` is misleading. It is actually the ID which DETR will reserve for **its own** `no_object` class.

It should be set to one plus the highest class ID in your dataset.

For instance, if you have one class (balloon):
- if you used the index n°0 for this class, then `max_id = 0` and `num_classes = max_id+1 = 1`
- if you used the index n°1 for this class, then `max_id = 1` and `num_classes = max_id+1 = 2`

Reference: https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223

In [20]:
first_class_index = 0

In [21]:
!ls /content

data				     panoptic_val2017.json
dataset_labels_super_categories.txt  random1_panoptic_train2017.json
detr				     random2_panoptic_train2017.json
drive				     rev_panoptic_train2017.json
new_dataset_labels.txt		     sample_data
panoptic_train2017.json


In [22]:
categoryObjList = []
categoryNameList = []

lines = []
with open('/content/data/custom/new_dataset_labels.txt') as f:
    lines = f.readlines()

for line in lines:
    fields = line.split('.')
    category_id = int(fields[0].strip())
    categoryName = fields[1].strip()
    SubCategoryName = 'N/A'
    categoryNameList.append(categoryName)
    categoryObj = {
        "id": category_id, 
        "name":categoryName , 
        "supercategory": SubCategoryName
    }
    categoryObjList.append(categoryObj)

In [23]:
assert(first_class_index in [0, 1])

if first_class_index == 0:

  # There is one class, balloon, with ID n°0.

  num_classes = 64

  finetuned_classes = categoryNameList


  # The `no_object` class will be automatically reserved by DETR with ID equal
  # to `num_classes`, so ID n°1 here.  

else:

  # There is one class, balloon, with ID n°1.
  #
  # However, DETR assumes that indexing starts with 0, as in computer science,
  # so there is a dummy class with ID n°0.
  # Caveat: this dummy class is not the `no_object` class reserved by DETR.

  num_classes = 65
  finetuned_classes = ['N/A']
  finetuned_classes.extend(categoryNameList)

  # The `no_object` class will be automatically reserved by DETR with ID equal
  # to `num_classes`, so ID n°2 here.

print('First class index: {}'.format(first_class_index))  
print('Parameter num_classes: {}'.format(num_classes))
print('Fine-tuned classes: {}'.format(finetuned_classes))

First class index: 0
Parameter num_classes: 64
Fine-tuned classes: ['misc_stuff', 'aac_blocks', 'adhesives', 'ahus', 'aluminium_frames_for_false_ceiling', 'chiller', 'concrete_mixer_machine', 'concrete_pump_(50%)', 'control_panel', 'cu_piping', 'distribution_transformer', 'dump_truck___tipper_truck', 'emulsion_paint', 'enamel_paint', 'fine_aggregate', 'fire_buckets', 'fire_extinguishers', 'glass_wool', 'grader', 'hoist', 'hollow_concrete_blocks', 'hot_mix_plant', 'hydra_crane', 'interlocked_switched_socket', 'junction_box', 'lime', 'marble', 'metal_primer', 'pipe_fittings', 'rcc_hume_pipes', 'refrigerant_gas', 'river_sand', 'rmc_batching_plant', 'rmu_units', 'sanitary_fixtures', 'skid_steer_loader_(bobcat)', 'smoke_detectors', 'split_units', 'structural_steel_-_channel', 'switch_boards_and_switches', 'texture_paint', 'threaded_rod', 'transit_mixer', 'vcb_panel', 'vitrified_tiles', 'vrf_units', 'water_tank', 'wheel_loader', 'wood_primer', 'building', 'ceiling', 'floor', 'food', 'furnitu

In [24]:
%cd /content/detr/

/content/detr


In [25]:
! ls

d2	    engine.py	main.py    requirements.txt	 tox.ini
datasets    hubconf.py	models	   run_with_submitit.py  util
Dockerfile  LICENSE	README.md  test_all.py


**Caveat**: below, we override the value of `num_classes` (hard-coded to 2 for the `custom` dataset in my `finetune` branch of DETR) in case `first_class_index = 0` instead of `first_class_index = 1` (default value).

In [26]:
! pwd

/content/detr


In [27]:
! cp '/content/drive/My Drive/eva6_source_code/detr.py' models/detr.py
! cp '/content/drive/My Drive/eva6_source_code/coco.py' datasets/coco.py
! cp '/content/drive/My Drive/eva6_source_code/coco_panoptic.py' datasets/coco_panoptic.py
! cp '/content/drive/My Drive/eva6_source_code/engine.py' .
! cp '/content/drive/My Drive/eva6_source_code/main.py' .

In [None]:
! python main.py \
  --masks \
  --dataset_file "coco_panoptic" \
  --coco_path "/content/data/custom/" \
  --coco_panoptic_path "/content/data/custom/" \
  --output_dir "/content/drive/My Drive/eva6_capstone_final_summarized_result/outputs_panoptic2" \
  --frozen_weights "/content/drive/My Drive/eva6_capstone_final_summarized_result/outputs_bbox1/checkpoint_49_epoch.pth" \
  --num_classes $num_classes \
  --epochs 50 \
  --lr 1e-5 \
  --lr_drop 1000 \
  --num_queries 30 \
  --batch_size 1


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: [9]  [11090/15234]  eta: 0:13:31  lr: 0.000010  class_error: 0.00  loss: 15.3800 (12.3283)  loss_ce: 0.0365 (0.1450)  loss_bbox: 0.2262 (0.2196)  loss_giou: 1.6551 (1.4279)  loss_mask: 0.2368 (0.3571)  loss_dice: 1.4067 (1.2126)  loss_ce_0: 0.0551 (0.1795)  loss_bbox_0: 0.2047 (0.2160)  loss_giou_0: 1.6719 (1.3897)  loss_ce_1: 0.0470 (0.1630)  loss_bbox_1: 0.2260 (0.2181)  loss_giou_1: 1.6606 (1.4118)  loss_ce_2: 0.0379 (0.1522)  loss_bbox_2: 0.2261 (0.2190)  loss_giou_2: 1.6540 (1.4256)  loss_ce_3: 0.0362 (0.1470)  loss_bbox_3: 0.2261 (0.2192)  loss_giou_3: 1.6433 (1.4316)  loss_ce_4: 0.0371 (0.1459)  loss_bbox_4: 0.2261 (0.2192)  loss_giou_4: 1.6455 (1.4282)  loss_ce_unscaled: 0.0365 (0.1450)  class_error_unscaled: 0.0000 (25.6222)  loss_bbox_unscaled: 0.9046 (0.8783)  loss_giou_unscaled: 0.8276 (0.7139)  cardinality_error_unscaled: 0.0000 (0.6138)  loss_mask_unscaled: 0.0395 (0.0595)  loss_dice_unscaled: 0.7034 