In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, argparse, yaml, wandb, datetime, logging, random
import numpy as np
import os.path as osp
import torch
import wandb
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from iamcl2r.params import ExperimentParams
from iamcl2r.logger import setup_logger
from iamcl2r.methods import set_method_configs, HocLoss, BCPLoss
from iamcl2r.dataset import create_data_and_transforms, BalancedBatchSampler
from iamcl2r.models import create_model
from iamcl2r.utils import check_params, save_checkpoint, init_distributed_device, is_master, broadcast_object
from iamcl2r.train import train_one_epoch, train_one_clr_epoch, retrieval_acc, train_one_bcp_epoch
from iamcl2r.eval import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config_str = """
root_folder: ./data
method: clr_bcp

# iamcl2r settings
pretrained_model_path: 
- ../pretrained_models_ckpts/learnable/ckpt_0.pt
- ../pretrained_models_ckpts/learnable/ckpt_2.pt
pretrained_backbones:
- resnet18
- resnet18
replace_ids: 
- 1

# training settings
scenario: 'class-incremental'
increment: 50
initial_increment: 50
train_dataset_name: 'cifar_clr'
batch_size: 512
lr: 0.003
momentum: 0.9
min_lr: 0.00001
weight_decay: 0.0005
epochs: 120
bc_lr: 0.03
bc_epochs: 100
rehearsal: 20
"""
config_path = '2tasks_clr_bcp.yaml'
loaded_params = yaml.safe_load(config_str)
args = ExperimentParams()

for k, v in loaded_params.items():
    args.__setattr__(k, v)

args.yaml_name = os.path.basename(config_path)
base_config_name = osp.splitext(osp.basename(config_path))[0]

In [4]:
device = init_distributed_device(args)
args.is_main_process = is_master(args)

if not osp.exists(args.data_path) and args.is_main_process:
    os.makedirs(args.data_path)

checkpoint_path = osp.join(*(args.output_folder, 
                                f"{base_config_name}",
                                f"run-{datetime.datetime.now().strftime('%Y%m%d-%H%M')}"
                            )
                            )
if args.distributed:
    checkpoint_path = broadcast_object(args, checkpoint_path)
args.checkpoint_path = checkpoint_path
if not osp.exists(args.checkpoint_path) and args.is_main_process:
    os.makedirs(args.checkpoint_path)

In [5]:
log_file = f"train-{datetime.datetime.now().strftime('%H%M%S')}-gpu{device.index}.log" if not args.eval_only else f"eval.log"
setup_logger(logfile=os.path.join(*(args.checkpoint_path, log_file)),
            console_log=args.is_main_process, 
            file_log=True, 
            log_level="INFO") 

logger = logging.getLogger('IAM-CL2R-Eval')
os.environ["WANDB_NOTEBOOK_NAME"] = ""
wandb.init(mode="disabled")

2024-07-22 21:50:43,242     wandb.jupyter[224]   ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [6]:
set_method_configs(args, name=args.method)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

args.replace_ids += [0]    
args.replace_ids.sort()     

In [7]:
data = create_data_and_transforms(args)
scenario_train = data["scenario_train"]
scenario_val = data["scenario_val"]
memory = data["memory"]
target_transform = data["target_transform"]

2024-07-22 21:50:43,323        Data-Utils[ 45]    INFO Loading Datasets


Files already downloaded and verified
Files already downloaded and verified


2024-07-22 21:50:45,403        Data-Utils[ 69]    INFO 

Training with 2 tasks.
In the first task there are 50 classes, while the other tasks have 50 classes each.




Subsampling dataset to 300 images per class.


In [8]:
args.classes_at_task = []
args.new_data_ids_at_task = []
args.seen_classes = []

