In [1]:
import argparse
import datetime
import json
import random
import time
from pathlib import Path
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
import datasets
import util.misc as utils
import datasets.samplers as samplers
from datasets import build_dataset, get_coco_api_from_dataset
from datasets.coco import make_coco_transforms
from datasets.torchvision_datasets.open_world import OWDetection
# from engine import evaluate, train_one_epoch, get_exemplar_replay
from models import build_model
import wandb

  from .autonotebook import tqdm as notebook_tqdm


{'OWDETR': ('aeroplane', 'bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'cow', 'dog', 'horse', 'motorbike', 'sheep', 'train', 'elephant', 'bear', 'zebra', 'giraffe', 'truck', 'person', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'chair', 'diningtable', 'pottedplant', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'bed', 'toilet', 'sofa', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'tvmonitor', 'bottle', 'unknown'), 'TOWOD': ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',

## Arguments setting

In [28]:
from dotmap import DotMap

args = DotMap(
    # main code
    frozen_weights=None,
    seed=0,
    device='cuda',
    data_root='./data/OWOD',
    batch_size=2,
    num_workers=8,
    lr_linear_proj_names=['reference_points', 'sampling_offsets'],
    lr=2e-4,
    lr_backbone_names=["backbone.0"],
    lr_linear_proj_mult=2e-5,
    sgd=False,
    weight_decay=1e-4,    
    pretrain=None,
    freeze_prob_model=False,
    
    # build_model
    num_classes=81,
    num_queries=100,
    num_feature_levels=4,
    aux_loss=True,
    with_box_refine=False,
    two_stage=False,
    masks=False,
    cls_loss_coef=2,
    bbox_loss_coef=5,
    giou_loss_coef=2,
    # obj_loss_coef=1, # duplicated
    # mask_loss_coef=,
    # dice_loss_coef=,
    dec_layers=6,
    hidden_dim=256,
    focal_alpha=0.25,
    # obj_temp=1, # duplicated
    # dataset_file=,
    
    # build_backbone
    lr_backbone=2e-5,
    backbone='dino_resnet50',
    dilation=False,
    position_embedding='sine',
    
    # build_deforamble_transformer
    nheads=8,
    enc_layers=6,
    dim_feedforward=1024,
    dropout=0.1,
    dec_n_points=4,
    enc_n_points=4,
    
    # build_matcher
    set_cost_class=2,
    set_cost_bbox=5,
    set_cost_giou=2,
    
    # bash script
    output_dir='output',
    dataset='TOWOD',
    PREV_INTRODUCED_CLS=0,
    CUR_INTRODUCED_CLS=5,
    train_set='owod_t1_toy_5classes_train', 
    test_set='owod_all_task_test',
    epochs=5,
    model_type='prob',
    obj_loss_coef=8e-4,
    obj_temp=1.3,
    exemplar_replay_selection=True,
    exemplar_replay_max_length=850,
    exemplar_replay_dir='',
    exemplar_replay_cur_file='',
)

## Before Training

In [29]:
# functions from main_open_world.py
def get_datasets(args):
    print(args.dataset)

    train_set = args.train_set
    test_set = args.test_set
    dataset_train = OWDetection(args, args.data_root, image_set=args.train_set, transforms=make_coco_transforms(args.train_set), dataset = args.dataset)
    dataset_val = OWDetection(args, args.data_root, image_set=args.test_set, dataset = args.dataset, transforms=make_coco_transforms(args.test_set))

    print(args.train_set)
    print(args.test_set)
    print(dataset_train)
    print(dataset_val)

    return dataset_train, dataset_val

In [30]:
# before train: experimental setting (no ddp setting)
# main() in main_open_world.py

if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
print(args)

device = torch.device(args.device)

# fix the seed for reproducibility
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

model, criterion, postprocessors, exemplar_selection = build_model(args, mode = args.model_type)
model.to(device)

model_without_ddp = model
print(model_without_ddp)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

dataset_train, dataset_val = get_datasets(args)

sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                pin_memory=True)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
                                drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                pin_memory=True)

