In [1]:
import os
os.chdir('..')
import shutil

In [2]:
import detectron2_1
import time

from detectron2.config import get_cfg
from pathlib import Path
from detectron2.checkpoint import DetectionCheckpointer
import cv2
from itertools import product

from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from PIL import Image
import random
import matplotlib.pyplot as plt
import detectron2.data.transforms as T

from detectron2.modeling import build_model
import numpy as np
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputs
from detectron2.structures import Instances
from detectron2.structures.boxes import Boxes
from detectron2.utils.events import EventStorage

from detectron2.data.samplers import InferenceSampler, TrainingSampler
from detectron2.data.build import get_detection_dataset_dicts
from detectron2.data.common import DatasetFromList, MapDataset
from detectron2.data import build_batch_data_loader, DatasetMapper

import torch
import torch.nn.functional as F

In [3]:
model_dir = Path('outputs')/'coco-detection'
rcnn_weights_path = model_dir/'model_final_280758.pkl'
rcnn_cfg_path = Path('configs/COCO-Detection')/'faster_rcnn_R_50_FPN_3x_test.yaml'

In [4]:
from detectron2_1.configs import get_cfg

def setup(config_file, rcnn_weights_path):
    """
    Create configs and perform basic setups.
    """
    
    # Initialize the configurations
    cfg = get_cfg()
    cfg.merge_from_file(config_file)
    
    # Ensure it uses appropriate names and architecture  
    cfg.AL.OBJECT_SCORING = 'entropy'
    cfg.MODEL.ROI_HEADS.NAME = 'ROIHeadsAL'
    cfg.MODEL.META_ARCHITECTURE = 'ActiveLearningRCNN'
    cfg.MODEL.WEIGHTS = str(rcnn_weights_path)
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05 # lower this threshold to get more boxes

    cfg.freeze()
    return cfg

In [5]:
cfg = setup(rcnn_cfg_path, rcnn_weights_path)

In [6]:
cfg.AL

CfgNode({'MODE': 'object', 'OBJECT_SCORING': 'entropy', 'IMAGE_SCORE_AGGREGATION': 'avg', 'PERTURBATION': CfgNode({'VERSION': 1, 'ALPHAS': [0.08, 0.12], 'BETAS': [0.04, 0.16], 'RANDOM': False, 'LAMBDA': 1.0}), 'DATASET': CfgNode({'NAME': '', 'IMG_ROOT': '', 'ANNO_PATH': '', 'CACHE_DIR': 'al_datasets', 'NAME_PREFIX': 'r', 'BUDGET_STYLE': 'object', 'IMAGE_BUDGET': 20, 'OBJECT_BUDGET': 2000, 'BUDGET_ALLOCATION': 'linear', 'SAMPLE_METHOD': 'top'}), 'OBJECT_FUSION': CfgNode({'OVERLAPPING_METRIC': 'iou', 'OVERLAPPING_TH': 0.25, 'SELECTION_METHOD': 'top', 'REMOVE_DUPLICATES': True, 'REMOVE_DUPLICATES_TH': 0.15, 'RECOVER_MISSING_OBJECTS': True, 'INITIAL_RATIO': 0.85, 'LAST_RATIO': 0.25, 'DECAY': 'linear', 'PRESELECTION_RAIO': 1.5, 'ENDSELECTION_RAIO': 1.25, 'SELECTION_RAIO_DECAY': 'linear', 'RECOVER_ALMOST_CORRECT_PRED': True, 'BUDGET_ETA': 0.2}), 'TRAINING': CfgNode({'ROUNDS': 5, 'EPOCHS_PER_ROUND_INITIAL': 500, 'EPOCHS_PER_ROUND_DECAY': 'linear', 'EPOCHS_PER_ROUND_LAST': 50})})

In [7]:

# cfg = get_cfg()
# add_al_config(cfg)
# cfg.merge_from_file(rcnn_cfg_path)
# cfg.MODEL.WEIGHTS = str(rcnn_weights_path)
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05 # lower this threshold to get more boxes


In [8]:
from detectron2_1.modelling import *

In [9]:
model = build_model(cfg)
model.eval()

