# AAFormer: Adaptive Agent Transformer for Few-shot Segmentation

**Paper authors:** Yuan Wang, Rui Sun, Zhe Zhang, Tianzhu Zhang

**Reproduced by:** Yusuf Soydan, Bartu Akyürek

This paper is about few-shot segmentation (FSS), aiming to segment objects in a given image and a set of support images with masks.

## Prepare Workspace


In [1]:
# General Dependencies

import os
import math
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
"""
Run from Colab or local
"""

try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    ROOT_PATH = "/content/gdrive/MyDrive/AAFormer"
    DATA_PATH = os.path.join(ROOT_PATH, "Datasets")
    
    %cd ./gdrive/MyDrive/AAFormer

except:
    ROOT_PATH = os.curdir
    DATA_PATH = "../Datasets"
    # Note: If you are running this notebook on your local, please put the Datasets folder outside of current directory

In [3]:
# Source dependencies
from data.dataset import FSSDataset
from common.vis import Visualizer
from common.evaluation import Evaluator

In [4]:
# HYPERPARAMETERS
# -----------------------------------------------------------------
# Note: the parameters with * comment are NOT provided by the paper
# TODO: make hyperparameters UPPER_CASE
image_resolution = 128
reduce_dim = 256
c = 2048            # * Hidden dimensions
hw = 256            # * Dimensionality of Feature maps (depends on ResNet choice)
N = 16              # * Number of layers in Encoder/Decoder
heads = 8           # * Number of attention heads
num_tokens = 15     # * Number of Agent Tokens
sinkhorn_reg = 5e-1 # * Regularization term of Optimal Transport
max_iter_ot = 100   # * Maximum iterations of Optimal Transport algorithm 

batch_size = 4    # According to paper: 4
adam_lr = 1e-4    # According to paper: 0.0001
adam_decay = 1e-2 # According to paper: 0.01
num_epoch = 400   # According to paper: 400

# Additional parameters to control the flow
# ----------------------------------------------------------------
bypass_ot = False                         # To turn off OT module if it causes any problem (e.g. RAM failure)
use_dice_loss = False                     # False means use CrossEntropy loss
cuda = torch.cuda.is_available()          # 
device = 'cuda' if cuda else 'cpu'        # 
checkpoint = 100                          # Save the model every {checkpoint} iteration

## Prepare PASCAL-$5^i$ Dataset

In [5]:
# STEP 1: Download PASCAL VOC2012 devkit (train/val data): (uncomment lines below)
# ------------------------------------------------------------------------------
#!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
#!tar -xvf 'VOCtrainval_11-May-2012.tar' -C ./Datasets/ 

# (or instead of wget, use directly the link to download)
# STEP 2: Place "VOC2012" folder from downloaded "VOCdevkit" folder under a "Datasets" folder.
# ------------------------------------------------------------------------------
#!mv Datasets/VOCdevkit/VOC2012 Datasets/VOC2012

# STEP 3: Download extended annotations from here
#!wget https://drive.google.com/file/d/10zxG2VExoEZUeyQl_uXga2OWHjGeZaf2/view

# STEP 4: Extract Put the downloaded extension under "Datasets/VOC2012/"


In [6]:
# Dataset initialization 
# TODO: does paper mention about image size? TODO: yes, 473
# TODO: they also do data aug...
# TODO: add disclaimer to dataset files https://github.com/juhongm999/hsnet/blob/main/train.py
FSSDataset.initialize(img_size=image_resolution, datapath=DATA_PATH, use_original_imgsize=False)
dataloader_trn = FSSDataset.build_dataloader(benchmark='pascal', bsz=4, nworker=1, fold=0, split='trn')
dataloader_val = FSSDataset.build_dataloader(benchmark='pascal', bsz=4, nworker=1, fold=0, split='val')

Total (trn) images are : 11394
Total (val) images are : 346


## Initialize Models

In [7]:
from model.aaformer import AAFormer
from common.dice_loss import BinaryDiceLoss

model = AAFormer(cuda = cuda,
                 c = c,
                 hw = hw,
                 N = N,
                 heads = heads, 
                 num_tokens = num_tokens, 
                 im_res=image_resolution, 
                 reduce_dim=reduce_dim, 
                 bypass_ot=bypass_ot, 
                 sinkhorn_reg=sinkhorn_reg,
                 max_iter_ot=max_iter_ot)

## Training Loop

In [None]:
# See Supplementary Material: "Dice loss is adopted to train our model" and AdamW optimizer is used for transformer blocks
# Although BCE loss is not mentioned in the paper, many papers working on segmentation uses BCE to compare the binary masks
# Dice loss function we used returns a very low value (about 0.8), which may result in tiny gradients.
dice_loss = BinaryDiceLoss()
bce_loss = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=adam_lr, weight_decay=adam_decay)
# TODO: how to seperate optimizers for models and rest?

step = 0
train_losses = []
val_losses = []
for epoch in range(num_epoch):
    
    model.train()
    train_loss = 0.0
    for idx, batch in enumerate(dataloader_trn):

        # STEP 0: Get query image and support images with corresponding masks
        query_img = batch['query_img'].to(device)
        supp_imgs = batch['support_imgs'].to(device)
        supp_masks = batch['support_masks'].to(device)

        # STEP 1: Get predicted mask
        preds = model(query_img, supp_imgs, supp_masks, normalize=use_dice_loss) # TODO: make normalize false if torch.CrossEntropy is used

        # STEP 2: Compute Dice loss
        if use_dice_loss:
            loss = dice_loss(preds.squeeze(1), supp_masks.squeeze(1))
        else:
            loss = bce_loss(preds.squeeze(1), supp_masks.squeeze(1))
        
        train_loss += loss.item()

        # STEP 3: Update optimizer
        step += 1
        loss.backward()
        optimizer.step()
        print("Step ", step, " Loss: ", loss.item())
        
        # TODO: save model file at every X iteration
        if(step % checkpoint == 0):
            pass

    # Append the average loss
    train_losses.append(train_loss/len(dataloader_trn))
    print("Epoch ", epoch, " Loss: ", train_loss/len(dataloader_trn))

    # TODO: validation loss? (with torch.no_grad():?)
    

 71%|█████████████████████████████████████████████████████████▏                      | 183/256 [00:06<00:02, 28.65it/s]

## Visualization Test

In [None]:
""" Code below visualizes prediction (to be used in testing)
Visualizer.initialize(True)
Evaluator.initialize()

for idx, batch in enumerate(dataloader_trn):
        # 1. Hypercorrelation Squeeze Networks forward pass
        #batch = utils.to_cuda(batch)
        #pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot)

        #assert pred_mask.size() == batch['query_mask'].size()
        pred_mask = batch['query_mask']

        # 2. Evaluate prediction
        area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
        
        #average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
        #average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)

        # Visualize predictions
        if Visualizer.visualize:
            Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
                                                  batch['query_img'], batch['query_mask'],
                                                  pred_mask, batch['class_id'], idx,
                                                  area_inter[1].float() / area_union[1].float())

        break   # TODO: delete this break to run visualization for full dataset
"""

In [None]:
# TODO: declare numerical results