In [4]:
import argparse
import re
from collections import OrderedDict
import json
import math
import os
import sys
import time
# import subprocess

try:
    import wandb
except ImportError:
    wandb = None

import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms

from datasets import get_dataset
import models
from tokenizer import SimpleTokenizer
from utils import AverageMeter, ProgressMeter, accuracy
import utils
from torchvision.datasets import ImageFolder
from utils import GaussianBlur, Solarize
from losses import DetailCLIPLoss, get_metric_names
import torch.distributed as dist

  from .autonotebook import tqdm as notebook_tqdm


In [10]:

def get_args_parser():
    parser = argparse.ArgumentParser(description='DetailCLIP pre-training and evaluation', add_help=False)
    # Data
    parser.add_argument('--dataset', default='yfcc15m', type=str, choices=['yfcc15m', 'cc3m', 'cc12m', 'coco', 'redcaps'])
    parser.add_argument('--metadata', default='yfcc15m.pkl', type=str,
                    help='path to metadata file (see README for details)')
    parser.add_argument('--root', default='', type=str,
                        help='path to dataset root')
    parser.add_argument('--output-dir', default='./', type=str, help='path where to save, empty for no saving')

    # Data Augmentation
    parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.14, 1.),
        help="""Scale range of the cropped image before resizing, relatively to the origin image.
        Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
        recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
    # Model
    parser.add_argument('--model', default='DetailCLIP_VITB16', type=str)
    parser.add_argument('--mask-ratio', default=0.5, type=float)
    parser.add_argument('--ssl-mlp-dim', default=4096, type=int,
                        help='hidden dim of SimCLR mlp projection head')
    parser.add_argument('--ssl-emb-dim', default=256, type=int,
                        help='output embed dim of SimCLR mlp projection head')
    parser.add_argument('--ssl-scale', default=1.0, type=float,
                        help='loss scale for SimCLR objective')
    parser.add_argument('--ssl-temp', default=0.1, type=float,
                        help='softmax temperature for SimCLR objective')
    parser.add_argument('--resume', default='', type=str, help='path to resume from')
    # Training
    parser.add_argument('--momentum-ema', default=0.996, type=float, help="""Base EMA
    parameter. The value is increased to 1 during training with cosine schedule.""")
    parser.add_argument('--epochs', default=25, type=int)
    parser.add_argument('--warmup-epochs', default=1, type=int)
    parser.add_argument('--start-epoch', default=0, type=int)
    parser.add_argument('--batch-size', default=5, type=int,
                        help='number of samples per-device/per-gpu')
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--base-lr', default=3e-3, type=float)
    parser.add_argument('--lr-start', default=1e-6, type=float,
                        help='initial warmup lr')
    parser.add_argument('--lr-end', default=1e-5, type=float,
                        help='minimum final lr')
    parser.add_argument('--update-freq', default=1, type=int,
                        help='optimizer update frequency (i.e. gradient accumulation steps)')
    parser.add_argument('--wd', default=0.5, type=float)
    parser.add_argument('--betas', default=(0.9, 0.98), nargs=2, type=float)
    parser.add_argument('--eps', default=1e-8, type=float)
    parser.add_argument('--eval-freq', default=1, type=int)
    parser.add_argument('--disable-amp', action='store_true',
                        help='disable mixed-precision training (requires more memory and compute)')
    # System
    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
                        help='number of data loading workers per process')
    parser.add_argument('--evaluate', action='store_true', help='eval only')
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=0, type=int,
                        help='node rank for distributed training')
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument('--dist-url', default='env://', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
    parser.add_argument('--wandb', action='store_true', help='Enable WandB logging')
    parser.add_argument('--descriptions', default='training', type=str)
    parser.add_argument('--port', default=29500, help='port of master addr')
    # Loss
    parser.add_argument('--disable-norm-pix-loss', action='store_true',
                        help='disable normalization of pixel loss for reconstruction')
    parser.add_argument('--clip_loss_weight', default=1.0, type=float, help='weight of clip loss')
    parser.add_argument('--ibot_patch_loss_weight', default=1.0, type=float, help='weight of ibot patch loss')
    parser.add_argument('--ibot_cls_loss_weight', default=1.0, type=float, help='weight of ibot classification loss') 
    parser.add_argument('--reconst_loss_weight', default=1.0, type=float, help='weight of reconstruction loss')
    return parser

In [8]:
def get_model(args):
    print("=> creating model: {}".format(args.model))
    model = getattr(models, args.model)(mask_ratio=args.mask_ratio) # note that args.mask_ratio is by default 0
    model.cuda(args.gpu)

    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], bucket_cap_mb=200,find_unused_parameters=False)

    return model

In [9]:
def get_optim(args, model):
    p_wd, p_non_wd = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue  # frozen weights
        if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n:
            p_non_wd.append(p)
        else:
            p_wd.append(p)

    optim_params = [{"params": p_wd, "weight_decay": args.wd},
                    {"params": p_non_wd, "weight_decay": 0},
                    ]

    optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas,
                                    eps=args.eps, weight_decay=args.wd)
    return optimizer

In [None]:
def load_ckpt(args, model, optimizer, scaler):
    # optionally resume from a checkpoint (takes precedence over autoresume)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading resume checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
            args.start_epoch = epoch
            result = model.load_state_dict(checkpoint['state_dict'], strict=False)
            print(result)
            optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else ()
            scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else ()
            args.best_acc = checkpoint['best_acc']
            print("=> loaded resume checkpoint '{}' (epoch {})"
                  .format(args.resume, epoch))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        # auto-resume from latest checkpoint in output directory
        latest = os.path.join(args.output_dir, 'checkpoint.pt')
        if os.path.isfile(latest):
            print("=> loading latest checkpoint '{}'".format(latest))
            latest_checkpoint = torch.load(latest, map_location='cpu')
            args.start_epoch = latest_checkpoint['epoch']
            model.load_state_dict(latest_checkpoint['state_dict'])
            optimizer.load_state_dict(latest_checkpoint['optimizer'])
            scaler.load_state_dict(latest_checkpoint['scaler'])
            args.best_acc = latest_checkpoint['best_acc']
            print("=> loaded latest checkpoint '{}' (epoch {})"
                  .format(latest, latest_checkpoint['epoch']))

In [None]:
def get_loader(args, tokenizer):
    print("=> creating dataset")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    val_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=args.global_crops_scale, interpolation=3), # 3 is bicubic
        transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
        transforms.RandomApply([Solarize()], p=0.2),
            transforms.ToTensor(),
            normalize
        ])

    train_dataset = get_dataset(train_transform, tokenizer, args)
    cwd = os.path.dirname(os.path.realpath(__file__))
    with open(os.path.join(cwd, 'dataset_catalog.json')) as f:
        root = json.load(f)['imagenet']['path']
    #add val folder for imagenet 1k
    val_dataset = ImageFolder(os.path.join(root, 'val'), val_transform)

    # dist eval resamples data to pad uneven batch sizes
    # make sure num_samples = 0 mod num_gpus for exact acc
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    else:
        train_sampler = None
        val_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
    
    return train_loader, train_sampler, val_loader

In [11]:
def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, momentum_schedule, args):
    batch_time = AverageMeter('Time', ':6.2f')
    data_time = AverageMeter('Data', ':6.2f')
    mem = AverageMeter('Mem (GB)', ':6.3f')
    metric_names = get_metric_names()
    iters_per_epoch = len(train_loader) // args.update_freq
    metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names])
    progress = ProgressMeter(
        iters_per_epoch,
        [batch_time, data_time, mem, *metrics.values()],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for data_iter, inputs in enumerate(train_loader):
        optim_iter = data_iter // args.update_freq
        data_time.update(time.time() - end)

        # update weight decay and learning rate according to their schedule
        it = iters_per_epoch * epoch + optim_iter  # global training iteration
        for k, param_group in enumerate(optimizer.param_groups):
            param_group['lr'] = lr_schedule[it]
        
        online_inputs = [inputs[0][0], inputs[0][1], inputs[1]]
        online_inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in online_inputs]

        m = momentum_schedule[it]  # momentum parameter
        # compute output
        with amp.autocast(enabled=not args.disable_amp):
            outputs = model(*online_inputs, m)
            loss_dict = criterion(outputs, epoch)

            loss = loss_dict['loss']
            loss /= args.update_freq

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()))
            sys.exit(1)

        scaler.scale(loss).backward()

        if (data_iter + 1) % args.update_freq != 0:
            continue

        # compute gradient and do SGD step
        scaler.step(optimizer)
        scaler.update()
        model.zero_grad(set_to_none=True)

        # clamp logit scale to [0, 100]
        logit_scale_e = 0

        utils.get_model(model).logit_scale.data.clamp_(0, 4.6052)   # since logit_scale is defined within the model class it is a learnable parameter that needs to be clamped
        if hasattr(utils.get_model(model),'logit_scale_e'):
            utils.get_model(model).logit_scale_e.data.clamp_(0, 4.6052)
            logit_scale_e = utils.get_model(model).logit_scale_e.exp().item()

        logit_scale = utils.get_model(model).logit_scale.exp().item()

        for k in loss_dict:
            metrics[k].update(loss_dict[k].item(), args.batch_size)

        # measure elapsed time
        batch_time.update(time.time() - end)

        end = time.time()
        
        mem.update(torch.cuda.max_memory_allocated() / 1e9)
        if optim_iter % args.print_freq == 0:
            if utils.is_main_process() and args.wandb:
                wandb.log({**{k: v.item() for k, v in loss_dict.items()},
                        'scaler': scaler.get_scale(),
                        'logit': logit_scale,
                        'logit_e': logit_scale_e,
                        })
            progress.display(optim_iter)

    progress.synchronize()
    return {**{k: v.avg for k, v in metrics.items()},
            'lr': optimizer.param_groups[0]['lr'],
            'logit_scale': logit_scale}