# Build_model will not load weights, have to load weights explicitly
checkpointer = DetectionCheckpointer(model)  
checkpointer.load(cfg.MODEL.WEIGHTS)

# Metadata of dataset
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])

# Testtime augmentation is random resize
aug = T.ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, 
                            cfg.INPUT.MIN_SIZE_TEST], 
                            cfg.INPUT.MAX_SIZE_TEST) 


- Rewrite data loader because we want data loader with batch size > 1 and no shuffle

In [10]:
# dataset_dicts = get_detection_dataset_dicts(
#     cfg.DATASETS.TRAIN,
#     filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
#     min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
#     if cfg.MODEL.KEYPOINT_ON
#     else 0,
#     proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
# )

# dataset = DatasetFromList(dataset_dicts)
# mapper = DatasetMapper(cfg, False) # set training = True? to enable gradients calculation
# dataset = MapDataset(dataset, mapper)

# # sampler = InferenceSampler(len(dataset))
# # batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 4, drop_last=False) # reduce batch size to 4

# # def trivial_batch_collator(batch):
# #     """
# #     A batch collator that does nothing.
# #     """
# #     return batch

# # data_loader = torch.utils.data.DataLoader(
# #     dataset,
# #     num_workers=cfg.DATALOADER.NUM_WORKERS,
# #     batch_sampler=batch_sampler,
# #     collate_fn=trivial_batch_collator,
# # )

In [11]:
# #### Training data loader 

# sampler = TrainingSampler(len(dataset), shuffle=False) # no shuffle
# # this will be an infinite data loader 
# data_loader = build_batch_data_loader(
#         dataset,
#         sampler,
#         4, # batch size equals 1
#         aspect_ratio_grouping=True,
#         num_workers=cfg.DATALOADER.NUM_WORKERS)

In [10]:
from detectron2.data import build_detection_test_loader
dataset_mapper = DatasetMapper(cfg, False) # set training = True? to enable gradients calculation
data_loader= build_detection_test_loader(
    cfg, cfg.DATASETS.TEST[0], mapper=dataset_mapper
)

In [11]:
### Enter training mode ###
# model.train()
# with EventStorage() as storage: # during training, we have to add this to suppress error
#     loss_dict = model(data)
    
# model.zero_grad() # zero out gradients first
# losses = sum(loss_dict.values())

# losses.backward()
# gradients_last = model.roi_heads.box_predictor.cls_score.weight.grad # get gradients of last layer

### Enter evaluation mode ###
# model.eval()
# outputs = model(data)


In [12]:
for i, data in enumerate(data_loader):
    break
    if i >= len(dataset):
        print(i)
        break

- Scoring functions

In [17]:
def estimate_for_proposals(model, features, proposals, return_grad=False):
    
    box2boxtransform = Box2BoxTransform(weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS)
    gradients_vec = []
    
#     if not return_grad:
    with torch.no_grad():
        features = [features[f] for f in model.roi_heads.box_in_features]
        box_features = model.roi_heads.box_pooler(features,
                            [x if isinstance(x, Boxes) \
                                else x.proposal_boxes for x in proposals])
        box_features = model.roi_heads.box_head(box_features)
        pred_class_logits, pred_proposal_deltas = model.roi_heads.box_predictor(box_features)
        del box_features

#     else:
#     with torch.enable_grad():
#         features = [features[f] for f in model.roi_heads.box_in_features]
#         box_features = model.roi_heads.box_pooler(features,
#                             [x if isinstance(x, Boxes) \
#                                 else x.proposal_boxes for x in proposals])
#         box_features = model.roi_heads.box_head(box_features)
#         pred_class_logits, pred_proposal_deltas = model.roi_heads.box_predictor(box_features)
#         _, pred_class_indices = torch.max(pred_class_logits, -1)        
#         print(pred_class_indices.shape)
        
#         m = torch.nn.LogSoftmax(dim=-1)
#         loss = torch.nn.NLLLoss()
#         target = pred_class_indices
#         tot_loss = loss(m(pred_class_logits), target)
#         print(tot_loss)
#         tot_loss.backward()
#         # only get gradients w.r.t. predicted class for computational feasibility
#         gradients = model.roi_heads.box_predictor.cls_score.weight.grad.detach().cpu()
        