DotMap(frozen_weights=None, seed=0, device='cuda', data_root='./data/OWOD', batch_size=2, num_workers=8, lr_linear_proj_names=['reference_points', 'sampling_offsets'], lr=0.0002, lr_backbone_names=['backbone.0'], lr_linear_proj_mult=2e-05, sgd=False, weight_decay=0.0001, pretrain=None, freeze_prob_model=False, num_classes=81, num_queries=100, num_feature_levels=4, aux_loss=True, with_box_refine=False, two_stage=False, masks=False, cls_loss_coef=2, bbox_loss_coef=5, giou_loss_coef=2, dec_layers=6, hidden_dim=256, focal_alpha=0.25, lr_backbone=2e-05, backbone='dino_resnet50', dilation=False, position_embedding='sine', nheads=8, enc_layers=6, dim_feedforward=1024, dropout=0.1, dec_n_points=4, enc_n_points=4, set_cost_class=2, set_cost_bbox=5, set_cost_giou=2, output_dir='output', dataset='TOWOD', PREV_INTRODUCED_CLS=0, CUR_INTRODUCED_CLS=5, train_set='owod_t1_toy_5classes_train', test_set='owod_all_task_test', epochs=5, model_type='prob', obj_loss_coef=0.0008, obj_temp=1.3, exemplar_repla



running with exemplar_replay_selection
DeformableDETR(
  (transformer): DeformableTransformer(
    (encoder): DeformableTransformerEncoder(
      (layers): ModuleList(
        (0): DeformableTransformerEncoderLayer(
          (self_attn): MSDeformAttn(
            (sampling_offsets): Linear(in_features=256, out_features=256, bias=True)
            (attention_weights): Linear(in_features=256, out_features=128, bias=True)
            (value_proj): Linear(in_features=256, out_features=256, bias=True)
            (output_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (linear1): Linear(in_features=256, out_features=1024, bias=True)
          (dropout2): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=1024, out_features=256, bias=True)
          (dropout3): Dropout(p=0.1, inplace=False)
          (norm2): LayerNo

In [31]:
# lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
def match_name_keywords(n, name_keywords):
    out = False
    for b in name_keywords:
        if b in n:
            out = True
            break
    return out

param_dicts = [
    {
        "params":
            [p for n, p in model_without_ddp.named_parameters()
                if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
        "lr": args.lr,
    },
    {
        "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
        "lr": args.lr_backbone,
    },
    {
        "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
        "lr": args.lr * args.lr_linear_proj_mult,
    }
]
if args.sgd:
    optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
                                weight_decay=args.weight_decay)
else:
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

# removed "coco_panoptic" and "coco" branch
base_ds = dataset_val

if args.frozen_weights is not None:
    checkpoint = torch.load(args.frozen_weights, map_location='cpu')
    model_without_ddp.detr.load_state_dict(checkpoint['model'])
    
output_dir = Path(args.output_dir)

if args.pretrain:
    print('Initialized from the pre-training model')
    checkpoint = torch.load(args.pretrain, map_location='cpu')
    state_dict = checkpoint['model']
    msg = model_without_ddp.load_state_dict(state_dict, strict=False)
    print(msg)
    args.start_epoch = checkpoint['epoch'] + 1
    if args.eval: 
        test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, args)
        # return
        
if args.freeze_prob_model:           
    if isinstance(model_without_ddp.prob_obj_head, torch.nn.ModuleList):
        for obj_head in model_without_ddp.prob_obj_head:
            obj_head.freeze_prob_model()
    else:
        model_without_ddp.prob_obj_head.freeze_prob_model()
        
    obj_bn_mean_before=model_without_ddp.prob_obj_head[0].objectness_bn.running_mean

## Training

In [12]:
# training

print(f'Start training from epoch {args.start_epoch} to {args.epochs}')


DotMap(_ipython_display_=DotMap(), _repr_mimebundle_=DotMap())