In [13]:
def train(args, net, previous_net, train_loader, val_loader, scenario_train, memory, task_id, device, target_transform=None):
    best_acc = 0
    exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
    include = lambda n, p: not exclude(n, p)
    named_parameters = list(net.named_parameters())
    gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
    rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

    optimizer = optim.SGD(
                        [
                            {"params": gain_or_bias_params, "weight_decay": 0.},
                            {"params": rest_params, "weight_decay": args.weight_decay},
                        ],
                        lr=args.lr, 
                        momentum=args.momentum, 
                        )
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
    scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.min_lr)
    criterion_cls = nn.CrossEntropyLoss().to(device)
    init_epoch = 0
    for epoch in range(init_epoch, args.epochs):
        args.current_epoch = epoch
        train_one_clr_epoch(args, 
                            device,
                            net, 
                            previous_net=previous_net,
                            train_loader=train_loader,
                            scaler=scaler,
                            optimizer=optimizer, 
                            epoch=epoch, 
                            criterion_cls=criterion_cls, 
                            task_id=task_id, 
                            add_loss=None, 
                            target_transform=target_transform)
        # warmup for the first 10 epochs
        if epoch > 10:
            scheduler_lr.step()

        if (epoch + 1) % args.eval_period == 0 or (epoch + 1) == args.epochs:
            acc_val = retrieval_acc(
                args, 
                device,
                net, 
                net,
                val_loader, 
                val_loader,
                n_backward_steps=0,
            )
            if args.is_main_process:
                wandb.log({'val/val_acc': acc_val}) 
                                
            if (acc_val >= best_acc and args.save_best) or ((epoch + 1) == args.epochs and not args.save_best):
                best_acc = acc_val
                if args.is_main_process:
                    wandb.log({'val/best_acc': best_acc}) 
                    save_checkpoint(args, net, optimizer, best_acc, scheduler_lr, backup=False)

        if ((epoch + 1) % args.save_period == 0 or (epoch + 1) == args.epochs) and args.is_main_process:
            save_checkpoint(args, net, optimizer, best_acc, scheduler_lr, backup=True)

## Train task 1 without alignment

In [14]:
task_id = 0
args.current_task_id = task_id
train_task_set = scenario_train[task_id]
new_data_ids = train_task_set.get_classes()
val_dataset = scenario_val[:task_id + 1]

class_in_step = scenario_train[:task_id].nb_classes + len(new_data_ids) if task_id > 0 else train_task_set.nb_classes
args.classes_at_task.append(class_in_step)
args.new_data_ids_at_task.append(new_data_ids)

if task_id in args.replace_ids:
    resume_path = args.pretrained_model_path[args.replace_ids.index(task_id)]
    args.current_backbone = args.pretrained_backbones[args.replace_ids.index(task_id)]
else:
    resume_path = osp.join(*(args.checkpoint_path, f"ckpt_{task_id-1}.pt"))

In [15]:
net = create_model(args, 
                    device=device,
                    resume_path=resume_path,
                    feat_size=args.feat_size,
                    backbone=args.current_backbone,
                    n_backward_vers=task_id,
                    )

2024-07-22 21:54:20,985             Model[243]    INFO Creating model with 50 classes and 512 features
2024-07-22 21:54:20,986             Model[251]    INFO Loading a model with config: {'num_classes': 50, 'feat_size': 512, 'backbone': 'resnet18', 'pretrained': False}


2024-07-22 21:54:21,160             Model[268]    INFO Resuming Weights from ../pretrained_models_ckpts/learnable/ckpt_0.pt


In [17]:
batchsampler = None
batch_size = args.batch_size
if task_id > 0: 
    if args.rehearsal > 0:
        mem_x, mem_y, mem_t = memory.get()
        train_task_set.add_samples(mem_x, mem_y, mem_t)
    batchsampler = BalancedBatchSampler(train_task_set, n_classes=train_task_set.nb_classes, 
                                        batch_size=args.batch_size, n_samples=len(train_task_set._x), 
                                        seen_classes=args.seen_classes, rehearsal=args.rehearsal)
    train_loader = DataLoader(train_task_set, batch_sampler=batchsampler, num_workers=args.num_workers) 