#         _, pred_class_indices = torch.max(pred_class_logits, -1)            

#         for j, ind in enumerate(pred_class_indices):
#             model.zero_grad()
#             pred_class_logits[j, ind].backward(retain_graph=True)
#             gradients_j = model.roi_heads.box_predictor.cls_score.weight.grad[ind].detach().cpu()
#             gradients_vec.append(gradients_j)
#         del box_features

    outputs = FastRCNNOutputs(
        box2boxtransform,
        pred_class_logits,
        pred_proposal_deltas,
        proposals,
        cfg.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA)
    
    return outputs
#     return outputs, gradients

In [18]:
# outputs = model(data)
# outputs.backward()
# print(model.roi_heads.box_predictor.cls_score.weight.grad)

In [19]:
## Preprocess image
images = model.preprocess_image(data)

# Get features
features = model.backbone(images.tensor)

# Get bounding box proposals
proposals, _ = model.proposal_generator(images, features, None)

# Get predicted boxes and do roi_head postprocessing(filter by scores, NMS, select topk)
outputs = estimate_for_proposals(model, features, proposals, return_grad=True)
pred_boxes = outputs.predict_boxes()
pred_probs = outputs.predict_probs()

# Perform postprocessing
cur_detections, filtered_indices = outputs.inference(cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST,
                                                     cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, 
                                                     cfg.TEST.DETECTIONS_PER_IMAGE)

In [21]:
# Avgpool over channels
features_heatmaps = [torch.mean(features[k], dim=1) for k in features.keys()]

# Adaptive pooling 2D --> Tensor(shape 5x8x8)]
pooled_heatmaps = torch.cat([F.adaptive_avg_pool2d(f_map[None, ...], output_size=(8, 8)) \
                             .view(1, 8, 8).detach().cpu() \
                             for f_map in features_heatmaps], dim=0)
pooled_heatmaps = pooled_heatmaps.view(-1)



In [23]:
pooled_heatmaps.shape

torch.Size([320])

In [13]:
feature_refs = torch.rand((5, 320))

In [14]:
outputs = model.forward_al(data, feature_refs)

torch.Size([90, 81])
tensor([0.0003, 0.0005, 0.0010, 0.0034, 0.0035, 0.0038, 0.0040, 0.0050, 0.0045,
        0.0048, 0.0053, 0.0053, 0.0056, 0.0071, 0.0069, 0.0090, 0.0086, 0.0124,
        0.0094, 0.0111, 0.0161, 0.0136, 0.0087, 0.0091, 0.0091, 0.0104, 0.0163,
        0.0076, 0.0122, 0.0077, 0.0163, 0.0091, 0.0072, 0.0134, 0.0069, 0.0076,
        0.0091, 0.0136, 0.0059, 0.0060, 0.0126, 0.0068, 0.0057, 0.0058, 0.0095,
        0.0051, 0.0127, 0.0051, 0.0050, 0.0056, 0.0145, 0.0095, 0.0048, 0.0058,
        0.0067, 0.0047, 0.0049, 0.0047, 0.0040, 0.0068, 0.0082, 0.0049, 0.0053,
        0.0044, 0.0065, 0.0043, 0.0117, 0.0052, 0.0041, 0.0039, 0.0048, 0.0079,
        0.0084, 0.0053, 0.0041, 0.0059, 0.0035, 0.0043, 0.0047, 0.0261, 0.0066,
        0.0041, 0.0063, 0.0029, 0.0051, 0.0041, 0.0090, 0.0090, 0.0031, 0.0053],
       device='cuda:0')
