In [1]:
# General Dependencies

import os
import math
import numpy as np
import torch


In [2]:
"""
Run from Colab or local
"""

try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    ROOT_PATH = "/content/gdrive/MyDrive/AAFormer"
    DATA_PATH = os.path.join(ROOT_PATH, "Datasets")
    
    %cd ./gdrive/MyDrive/AAFormer

except:
    ROOT_PATH = os.curdir
    DATA_PATH = "../Datasets"




In [3]:
# Source dependencies
from data.dataset import FSSDataset
from common.vis import Visualizer
from common.evaluation import Evaluator

## Prepare PASCAL-$5^i$ Dataset

In [4]:
# STEP 1: Download PASCAL VOC2012 devkit (train/val data): (uncomment lines below)
# ------------------------------------------------------------------------------
#!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
#!tar -xvf 'VOCtrainval_11-May-2012.tar' -C ./Datasets/ 

# (or instead of wget, use directly the link to download)
# STEP 2: Place "VOC2012" folder from downloaded "VOCdevkit" folder under a "Datasets" folder.
# ------------------------------------------------------------------------------
#!mv Datasets/VOCdevkit/VOC2012 Datasets/VOC2012

# STEP 3: Download extended annotations from here
#!wget https://drive.google.com/file/d/10zxG2VExoEZUeyQl_uXga2OWHjGeZaf2/view

# STEP 4: Extract Put the downloaded extension under "Datasets/VOC2012/"


In [5]:
# Dataset initialization 
# TODO: does paper mention about image size? TODO: yes, 473
# TODO: they also do data aug...
# TODO: add disclaimer to dataset files https://github.com/juhongm999/hsnet/blob/main/train.py
FSSDataset.initialize(img_size=128, datapath=DATA_PATH, use_original_imgsize=False)
dataloader_trn = FSSDataset.build_dataloader(benchmark='pascal', bsz=4, nworker=1, fold=0, split='trn')
dataloader_val = FSSDataset.build_dataloader(benchmark='pascal', bsz=4, nworker=1, fold=0, split='val')

Total (trn) images are : 11394
Total (val) images are : 346


## Visualization deneme

In [6]:
Visualizer.initialize(True)
Evaluator.initialize()

In [7]:
""" Code below visualizes prediction (to be used in testing)

for idx, batch in enumerate(dataloader_trn):
        # 1. Hypercorrelation Squeeze Networks forward pass
        #batch = utils.to_cuda(batch)
        #pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot)

        #assert pred_mask.size() == batch['query_mask'].size()
        pred_mask = batch['query_mask']

        # 2. Evaluate prediction
        area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
        
        #average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
        #average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)

        # Visualize predictions
        if Visualizer.visualize:
            Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
                                                  batch['query_img'], batch['query_mask'],
                                                  pred_mask, batch['class_id'], idx,
                                                  area_inter[1].float() / area_union[1].float())

        break   # TODO: delete this break to run visualization for full dataset
"""

" Code below visualizes prediction (to be used in testing)\n\nfor idx, batch in enumerate(dataloader_trn):\n        # 1. Hypercorrelation Squeeze Networks forward pass\n        #batch = utils.to_cuda(batch)\n        #pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot)\n\n        #assert pred_mask.size() == batch['query_mask'].size()\n        pred_mask = batch['query_mask']\n\n        # 2. Evaluate prediction\n        area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)\n        \n        #average_meter.update(area_inter, area_union, batch['class_id'], loss=None)\n        #average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)\n\n        # Visualize predictions\n        if Visualizer.visualize:\n            Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],\n                                                  batch['query_img'], batch['query_mask'],\n                                          

## Decoder Test

In [10]:
from model.tokens import init_agent_tokens
from model.featureextractor import FeatureExtractor
from model.representationencoder import RepresentationEncoder

# TODO: c is hyperparameter, gather these params at the beginning
c = 2048
hw = 256 
N = 16
heads = 8
num_tokens = 15 # TODO: make it a hyperparameter

feature_extractor = FeatureExtractor(layers=50, reduce_dim=hw, c=c)
representation_encoder = RepresentationEncoder(c, hw, N, heads)
agent_learning_decoder = AgentLearningDecoder(c, N)


In [12]:
# Assumption
# --------------------------------------------------------------------------
# Authors do not specify if the attention mechanism of Agent Learning Decoder 
# is multi-head or not, so we assume it is not multi-head.
# For Representation Encoder and Agent Matching Decoder, they mention the 
# equations are implemented with multi-head mechanism.
# We implement the attention as if it is multi-head and setting num_heads=1
# will make it a single head.

import torch
import torch.nn as nn