else:
    if (args.use_subsampled_dataset and args.img_per_class * args.classes_at_task[0] < args.batch_size):
        batch_size = args.img_per_class * args.classes_at_task[0]
        logger.info(f"Original batch size of {batch_size} is too high.")
        logger.info(f"In current task there are {args.classes_at_task[0]} classes and the dataset has {args.img_per_class} img per class.")
        logger.info(f"Setting batch to {batch_size} images per class")
    train_loader = DataLoader(train_task_set, 
                                batch_size=batch_size, shuffle=True, 
                                drop_last=True, num_workers=args.num_workers) 
    
val_loader = DataLoader(val_dataset, 
                        batch_size=batch_size, shuffle=False,
                        drop_last=False, num_workers=args.num_workers)
    
best_acc = 0
logger.info(f"Starting training new network first..")
logger.info(f"Starting Epoch Loop at task {task_id+1}/{scenario_train.nb_tasks}")

2024-07-22 21:54:34,509     IAM-CL2R-Eval[ 26]    INFO Starting training new network first..
2024-07-22 21:54:34,510     IAM-CL2R-Eval[ 27]    INFO Starting Epoch Loop at task 1/2


In [18]:
previous_net = None
train(args, net, previous_net, train_loader, val_loader, scenario_train, memory, task_id, device, target_transform=target_transform)

2024-07-22 21:54:48,159             Utils[ 47]    INFO Task 1 Epoch [1]/[120]	 Training Loss: 4.8386	 Training Accuracy: 25.08	 LR: 0.003000	 Time: 8.45
2024-07-22 21:54:56,983             Utils[ 47]    INFO Task 1 Epoch [2]/[120]	 Training Loss: 4.5489	 Training Accuracy: 28.63	 LR: 0.003000	 Time: 8.82
2024-07-22 21:55:05,461             Utils[ 47]    INFO Task 1 Epoch [3]/[120]	 Training Loss: 4.3461	 Training Accuracy: 30.88	 LR: 0.003000	 Time: 8.47
2024-07-22 21:55:13,750             Utils[ 47]    INFO Task 1 Epoch [4]/[120]	 Training Loss: 4.2144	 Training Accuracy: 33.32	 LR: 0.003000	 Time: 8.28
2024-07-22 21:55:22,130             Utils[ 47]    INFO Task 1 Epoch [5]/[120]	 Training Loss: 4.1205	 Training Accuracy: 33.78	 LR: 0.003000	 Time: 8.38
2024-07-22 21:55:24,051 Performance-Metrics[ 41]    INFO Finish Calculating 49 template features.
2024-07-22 21:55:24,052 Performance-Metrics[ 79]    INFO => calculate rank
2024-07-22 21:55:24,053 Performance-Metrics[ 48]    INFO query

## Training on Task 2

In [None]:
task_id = 1
args.current_task_id = task_id
train_task_set = scenario_train[task_id]
new_data_ids = train_task_set.get_classes()
val_dataset = scenario_val[:task_id + 1]

class_in_step = scenario_train[:task_id].nb_classes + len(new_data_ids) if task_id > 0 else train_task_set.nb_classes
args.classes_at_task.append(class_in_step)
args.new_data_ids_at_task.append(new_data_ids)

if task_id in args.replace_ids:
    resume_path = args.pretrained_model_path[args.replace_ids.index(task_id)]
    args.current_backbone = args.pretrained_backbones[args.replace_ids.index(task_id)]
else:
    resume_path = osp.join(*(args.checkpoint_path, f"ckpt_{task_id-1}.pt"))

In [None]:
previous_net = create_model(args,
                            device=device,
                            resume_path=resume_path, 
                            new_classes=0,   # not expanding classifier for old model
                            feat_size=args.feat_size,
                            backbone=args.current_backbone,
                            n_backward_vers=task_id,
                        )
# set false to require grad for all parameters
for param in previous_net.parameters():
    param.requires_grad = False
previous_net.eval() 

net = create_model(args, 
                    device=device,
                    resume_path=resume_path,
                    feat_size=args.feat_size,
                    backbone=args.current_backbone,
                    n_backward_vers=task_id,
                    )