tensor([0.0003, 0.0005, 0.0010, 0.0034, 0.0035, 0.0038, 0.0040, 0.0050, 0.0045,
        0.0048, 0.0053, 0.0053, 0.0056, 0.0071, 0.0069, 0.0090, 0.0086, 0.

In [15]:
outputs[0]['instances'].scores_al

tensor([0.0003, 0.0005, 0.0010, 0.0034, 0.0035, 0.0038, 0.0040, 0.0050, 0.0045,
        0.0048, 0.0053, 0.0053, 0.0056, 0.0071, 0.0069, 0.0090, 0.0086, 0.0124,
        0.0094, 0.0111, 0.0161, 0.0136, 0.0087, 0.0091, 0.0091, 0.0104, 0.0163,
        0.0076, 0.0122, 0.0077, 0.0163, 0.0091, 0.0072, 0.0134, 0.0069, 0.0076,
        0.0091, 0.0136, 0.0059, 0.0060, 0.0126, 0.0068, 0.0057, 0.0058, 0.0095,
        0.0051, 0.0127, 0.0051, 0.0050, 0.0056, 0.0145, 0.0095, 0.0048, 0.0058,
        0.0067, 0.0047, 0.0049, 0.0047, 0.0040, 0.0068, 0.0082, 0.0049, 0.0053,
        0.0044, 0.0065, 0.0043, 0.0117, 0.0052, 0.0041, 0.0039, 0.0048, 0.0079,
        0.0084, 0.0053, 0.0041, 0.0059, 0.0035, 0.0043, 0.0047, 0.0261, 0.0066,
        0.0041, 0.0063, 0.0029, 0.0051, 0.0041, 0.0090, 0.0090, 0.0031, 0.0053])

In [14]:
len(gradients)

1000

In [58]:
features['p2'].shape

torch.Size([1, 256, 200, 272])

In [61]:
pred_probs[0].shape

torch.Size([1000, 81])

In [33]:
pred_probs = torch.stack(pred_probs)
filter_prob = [pred_probs[i, filtered_index, :] for i, filtered_index in enumerate(filtered_indices)]

# compute entropy 
entropy_scores = [torch.sum(-probs*torch.log2(probs), dim=-1) for idx, probs in enumerate(filter_prob)]

for cur_detection, object_score in zip(cur_detections, entropy_scores):
    cur_detection.scores_al = object_score

In [46]:
cfg.DATASETS

CfgNode({'TRAIN': ('coco_2017_train',), 'PROPOSAL_FILES_TRAIN': (), 'PRECOMPUTED_PROPOSAL_TOPK_TRAIN': 2000, 'TEST': ('coco_2017_val',), 'PROPOSAL_FILES_TEST': (), 'PRECOMPUTED_PROPOSAL_TOPK_TEST': 1000})

## Register datasets

In [93]:
from detectron2.data.datasets import register_coco_instances
# Define dataset paths
data_dir = Path("datasets/coco/train2017/")
train_subset1_coco = "datasets/coco/annotations/instances_train2017_subset1.json"
train_subset2_coco = "datasets/coco/annotations/instances_train2017_subset2.json"
register_coco_instances("coco_2017_train_subset1", {}, train_subset1_coco, data_dir)
register_coco_instances("coco_2017_train_subset2", {}, train_subset2_coco, data_dir)

AssertionError: Dataset 'coco_2017_train_subset1' is already registered!

In [94]:
model_dir = Path('outputs')/'coco-detection'
rcnn_weights_path = model_dir/'model_final_280758.pkl'
rcnn_cfg_path = Path('configs/COCO-Detection')/'faster_rcnn_R_50_FPN_3x.yaml'

cfg = get_cfg()
cfg.merge_from_file(rcnn_cfg_path)
cfg.MODEL.WEIGHTS = str(rcnn_weights_path)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05 # lower this threshold to get more boxes
cfg.DATASETS.TRAIN = ('coco_2017_train_subset1',)
cfg.DATASETS.TEST = ('coco_2017_train_subset2',)

In [5]:

def calculate_entropy_scores(p):
    # use entropy
    entropy = torch.sum(-(p * torch.log(p)), dim=-1)
    return entropy


######### For perturbed box #######################################################
def iou(area1, area2, inter_area):
    return inter_area / (area1 + area2 - inter_area)

def elementwise_intersect_area(boxes1, boxes2):

    # Modified based on 
    # https://detectron2.readthedocs.io/_modules/detectron2/structures/boxes.html#pairwise_iou
    boxes1, boxes2 = boxes1.tensor, boxes2.tensor

    width_height = torch.min(boxes1[:, 2:], boxes2[:, 2:]) - torch.max(boxes1[:, :2], boxes2[:, :2])  # [N,M,2]

    width_height.clamp_(min=0)  # [N,2]
    inter = width_height.prod(dim=-1)  # [N]
    del width_height
    return inter

def elementwise_iou(boxes1, boxes2):
    
    area1, area2 = boxes1.area(), boxes2.area()
    inter_area = elementwise_intersect_area(boxes1, boxes2)

    scores = torch.where(
            inter_area > 0,
            iou(area1, area2, inter_area),
            torch.zeros(1, dtype=inter_area.dtype, device=inter_area.device),
        )
    return scores

def calculate_iou_scores(perturbed_box, raw_det, num_shifts, num_bbox_reg_classes):
    reshaped_boxes = perturbed_box.reshape(-1, num_bbox_reg_classes, 4)
    cat_ids = raw_det.pred_classes.repeat_interleave(num_shifts, dim=0)
    perturbed_boxes = Boxes(torch.stack([reshaped_boxes[row_id, cat_id] for row_id, cat_id in enumerate(cat_ids)]))
    raw_boxes = Boxes(raw_det.pred_boxes.tensor.repeat_interleave(num_shifts, dim=0))
    ious = elementwise_iou(raw_boxes, perturbed_boxes)
    # aggregate the statistics for each prediction
    iou_scores = torch.Tensor([scores.mean() for scores in ious.split(num_shifts)])
    return iou_scores

######### For perturbed box #######################################################
def calculate_ce_scores(p, q, num_shifts):
    # use crossentropy for calculation diff
    diff = - (p * torch.log(q)).mean(dim=-1)
    # aggregate the statistics for each prediction
    diff = torch.Tensor([scores.mean() for scores in diff.split(num_shifts)]) 

    return diff

In [75]:
class ActiveLearning():
    
    def __init__(self, cfg, score_func):
        
        # Initialize model
        self.cfg = cfg
        self.model = build_model(cfg)
        self.model.eval()

        # Build_model will not load weights, have to load weights explicitly
        checkpointer = DetectionCheckpointer(self.model)  
        checkpointer.load(cfg.MODEL.WEIGHTS)

        # Metadata of dataset
        self.metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
        
        # Testtime augmentation is random resize
        self.aug = T.ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, 
                                         cfg.INPUT.MIN_SIZE_TEST], 
                                         cfg.INPUT.MAX_SIZE_TEST) 
        # Perturbation matrix for xywh
        self.num_shifts, self.shift_matrix = self._create_translations()   
        
        # scoring function
        self.score_func = score_func
        
        # data loader
        self.data_loader, self.len_data = self._load_dataset()
    
        dataset_mapper = DatasetMapper(cfg, False) # set training = False
        self.data_loader= build_detection_test_loader(
            cfg, cfg.DATASETS.TEST[0], mapper=dataset_mapper
        )

    
    def _estimate_for_proposals(self, features, proposals):
    
        box2boxtransform = Box2BoxTransform(weights=self.cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS)

        with torch.no_grad():
            features = [features[f] for f in self.model.roi_heads.box_in_features]
            box_features = self.model.roi_heads.box_pooler(features,
                                [x if isinstance(x, Boxes) \
                                    else x.proposal_boxes for x in proposals])
            box_features = self.model.roi_heads.box_head(box_features)
            pred_class_logits, pred_proposal_deltas = self.model.roi_heads.box_predictor(box_features)
            del box_features

        outputs = FastRCNNOutputs(
            box2boxtransform,
            pred_class_logits,
            pred_proposal_deltas,
            proposals,
            self.cfg.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA)

        return outputs
        
    @staticmethod
    def _create_translations():
        '''Function to generate perturbation matrix'''
        # https://github.com/lolipopshock/Detectron2_AL
        def _generate_individual_shift_matrix(alpha, beta):
            return torch.Tensor([
                    [(1-alpha), 0,       -alpha,    0],
                    [0,        (1-beta), 0,         -beta],
                    [alpha,     0,       (1+alpha), 0],
                    [0,         beta,    0,         (1+beta)],
                ])

        # Horizontal translation ratio
        alphas = [0.08, 0.12] 
        # Vertical translation ratio
        betas  = [0.04, 0.16] 

        derived_shift = [
            [
                [alpha, beta], 
                [alpha, -beta], 
                [-alpha, beta], 
                [-alpha, -beta]
            ] 
            for alpha, beta in product(alphas, betas)
        ]

        matrices = [_generate_individual_shift_matrix(*params) 
                        for params in sum(derived_shift, [])]
        return len(matrices), torch.stack(matrices, dim=-1)
    
    @staticmethod
    def _feature_embed(features):
        # Avgpool over channels
        features_heatmaps = [torch.mean(features[k], dim=1) for k in features.keys()]
        
        # Adaptive pooling 2D --> Tensor(shape 5x32x32)]
        pooled_heatmaps = torch.cat([F.adaptive_avg_pool2d(f_map[None, ...], output_size=(32, 32)) \
                                     .view(1, 32, 32).detach().cpu() \
                                     for f_map in features_heatmaps], dim=0)
        
        return pooled_heatmaps
    
    def _perturb_consistency(self, cur_detections, features, orig_prob):
        # Box Perturbations
        # Get new proposals
        orig_boxes = cur_detections[0].pred_boxes.tensor
        new_proposals = Instances(cur_detections[0].image_size,
                                  proposal_boxes=Boxes(
                                      torch.einsum('bi,ijc->bjc', 
                                      orig_boxes,
                                      self.shift_matrix.to(orig_boxes.device)).permute(0,2,1).reshape(-1,4))) 

        perturbed_outputs = self._estimate_for_proposals(features, [new_proposals]) 
        perturbed_probs = perturbed_outputs.predict_probs()[0]
        perturbed_boxes = perturbed_outputs.predict_boxes()[0]
        num_bbox_reg_classes = perturbed_boxes.shape[1] // 4

        p = orig_prob.repeat_interleave(self.num_shifts, dim=0)
        q = perturbed_probs

        # Compute difference in IoU and Cross-Entropy
        diff1 = calculate_iou_scores(perturbed_boxes, cur_detections[0], 
                                     self.num_shifts, num_bbox_reg_classes)
        diff2 = calculate_ce_scores(p, q, self.num_shifts)
        box_diff = diff1 + diff2
        return box_diff

    
    def run(self):
        
        for i, data in enumerate(self.data_loader):
                
            file_name = data[0]['file_name']
            image_id = data[0]['image_id']
            
            with torch.no_grad():
                # Preprocess images 
                images = self.model.preprocess_image(data)

                # Get features
                features = self.model.backbone(images.tensor)

                # Get bounding box proposals
                proposals, _ = self.model.proposal_generator(images, features, None)

                # Get predicted boxes and do roi_head postprocessing(filter by scores, NMS, select topk)
                outputs = self._estimate_for_proposals(features, proposals)
                pred_boxes = outputs.predict_boxes()[0] # batch size is 1
                pred_probs = outputs.predict_probs()[0]
                cur_detections, filtered_indices = outputs.inference(self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST,
                                               self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, 
                                               self.cfg.TEST.DETECTIONS_PER_IMAGE)
                filter_prob = pred_probs[filtered_indices[0], :]
                
                if self.score_func == 1:
                    ## Entropy
                    scores = calculate_entropy_scores(filter_prob)
                    cur_detections[0].scores_al = scores  # batch size is 1
#                     return cur_detections
                    
                elif self.score_func == 2:
                    ## Perturbed boxes consistency
                    scores = self._perturb_consistency(cur_detections, features, filter_prob)
                    cur_detections[0].scores_al = scores  # batch size is 1
#                     return cur_detections
                    
                elif self.score_func == 3:
                    ## Get embedding of FPN 
                    embed = self._feature_embed(features).detach().cpu() # per-image basis
#                     return cur_detections, embed                
            
                del images, features, proposals, outputs
        
    def save_res(self, detections, file_name, image_id, write_path):
        pass

In [76]:
al = ActiveLearning(cfg, score_func=3)

In [77]:
haha = al.run()

In [81]:
haha[1].shape

torch.Size([5, 16, 16])

In [50]:
a = torch.arange(10).reshape(5,2)

In [52]:
a

tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])

In [56]:
a.split(1)

(tensor([[0, 1]]),
 tensor([[2, 3]]),
 tensor([[4, 5]]),
 tensor([[6, 7]]),
 tensor([[8, 9]]))