# This attention mechanism is not the same with original self-attention.
# Here we implement eqn.3 and eqn.4, i.e. 
# S = softmax(QK^T / sqrt(d_k) + M)
# There will be another algorithm, Optimal Transport (eqn.5-6) before
# we scale the attention with Value.
class Attention_Eqn3(nn.Module):
    def __init__(self, hidden_dims):
        super().__init__()
        
        self.d_k = hidden_dims
        self.W_a_Q = nn.Linear(hidden_dims, hidden_dims)
        self.W_s_K = nn.Linear(hidden_dims, hidden_dims)
    
    
    def forward(self, F_a, F_s, M):
        
        # F_a has shape (batchsize, num_tokens, c=d_k)
        # F_s has shape (batchsize, h*w, c=d_k)
        
        Q_a = self.W_a_Q(F_a)  # Get Query, shape (batchsize, numtokens, c=d_k)
        K_s = self.W_s_K(F_s)  # Get Key, shape (batchsize, hw, c=d_k)
        
        # Transposed Key has shape (batchsize, c=d_k, hw)
        # This corresponds to K x hw dimensions of M in eqn.4, see page 8 first sentence
        # Q_a's dimension is unsqueezed to get (batchsize, 1, numtokens, d_k)
        # in order to use torch.matmul (see its documents for explanation)
        # Such that the result QK has shape (batchsize, numtokens, hw)
        QK = torch.matmul(Q_a.unsqueeze(1), K_s.transpose(1,2)) / math.sqrt(self.d_k)
        
        print(QK.shape)
        #S = torch.nn.Softmax(QK + M)

class AgentLearningDecoder(nn.Module):
    
    def __init__(self,  c, num_layers, num_heads=1):
        super().__init__()
        
        self.d_k = c // num_heads
        self.attn = Attention_Eqn3(self.d_k)
    
    
    def forward(self, F_a, F_s, M_s):
        
        # Step 1: Masked Cross Attention between (F_a, F_s_hat)
        # This part is the implementation of eqn.3 and eqn.4
        # Return Part Mask S (see Fig.3 (a))
        # --------------------------------------------------------
        
        # Flatten M_s  from shape (batchsize, 1, h, w)
        # to shape (batchsize, 1, h*w)
        M_s = torch.flatten(M_s, start_dim=2)
        
        # See page 8, first sentence, N is the duplication of M
        # for each token of the agent tokens.
        # N has shape (batchsize, numtokens, hw)
        num_tokens = F_a.shape[1]
        N = M_s.repeat(1,num_tokens,1) 
                
        M = torch.where(N == 1, 0, float('-inf'))
        # Debug: Check M has zeros in it
        # print ((M == 0).nonzero(as_tuple=True)[0])
        
        # Get the "masked attention weight matrix"
        S = self.attn(F_a, F_s, M)
        #print(S.shape)
        
        # Step 2: TODO
        # This part is the implementation of eqn.5 and eqn.6
        # --------------------------------------------------------
        # Step 3: TODO
        # --------------------------------------------------------
        # Step 4: TODO
        # --------------------------------------------------------


In [13]:
for idx, batch in enumerate(dataloader_trn):
    
    query_img = batch['query_img']
    supp_imgs = batch['support_imgs']
    supp_masks = batch['support_masks']
    
    # STEP 1: Extract Features from the backbone model (ResNet)
    F_Q, F_S, s_mask_list, f_s, M_s = feature_extractor(query_img, supp_imgs, supp_masks)

    # STEP 2.1: Pass the features from encoder 
    F_S_hat = representation_encoder(F_S)
    
    # STEP 2.2: Get Initial Agent Tokens
    # TODO: can we get rid of for loop?
    X, L = [], []
    for i, m in enumerate(M_s):  # M_s has shape (batchsize, 1, 16, 16)
      m = m.squeeze(0)  # Get a single mask, shape (16, 16)

      fg = np.where(m == 1.) # get foreground pixels
      bg = np.where(m == 0.) # get background pixels
      
      # Create tensor with shape [num_foreground_pix, 2] where the last dimension has (x,y) locations of foreground pixels
      foreground_pix = torch.stack((torch.from_numpy(fg[0]), torch.from_numpy(fg[1])), dim=1)
      background_pix = torch.stack((torch.from_numpy(bg[0]), torch.from_numpy(bg[1])), dim=1)

      X.append(foreground_pix)
      L.append(background_pix)

    tokens = init_agent_tokens(num_tokens, X, L, f_s) # every token has [K,c] dim for every sample in a batch
    
    
    # STEP 3: Pass initial agent tokens through Agent Learning Decoder
    # and obtain agent tokens.
    agent_learning_decoder(tokens, F_S_hat, M_s)
    
    # STEP 4: Pass agent tokens through Agent Matching Decoder
    break

TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not int