In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')
from utils.models import get_backbone
from utils.isomaxplus import IsoMaxPlusLossFirstPart, IsoMaxPlusLossSecondPart

In [3]:
import os
from glob import glob

import torch
import torchvision
from torch import nn

import numpy as np
from typing import Optional, Dict, Any


def load_model(_ckpt_path, num_classes=2):
    ckpt = torch.load(_ckpt_path, map_location="cpu")
    _model = torchvision.models.resnet50()
    backbone = torch.nn.Sequential(*list(_model.children())[:-1])
    emb_dim = _model.fc.in_features
    head = IsoMaxPlusLossFirstPart(emb_dim, num_classes)

    _model = nn.Sequential(backbone, nn.Flatten(), head)
    _model.load_state_dict(ckpt, strict=False)
    return _model


def get_pre_extracted_features(_ckpt_dir: str, _set_name: str) -> np.ndarray:
    pre_extracted_feats = np.load(f'{_ckpt_dir}/feats_{_set_name}.npy', mmap_mode='r')
    pre_extracted_feats = ((pre_extracted_feats - pre_extracted_feats.mean(axis=1, keepdims=True)) /
                           pre_extracted_feats.std(axis=1, keepdims=True))
    return pre_extracted_feats

In [4]:
from utils import datasets as dsets
from torch.utils.data import DataLoader

stage = 1
trn_split = 'va'
train_attr = 'yes'
worst_metric = 'wga_val'
subsample_type = 'group'
dataset_name = 'Waterbirds'
data_dir = '/scratch/ssd004/scratch/minht/datasets/'
ckpt_path = '/scratch/ssd004/scratch/minht/checkpoints/sd0/Waterbirds/13574640/'

workers = 6
batch_size_train, batch_size_eval = 64, 128

datasets, dataloaders = dict(), dict()
datasets['val'] = vars(dsets)[dataset_name](data_dir, 'va', None)
datasets['test'] = vars(dsets)[dataset_name](data_dir, 'te', None)

trn_split = 'va'
for set_name in ['val', 'test']:
    datasets[set_name].feats = get_pre_extracted_features(ckpt_path, set_name)
    dataloaders[set_name] = DataLoader(dataset=datasets[set_name], num_workers=workers, pin_memory=False,
                                       batch_size=batch_size_eval, shuffle=False)

datasets['train'] = vars(dsets)[dataset_name](
    data_dir, trn_split, None, train_attr=train_attr, subsample_type=subsample_type, stage=stage,
    pre_extracted_feats=datasets['val'].feats)
dataloaders['train'] = DataLoader(datasets['train'], batch_size=batch_size_train, drop_last=True, shuffle=True, num_workers=workers, pin_memory=False)

In [5]:
from tqdm import tqdm
# tqdm._instances.clear()

from utils.misc import get_scheduler_func

epochs = 20
cov_reg = 5e5
device = 'cuda'
wd_weight = 10
prototypes_ensemble = None

model = load_model(ckpt_path + 'ckpt_last.pt')
model.to(device)
criterion = IsoMaxPlusLossSecondPart(entropic_scale=30, reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
lr_scheduler = get_scheduler_func('onecycle', 1e-3, epochs, len(dataloaders['train']))(optimizer)

for epoch in range(epochs):
    model.train()
    running_loss, running_clf, running_cov, running_correct, total = 0.0, 0.0, 0.0, 0, 0
    pbar = tqdm(dataloaders['train'], desc=f"Epoch {epoch+1}", leave=False if epoch < epochs-1 else True,)

    for _, inputs, labels, _, _, feats in pbar:
        feats = feats.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model[-1](feats)
        clf_loss = criterion(outputs, labels)

        head = model[-1]
        wd = torch.einsum('ijk,ilk->ijl', [head.prototypes[:, None], head.prototypes[:, None]]) * wd_weight
        wd = wd.squeeze().mean()
        loss = clf_loss + wd

        cov_loss = torch.tensor(0.0, device=device)
        if (prototypes_ensemble is not None) and (cov_reg > 0):
            _prototypes = torch.cat([head.prototypes[:, None], prototypes_ensemble], dim=1)
            n_pro, n_dim = _prototypes.shape[1:]
            cov = torch.einsum('ijk,ilk->ijl', [_prototypes, _prototypes]) / (n_dim - 1)
            cov_loss = torch.abs(cov[:, 0, 1:].sum(1).div(n_pro).mean())
            loss += cov_loss * cov_reg

        loss.backward()
        optimizer.step()
        if lr_scheduler is not None:
            lr_scheduler.step()

        preds = outputs.argmax(dim=1)
        correct = (preds == labels).sum().item()
        running_loss += loss.item()
        running_clf += clf_loss.item()
        running_cov += cov_loss.item()
        running_correct += correct
        total += labels.size(0)

        pbar.set_postfix({
            'loss': running_loss / (total // labels.size(0)),
            'clf': running_clf / (total // labels.size(0)),
            'cov': running_cov / (total // labels.size(0)),
            'acc': f"{running_correct / total:.2%}"
        })


Epoch 20: 100%|██████████| 8/8 [00:00<00:00, 25.36it/s, loss=0.0279, clf=0.0081, cov=0, acc=100.00%] 
