# Setup and Introduction
PANet is a semantic image segmentation model developed in 2019 that utilizes path augmented regularization (PAR) alsongside other novel approaches to out perform predacessor networks by a large margin.

This notebook is a supplement aiming to give insight into the network training pipeline.
In this notebook, we will initialize all files and values, and then we'll perform a single training run on a single split of data.
We will then use that trained model to perform image segmentation on the test images to obtain qualitative results.

In [1]:
# @title Init Project Code, Imports, and Constants
!pip install torch-summary
!pip install torchinfo

!git init
!git remote add origin https://github.com/bryjen/PANet
!git fetch
!git reset --hard origin/master

!pip install sacred

import os
import tqdm
import torch
import shutil
import zipfile
import kagglehub
import torch.optim

import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn

from google.colab import drive
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from models.fewshot import FewShotSeg
from dataloaders.customized import voc_fewshot, coco_fewshot
from dataloaders.transforms import ToTensorNormalize
from dataloaders.transforms import Resize, DilateScribble
from util.metric import Metric
from util.utils import set_seed, CLASS_LABELS, get_bbox
from config import ex

DRIVE_MOUNTED = False

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl.metadata (18 kB)
Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0
[33mhint: Using 'master' as the name for the initial branch. This default branch name[m
[33mhint: is subject to change. To configure the initial branch name to use in all[m
[33mhint: [m
[33mhint: 	git config --global init.defaultBranch <name>[m
[33mhint: [m
[33mhint: Names commonly chosen instead of 'master' are 'main', 'trunk' and[m
[33mhint: 'development'. The just-created branch can be renamed via this command:[m
[33mhint: [m
[33mhint: 	git branch -m <name>[m
Initialized empty Git repository in /content/.git/
remote: 

In [2]:
# @title Pre-trained and Weights & Dataset augmentation files
# Here we download files required to setup the pipeline.
# We download pre-trained weights so that we can fine-tune the model.
# Additionally, we download additional annotations (scribbles + bounding boxes)
# and modify the downloaded dataset with these files.

# pre-trained weights
os.makedirs("/content/pretrained_model/", exist_ok=True)
!gdown 1TJZy_YkYwNMkdtECAlkWy8VuorPFjP1t --output "/content/pretrained_model/"

# dataset modification files
os.makedirs("/content/dataset_augmentations/", exist_ok=True)
!gdown 1ZP6FHiclSNk1nH0WxhCH-W_fwd3GnEZa --output "/content/dataset_augmentations/"
!gdown 1oGy-tg_Tv_-ZUQeiNtM5i4sK0lYB0MeL --output "/content/dataset_augmentations/"
!gdown 1R5G3BfQTAY4zWkazTsgFc3kJNQVjwi4M --output "/content/dataset_augmentations/"
!gdown 1D4oX0Ub6ObFO81yonWY4ehZ-tr5Hn5UK --output "/content/dataset_augmentations/"


path = kagglehub.dataset_download("ngan2710/voc-devkit-2007-and-2012")

scribble_aug_auto_zip_path = "/content/dataset_augmentations/ScribbleAugAuto.zip"
segmentation_class_aug_zip_path = "/content/dataset_augmentations/SegmentationClassAug.zip"
segmentation_object_aug_zip_path = "/content/dataset_augmentations/SegmentationObjectAug.zip"

extract_to = "/root/.cache/kagglehub/datasets/ngan2710/voc-devkit-2007-and-2012/versions/1/VOCdevkit/VOC2012/"

with zipfile.ZipFile(scribble_aug_auto_zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

with zipfile.ZipFile(segmentation_class_aug_zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

with zipfile.ZipFile(segmentation_object_aug_zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)


segmentation_replacement_zip_path = "/content/dataset_augmentations/Segmentation.zip"
imagesets_dir = "/root/.cache/kagglehub/datasets/ngan2710/voc-devkit-2007-and-2012/versions/1/VOCdevkit/VOC2012/ImageSets/"

shutil.rmtree(f"{imagesets_dir}Segmentation", ignore_errors=True)

with zipfile.ZipFile(segmentation_replacement_zip_path, 'r') as zip_ref:
    zip_ref.extractall(imagesets_dir)

Downloading...
From (original): https://drive.google.com/uc?id=1TJZy_YkYwNMkdtECAlkWy8VuorPFjP1t
From (redirected): https://drive.google.com/uc?id=1TJZy_YkYwNMkdtECAlkWy8VuorPFjP1t&confirm=t&uuid=46efc206-4ef4-46c7-a936-2eb3db5134b7
To: /content/pretrained_model/vgg16-397923af.pth
100% 553M/553M [00:09<00:00, 60.0MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1ZP6FHiclSNk1nH0WxhCH-W_fwd3GnEZa
From (redirected): https://drive.google.com/uc?id=1ZP6FHiclSNk1nH0WxhCH-W_fwd3GnEZa&confirm=t&uuid=fd29483a-8768-4c58-a63a-69150a3d0232
To: /content/dataset_augmentations/ScribbleAugAuto.zip
100% 40.5M/40.5M [00:01<00:00, 36.6MB/s]
Downloading...
From: https://drive.google.com/uc?id=1oGy-tg_Tv_-ZUQeiNtM5i4sK0lYB0MeL
To: /content/dataset_augmentations/Segmentation.zip
100% 137k/137k [00:00<00:00, 4.10MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1R5G3BfQTAY4zWkazTsgFc3kJNQVjwi4M
From (redirected): https://drive.google.com/uc?id=1R5G3BfQTAY4zWkazTsgFc3kJN

# Training and Testing

Recall that we split the dataset into 4 splits. Below, we use 1-way, 5-shot learning to train a model, utilizing PAR, using dense (strong) annotations.
Afterwards, we evaluate the model on the remainder of the data to obtain binary-IoU and mean-IoU metrics.

In the following sections, we will be using data from both training and testing to replicate the figures on the original paper.



****
**NOTE** that below, we only perform 1 'run'.
From experience, a singular run (training and testing) takes roughly 1h 30min to 2hrs to run using an A100 GPU. In total, there are 72 of such 'runs'.

As such, for simplicity and to be able to demonstrate that the pipeline works, only one 'run' will be used. If you wish to perform the **full** experiment, another `.ipynb` in the repository contains the code to do so.
****

In [3]:
os.chdir("/content/")
!python train.py with mode='train' dataset='VOC' label_sets=0 model.align=True task.n_ways=1 task.n_shots=5

INFO - PANet - Running command 'main'
INFO - PANet - Started run with ID "1"
INFO - main - ###### Create model ######
INFO - main - ###### Load data ######
INFO - main - 
INFO - main - ###### Training Config ######
INFO - main - n-ways:		1
INFO - main - n-shots:		5
INFO - main - n-queries:	1
INFO - main - PAR?:		True
INFO - main - 
INFO - main - ###### Set optimizer ######
INFO - main - ###### Training ######
0it [00:00, ?it/s]
ERROR - PANet - Failed after 0:00:02!
Traceback (most recent calls WITHOUT Sacred internals):
  File "/content/train.py", line 116, in main
    for i_iter, sample_batched in tqdm.tqdm(enumerate(trainloader)):
  File "/usr/local/lib/python3.11/dist-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 708, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 14

In [4]:
os.chdir("/content/")
!python test.py with mode='test' snapshot='./runs/PANet_VOC_align_sets_0_1way_5shot_[train]/1/snapshots/30000.pth'

object address  : 0x7b2834117dc0
object refcount : 2
object type     : 0x9d5ea0
object type name: KeyboardInterrupt
object repr     : KeyboardInterrupt()
lost sys.stderr
^C


In [5]:
import os
import shutil
from google.colab import drive


if DRIVE_MOUNTED is False:
  drive.mount('/content/drive')
  DRIVE_MOUNTED = True

final_model_weights = "/content/runs/PANet_VOC_align_sets_0_1way_5shot_[train]/1/snapshots/30000.pth"

if DRIVE_MOUNTED is True:
  src = "/content/runs/"
  dst = "/content/drive/MyDrive/PANet/Results/"
  shutil.copytree(src, dst, dirs_exist_ok=True)
  final_model_weights = "/content/drive/MyDrive/PANet/Results/PANet_VOC_align_sets_0_1way_5shot_[train]/1/snapshots/30000.pth"

Mounted at /content/drive


# Visualizing Evaluation Results
We have only finished one training run.
That in itself is sufficient to obtain a model that can perform semantic segmentation.


However, in order to obtain the following:
- Mean-IoU metrics for 1-shot and 5-shot learning
- Binary-IoU metrics for 1-shot and 5-shot learning
- Loss per iteration for 1-shot/5-shot learning w/ and w/out PAR
- Mean-IoU metrics for models w/ and w/out PAR
- Weak annotation metrics for 1-shot/5-shot models

, which can be seen in the report, you would need to complete the remaining 71 runs. The snapshots generated by the training/testing contain loss and performance metrics. Analyzing results from these files are simple enough.

For more details, check the aforementioned `full` notebook

In [6]:
# @title Helper Functions

import matplotlib.pyplot as plt
import matplotlib.patches as patches


def show_segmentation(idx, sample, query_pred) :
  """ Plots a support and query set with the predicted mask.
  """
  fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,6), gridspec_kw={'wspace':0.3, 'hspace':0.2})

  # display configs
  do_threshold = True
  threshold = 0.7
  mask_alpha = 0.5
  query_cmap = plt.get_cmap('bwr').reversed()
  support_cmap = plt.get_cmap('bwr')

  def process_mask(mask, do_threshold=False, threshold=0.7):
    if do_threshold:
      x = mask
      norm_x = (x - x.min()) / (x.max() - x.min())
      mask = (norm_x > threshold).float()
      mask_np = mask.detach().cpu().numpy()
      mask_np = np.ma.masked_where(mask_np == 1, mask_np)
    else:
      x = mask
      mask = torch.sigmoid(x)
      mask_np = mask.detach().cpu().numpy()
      mask_np = np.ma.masked_where(mask_np == 1, mask_np)

    return mask_np


  # support image with ground truth segmentation mask
  plt.subplot(1, 3, 1)
  img = sample["support_images_t"][0][0].squeeze(0)
  img_np = img.permute(1, 2, 0).cpu().numpy()

  mask = sample["support_mask"][0][0]["fg_mask"].squeeze(0).cpu().numpy()
  mask = np.ma.masked_where(mask == 0, mask)

  plt.imshow(img_np)
  plt.imshow(mask, cmap=support_cmap, alpha=mask_alpha)
  plt.axis('off')


  # ground truth segmentation mask
  plt.subplot(1, 3, 2)
  img = sample["query_images_t"][0][0].squeeze(0)
  img_np = img.permute(1, 2, 0).cpu().numpy()

  mask = sample["query_masks"][0][0].squeeze(0).squeeze(0).cpu().numpy()
  mask = np.ma.masked_where(mask == 1, mask)

  plt.imshow(img_np)
  plt.imshow(mask, cmap=query_cmap, alpha=mask_alpha)
  plt.axis('off')


  # predicted segmentation mask
  plt.subplot(1, 3, 3)
  mask = process_mask(query_pred[0][0], do_threshold=do_threshold, threshold=threshold)
  plt.imshow(img_np)
  plt.imshow(mask, cmap=query_cmap, alpha=mask_alpha)
  plt.axis('off')


  # adds fancy stuff to make it look more like the report
  pos = [ax.get_position() for ax in axes]
  xmin = min(p.x0 for p in pos)
  ymin = min(p.y0 for p in pos)
  xmax = max(p.x1 for p in pos)
  ymax = max(p.y1 for p in pos)

  box = patches.FancyBboxPatch(
      (xmin, ymin), xmax - xmin, ymax - ymin,
      boxstyle="round,pad=0.011",
      transform=fig.transFigure,
      fill=False,
      edgecolor='black',
      linewidth=1
  )
  fig.add_artist(box)

  fig.text(0.515, 0.73, 'ground truth', ha='center', va='bottom')
  fig.text(0.8, 0.73, 'prediction', ha='center', va='bottom')

  fig.text(0.1, 0.73, f'{idx}', ha='center', va='bottom')

  label = VOC_CLASS_STRS[int(sample["class_ids"][0][0])]
  fig.text(0.1, 0.5, label, rotation='vertical', ha='center', va='center')

  return plt.gcf()

In [None]:
# @title Qualitative Examples
# @markdown This block will perform segmentation on 'IMAGES_TO_DISPLAY' images.

IMAGES_TO_DISPLAY = 5 # @param {type:"integer"}
VOC_CLASS_STRS = [
    "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car",
    "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]

# manually init some of the config variables
# sometimes they don't get loaded properly
_config = {}
_config['optim'] = {
  'lr': 1e-3,
  'momentum': 0.9,
  'weight_decay': 0.0005,
}
_config['model'] = {
  'align': True,
}
_config['task'] = {
  'n_ways': 1,
  'n_shots': 1,
  'n_queries': 1,
}
_config['n_steps'] = 30000
_config['label_sets'] = 0
_config['batch_size'] = 1
_config['lr_milestones'] = [10000, 20000, 30000]
_config['align_loss_scaler'] = 1
_config['ignore_label'] = 255
_config['print_interval'] = 100
_config['save_pred_every'] = 1_000  # was 10_000
_config["input_size"] = (417, 417)
_config["seed"] = 1234
_config["cuda_visable"] = '0, 1, 2, 3, 4, 5, 6, 7'
_config["gpu_id"] = 0
_config["mode"] = 'test' # 'train' or 'test'


# init torch, model, seeds, etc.
set_seed(_config['seed'])
cudnn.enabled = True
cudnn.benchmark = True
torch.cuda.set_device(device=_config['gpu_id'])
torch.set_num_threads(1)

model = FewShotSeg(pretrained_path=final_model_weights, cfg=_config['model'])
model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
model.load_state_dict(torch.load(final_model_weights, map_location='cpu'))
model.eval()


# init dataset
data_name = "VOC"
if data_name == 'VOC':
    make_data = voc_fewshot
    max_label = 20
elif data_name == 'COCO':
    make_data = coco_fewshot
    max_label = 80
else:
    raise ValueError('Wrong config for dataset!')
labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']]
transforms = [Resize(size=_config['input_size'])]
transforms = Compose(transforms)

dataset = make_data(
    base_dir="/root/.cache/kagglehub/datasets/ngan2710/voc-devkit-2007-and-2012/versions/1/VOCdevkit/VOC2012/",
    split="trainaug",
    transforms=transforms,
    to_tensor=ToTensorNormalize(),
    labels=labels,
    max_iters=1_000 * 1,
    n_ways=1,
    n_shots=1,
    n_queries=1
)

testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False)


for idx, sample in enumerate(testloader):
  # passing the input into the model
  # note that we cant just do 'model(sample)' because we have to account for
  # different types of annotaions.
  _config['bbox'] = False
  _config['scribble'] = False
  label_ids = list(sample['class_ids'])
  support_images = [[shot.cuda() for shot in way] for way in sample['support_images']]
  suffix = 'scribble' if _config['scribble'] else 'mask'

  if _config['bbox']:
      support_fg_mask = []
      support_bg_mask = []
      for i, way in enumerate(sample['support_mask']):
          fg_masks = []
          bg_masks = []
          for j, shot in enumerate(way):
              fg_mask, bg_mask = get_bbox(shot['fg_mask'], sample['support_inst'][i][j])
              fg_masks.append(fg_mask.float().cuda())
              bg_masks.append(bg_mask.float().cuda())
          support_fg_mask.append(fg_masks)
          support_bg_mask.append(bg_masks)
  else:
      support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way] for way in sample['support_mask']]
      support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way] for way in sample['support_mask']]

  query_images = [query_image.cuda() for query_image in sample['query_images']]
  query_labels = torch.cat([query_label.cuda()for query_label in sample['query_labels']], dim=0)
  query_pred, _ = model(support_images, support_fg_mask, support_bg_mask, query_images)

  label = VOC_CLASS_STRS[int(sample["class_ids"][0][0])]
  figure = show_segmentation(idx, sample, query_pred)
  plt.show(figure)

  if idx >= IMAGES_TO_DISPLAY:
    break