In [1]:
#!/usr/bin/env python3

""" ImageNet Training Script

This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.

This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/blob/main/imagenet/README.md)
(https://github.com/pytorch/examples/tree/master/imagenet)

NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
# from setproctitle import setproctitle
# setproctitle("python3 utils/train.py")

" ImageNet Training Script\n\nThis is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet\ntraining results with some of the latest networks and training techniques. It favours canonical PyTorch\nand standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed\nand training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.\n\nThis script was started from an early version of the PyTorch ImageNet example\n(https://github.com/pytorch/examples/blob/main/imagenet/README.md)\n(https://github.com/pytorch/examples/tree/master/imagenet)\n\nNVIDIA CUDA specific speedups adopted from NVIDIA Apex examples\n(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)\n\nHacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)\n"

In [None]:
# pip install nvidia-pyindex
# pip install onnx-graphsurgeon

In [2]:
import warnings
#warnings.filterwarnings("ignore")
import cv2
import gc
import torch

import argparse
import logging
import math
import os
import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torchvision.utils
import yaml

# timm version 0.9.2
from timm import utils
from timm.data import (AugMixDataset, FastCollateMixup, Mixup, create_dataset,
                       create_loader, resolve_data_config)
from utils.build import get_exp_fn
from timm.layers import (convert_splitbn_model, convert_sync_batchnorm,
                         set_fast_norm)
from timm.loss import (BinaryCrossEntropy, JsdCrossEntropy,
                       LabelSmoothingCrossEntropy, SoftTargetCrossEntropy)
from timm.models import (create_model, load_checkpoint, model_parameters,
                         resume_checkpoint, safe_model_name)
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from tqdm import tqdm

In [3]:
try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as ApexDDP
    from apex.parallel import convert_syncbn_model
    has_apex = True
except ImportError:
    has_apex = False

has_native_amp = False
try:
    if getattr(torch.cuda.amp, 'autocast') is not None:
        has_native_amp = True
except AttributeError:
    pass

try:
    import wandb
    has_wandb = True
except ImportError:
    has_wandb = False

try:
    from functorch.compile import memory_efficient_fusion
    has_functorch = True
except ImportError as e:
    has_functorch = False

has_compile = hasattr(torch, 'compile')

In [4]:
# OpenCV should automatically determine the optimal number of threads for parallel processing.
cv2.setNumThreads(0)
# disable the usage of OpenCL (Open Computing Language) for GPU acceleration in OpenCV. force OpenCL to not use GPU.
cv2.ocl.setUseOpenCL(False)
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"]="1"


# PyTorch initializes the CuDNN library and prepares it for subsequent convolution operations. 
# This step is useful in cases where the first convolution operation can incur significant overhead due to the initialization of CuDNN. 
# By calling this function beforehand, the initialization is forced to happen upfront, potentially reducing the latency during subsequent convolution operations.
def force_cudnn_initialization():
    s = 32
    dev = torch.device('cuda')
    torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
    
_logger = logging.getLogger(__name__)

In [5]:
# rsna argument order for different folds.
# -expn reproduce_train_fold_2 --smoothing 0.1 (soft_pos = 0.9) --opt sgd --drop 0.5 
# -expn reproduce_train_fold_3 --smoothing 0.1 (soft_pos = 0.9) ...
# -expn stage1_reproduce_train_fold_0 --smoothing 0.2 (soft_pos = 0.8) ...
# -expn stage1_reproduce_train_fold_1 --smoothing 0.2 (soft_pos = 0.8) ...
# -expn stage1_reproduce_train_fold_0 --smoothing 0.1 --initial-checkpoint {finetune over stage1_reproduce_train_fold_0} ...
# -expn stage1_reproduce_train_fold_1 --smoothing 0.1 --initial-checkpoint {finetune over stage1_reproduce_train_fold_1} ...



In [6]:
# # The first arg parser parses out only the --config argument, this argument is used to
# # load a yaml file containing key-values that override the defaults for the main parser below
config_parser = argparse.ArgumentParser(description='Training Config', add_help=False)
config_parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments')


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# rsna argument - another configuration file.
parser.add_argument('--exp', nargs='?', default='utils/train.py', type=str, metavar='FILE', help='Python exp config')
# rsna argument - exp meta keywords- {one_pos_mode i.e., at least one positive sample per iteration, ...}
parser.add_argument('--exp-kwargs', nargs='*', default={'fold_idx':2, 'num_sched_epochs':10, 'num_epochs':350,
    'start_ratio':0.1429, 'end_ratio':0.1429, 'one_pos_mode':True}, action=utils.ParseKwargs)

# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside the dataset group because it is positional.
group.add_argument('--data', nargs='?', metavar='DIR', const=None,
                    help='path to dataset (positional is *deprecated*, use --data-dir)')
group.add_argument('--data-dir', metavar='DIR',
                    help='path to dataset (root dir)')
group.add_argument('--dataset', metavar='NAME', default='',
                    help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
                   help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
                   help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False,
                   help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
                   help='path to class to idx mapping file (default: "")')

_StoreAction(option_strings=['--class-map'], dest='class_map', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='path to class to idx mapping file (default: "")', metavar='FILENAME')

In [7]:
# Model parameters
group = parser.add_argument_group('Model parameters')
# rsna argument - model
group.add_argument('--model', default='convnext_small.fb_in22k_ft_in1k_384', type=str, metavar='MODEL',
                   help='Name of model to train (default: "resnet50")')
# rsna argument - pretrained
group.add_argument('--pretrained', action='store_true', default=True,
                   help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                   help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
                   help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
                   help='prevent resume of optimizer state when resuming model')
# rsna argument - number of output classes
group.add_argument('--num-classes', type=int, default=1, metavar='N',
                   help='number of label classes (Model default if None)')
# rsna argument - gp
group.add_argument('--gp', default='max', type=str, metavar='POOL',
                   help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
                   help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
                   help='Image input channels (default: None => 3)')
# rsna argument - input size
group.add_argument('--input-size', default=[3,2048,1024], nargs=3, type=int,
                   metavar='N N N',
                   help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
# rsna argument - crop percentage
group.add_argument('--crop-pct', default=None, type=float,
                   metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                   help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                   help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
                   help='Image resize interpolation type (overrides model)')
# rsna argument - batch size
group.add_argument('-b', '--batch-size', type=int, default=1, metavar='N',
                   help='Input batch size for training (default: 128)')
# rsna argument - validation batch size
group.add_argument('-vb', '--validation-batch-size', type=int, default=2, metavar='N',
                   help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
                   help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
                   help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False,
                   help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
                   help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)


ParseKwargs(option_strings=['--model-kwargs'], dest='model_kwargs', nargs='*', const=None, default={}, type=None, choices=None, help=None, metavar=None)

In [8]:
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
                             help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
                             help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
                             help="Enable AOT Autograd support.")

_StoreTrueAction(option_strings=['--aot-autograd'], dest='aot_autograd', nargs=0, const=True, default=False, type=None, choices=None, help='Enable AOT Autograd support.', metavar=None)

In [9]:
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
# rsna argument - optimizer
group.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                   help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
                   help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                   help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
                   help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
                   help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                   help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
                   help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
                   help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)

ParseKwargs(option_strings=['--opt-kwargs'], dest='opt_kwargs', nargs='*', const=None, default={}, type=None, choices=None, help=None, metavar=None)

In [10]:
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
                   help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
                   help='Apply LR scheduler step on update instead of epoch end.')
# rsna argument - learning rate
group.add_argument('--lr', type=float, default=.009854, metavar='LR',
                   help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
                   help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
                   help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
                   help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                   help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                   help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                   help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                   help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
                   help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
                   help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
                   help='learning rate k-decay for cosine/poly (default: 1.0)')
# rsna argument - warmup learning rate
group.add_argument('--warmup-lr', type=float, default=.009001, metavar='LR',
                   help='warmup learning rate (default: 1e-5)')
# rsna argument - minimum learning rate
group.add_argument('--min-lr', type=float, default=.009001, metavar='LR',
                   help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
# rsna argument - 
group.add_argument('--epochs', type=int, default=300, metavar='N',
                   help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
                   help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
                   help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
                   help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
                   help='epoch interval to decay LR')
# rsna argument - 
group.add_argument('--warmup-epochs', type=int, default=0, metavar='N',
                   help='epochs to warmup LR, if scheduler supports')# default=5
group.add_argument('--warmup-prefix', action='store_true', default=False,
                   help='Exclude warmup period from decay schedule.')
# rsna argument - 
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
                   help='epochs to cooldown LR at min_lr, after cyclic schedule ends')# default=0
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                   help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                   help='LR decay rate (default: 0.1)')


_StoreAction(option_strings=['--decay-rate', '--dr'], dest='decay_rate', nargs=None, const=None, default=0.1, type=<class 'float'>, choices=None, help='LR decay rate (default: 0.1)', metavar='RATE')

In [11]:
# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
# rsna argument -
group.add_argument('--no-aug', action='store_true', default=True,
                   help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                   help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
                   help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
                   help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
                   help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                   help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
                   help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
                   help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
                   help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
                   help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
# rsna argument - 
group.add_argument('--bce-loss', action='store_true', default=True,
                   help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None,
                   help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
                   help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
                   help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
                   help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
                   help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
                   help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
                   help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                   help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
                   help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
                   help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
                   help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                   help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
# rsna argument - 
group.add_argument('--smoothing', type=float, default=0.1,
                   help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
                   help='Training interpolation (random, bilinear, bicubic default: "random")')
# rsna argument - 
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                   help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
                   help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
# rsna argument - drop path rate
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
                   help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                   help='Drop block rate (default: None)')

_StoreAction(option_strings=['--drop-block'], dest='drop_block', nargs=None, const=None, default=None, type=<class 'float'>, choices=None, help='Drop block rate (default: None)', metavar='PCT')

In [12]:
# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
                   help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
                   help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
                   help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
                   help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
                   help='Enable separate BN layers per augmentation split.')

_StoreTrueAction(option_strings=['--split-bn'], dest='split_bn', nargs=0, const=True, default=False, type=None, choices=None, help='Enable separate BN layers per augmentation split.', metavar=None)

In [13]:
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
# rsna argument - sets up an exponential moving average of model weights if specified.
group.add_argument('--model-ema', action='store_true', default=True,
                   help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                   help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
# rsna argument - 
group.add_argument('--model-ema-decay', type=float, default=0.9998,
                   help='decay factor for model weights moving average (default: 0.9998)')

_StoreAction(option_strings=['--model-ema-decay'], dest='model_ema_decay', nargs=None, const=None, default=0.9998, type=<class 'float'>, choices=None, help='decay factor for model weights moving average (default: 0.9998)', metavar=None)

In [14]:
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
                   help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
                   help='worker seed mode (default: all)')
# rsna argument -
group.add_argument('--log-interval', type=int, default=500, metavar='N',
                   help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
                   help='how many batches to wait before writing recovery checkpoint')
# rsna argument -
group.add_argument('--checkpoint-hist', type=int, default=100, metavar='N',
                   help='number of checkpoints to keep (default: 10)')
# rsna argument -
group.add_argument('-j', '--workers', type=int, default=11, metavar='N',
                   help='how many training processes to use (default: 4)')
# rsna argument - save images for debugging - one can watch how augmentation affect images.
group.add_argument('--save-images', action='store_true', default=False,
                   help='save images of input batches every log interval for debugging')
# rsna argument - enable mixed precision training
group.add_argument('--amp', action='store_true', default=True,
                   help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
                   help='lower precision AMP dtype (default: float16)')
# rsna argument
group.add_argument('--amp-impl', default='native', type=str,
                   help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
                   help='Force broadcast buffers for native DDP (Distributed Data Parallel) to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
                   help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
                   help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
                   help='path to output folder (default: none, current dir)')

# rsna argument - experiment name.
group.add_argument('-expn', '--experiment-name', default='fully_reproduce_train_fold_2', type=str, metavar='NAME',
                   help='name of train experiment, name of sub-folder for output')
# rsna argument - evaluation metric - groupby(['patient_id', 'laterality']).mean() and pfbeta
group.add_argument('--eval-metric', default='gbmean_best_pfbeta', type=str, metavar='EVAL_METRIC',
                   help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
                   help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
                   help='use the multi-epochs-loader to save time at the beginning of every epoch')
# rsna argument - wandb
group.add_argument('--log-wandb', action='store_true', default=False,
                   help='log training and validation metrics to wandb')
group.add_argument('--pos-weight', type=float, default=-1, help='Positive weight used for loss computation.')
# rsna argument - 
group.add_argument('--dense-ckpt-epochs', default=[10, 18], nargs=2, type=int,
                   metavar='N N',
                   help='Validate and save ckpts more frequently (validate >1 times per epoch)')# default=[9999, -1]
# rsna argument - 
group.add_argument('--dense-ckpt-bins', type=int, default=2, metavar='N',
                   help='Validate and save ckpts more frequently (validate >1 times per epoch)') 

_StoreAction(option_strings=['--dense-ckpt-bins'], dest='dense_ckpt_bins', nargs=None, const=None, default=2, type=<class 'int'>, choices=None, help='Validate and save ckpts more frequently (validate >1 times per epoch)', metavar='N')

In [15]:
def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    #args = parser.parse_args(remaining)
    args, _ = parser.parse_known_args()
    
    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text


In [16]:
def train_one_epoch(
        exp,    # change
        eval_loader,    # change
        val_loss_fn,    # change
        best_metric,    # change
        best_epoch,     # change
        epoch,
        model,
        loader,
        optimizer,
        loss_fn,
        args,
        device=torch.device('cuda'),
        lr_scheduler=None,
        saver=None,
        output_dir=None,
        amp_autocast=suppress,
        loss_scaler=None,
        model_ema=None,
        mixup_fn=None,
        num_updates = None,     # change
):
    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
        if args.prefetcher and loader.mixup_enabled:
            loader.mixup_enabled = False
        elif mixup_fn is not None:
            mixup_fn.mixup_enabled = False

    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
    batch_time_m = utils.AverageMeter()
    data_time_m = utils.AverageMeter()
    losses_m = utils.AverageMeter()
    
    model.train()

    end = time.time()
    num_batches_per_epoch = len(loader)
    last_idx = num_batches_per_epoch - 1
    if num_updates is None:
        num_updates = epoch * num_batches_per_epoch
    else:
        print('Current number of updates:', num_updates)
    epoch_num_updates = 0
    dense_ckpt_start_epoch, dense_ckpt_end_epoch = args.dense_ckpt_epochs
    dense_ckpt_interval = num_batches_per_epoch // args.dense_ckpt_bins + 1
    print(f'DENSE CKPT START/END: {args.dense_ckpt_epochs}, interval = {dense_ckpt_interval}')
    for batch_idx, (input, target) in tqdm(enumerate(loader)):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.to(device), target.to(device)
            if mixup_fn is not None:
                input, target = mixup_fn(input, target)
                
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        with amp_autocast():
            # draw_graph(model, input_data = input, expand_nested=True, save_graph=True, device='cuda').visual_graph
            output = model(input)
            loss = loss_fn(output, target)

        # break
        
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        optimizer.zero_grad()
        if loss_scaler is not None:
            loss_scaler(
                loss, optimizer,
                clip_grad=args.clip_grad,
                clip_mode=args.clip_mode,
                parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
                create_graph=second_order
            )
        else:
            loss.backward(create_graph=second_order)
            if args.clip_grad is not None:
                utils.dispatch_clip_grad(
                    model_parameters(model, exclude_head='agc' in args.clip_mode),
                    value=args.clip_grad,
                    mode=args.clip_mode
                )
            optimizer.step()

        if model_ema is not None:
            model_ema.update(model)

        torch.cuda.synchronize()

        num_updates += 1
        epoch_num_updates += 1
        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if utils.is_primary(args):
                _logger.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:#.4g} ({loss.avg:#.3g})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        batch_time=batch_time_m,
                        rate=input.size(0) * args.world_size / batch_time_m.val,
                        rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m)
                )

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True
                    )

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(epoch, batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        end = time.time()
        # end for


        ##########################################################
        # perform dense validate and save model checkpointing
        # START

        if epoch >= dense_ckpt_start_epoch and epoch <= dense_ckpt_end_epoch and epoch_num_updates % dense_ckpt_interval == 0:
            print('\n-----DENSE CKPT VALIDATION-----\n')
            
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if utils.is_primary(args):
                    _logger.info("Distributing BatchNorm running means and vars")
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
            
            temp_epoch = round(epoch - 1 + epoch_num_updates / num_batches_per_epoch , 2) 
            if args.save_results_dir is not None:
                plot_save_path = os.path.join(args.save_results_dir, f'plot_{temp_epoch}.jpg')
                ema_plot_save_path = os.path.join(args.save_results_dir, f'ema_plot_{temp_epoch}.jpg')
                pred_save_path = os.path.join(args.save_results_dir, f'pred_{temp_epoch}.csv')
            else:
                plot_save_path = None
                ema_plot_save_path = None

            eval_metrics = None
            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                eval_metrics = validate(
                    exp,
                    model_ema.module,
                    eval_loader,
                    val_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    log_suffix=' (EMA)',
                    plot_save_path = ema_plot_save_path,
                    pred_save_path = pred_save_path,
                )
            else:
                eval_metrics = validate(
                exp,
                model,
                eval_loader,
                val_loss_fn,
                args,
                amp_autocast=amp_autocast,
                plot_save_path = plot_save_path,
                pred_save_path = pred_save_path,
            )

            # primary process/rank only
            if output_dir is not None:
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
                utils.update_summary(
                    temp_epoch,
                    OrderedDict([('loss', losses_m.avg)]),
                    eval_metrics,
                    filename=os.path.join(output_dir, 'summary.csv'),
                    lr=sum(lrs) / len(lrs),
                    write_header=best_metric is None,
                    log_wandb=args.log_wandb and has_wandb,
                )

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[args.eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(temp_epoch, metric=save_metric)

            model.train()
            print('\n---------------------------------------\n')
        # END
        ##########################################################


    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)]), num_updates, best_metric, best_epoch


In [17]:
def validate(
        exp,
        model,
        loader,
        loss_fn,
        args,
        device=torch.device('cuda'),
        amp_autocast=suppress,
        log_suffix='',
        plot_save_path = None,
        pred_save_path = None,
):
    batch_time_m = utils.AverageMeter()
    losses_m = utils.AverageMeter()
    top1_m = utils.AverageMeter()
    top5_m = utils.AverageMeter()

    model.eval()

    end = time.time()
    last_idx = len(loader) - 1

    # buffer
    # we use modified OrderedDistributedSampler, a little change from timm's
    val_len = len(loader.dataset)
    num_samples_per_rank = int(math.ceil(val_len / args.world_size))
    temp_total_size = num_samples_per_rank * args.world_size
    if utils.is_primary(args):
        _logger.info(
            f'Val samples: {val_len}, buffer size: {temp_total_size}')

    targets = torch.zeros((temp_total_size,), requires_grad=False)
    preds = torch.zeros_like(targets)
    sample_weights =  torch.zeros_like(targets)

    with torch.no_grad():
        cur_sample_idx = 0
        for batch_idx, (input, target) in tqdm(enumerate(loader)):
            last_batch = batch_idx == last_idx
            if not args.prefetcher:
                input = input.to(device)
                target = target.to(device)
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            with amp_autocast():
                output = model(input)
                if isinstance(output, (tuple, list)):
                    output = output[0]

                # augmentation reduction
                reduce_factor = args.tta
                if reduce_factor > 1:
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    target = target[0:target.size(0):reduce_factor]

                num_classes = output.shape[-1]
                if num_classes == 1:
                    loss = loss_fn(output, target.float().view(-1, 1))
                else:
                    loss = loss_fn(output, target)
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

            if args.distributed:
                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
                acc1 = utils.reduce_tensor(acc1, args.world_size)
                acc5 = utils.reduce_tensor(acc5, args.world_size)
            else:
                reduced_loss = loss.data

            if device.type == 'cuda':
                torch.cuda.synchronize()

            losses_m.update(reduced_loss.item(), input.size(0))
            top1_m.update(acc1.item(), output.size(0))
            top5_m.update(acc5.item(), output.size(0))

            # CUSTOM METRIC
            if args.distributed:
                # gather preds
                # @TODO: only available in torch > 1.13 ?
                # output = utils.all_gather_tensor_into_tensor(output, args.world_size)
                output = utils.all_gather_tensor(output, args.world_size)
                # is the list in ordered by rank?
                # https://discuss.pytorch.org/t/order-of-the-list-returned-by-torch-distributed-all-gather/125273
                # https://github.com/pytorch/pytorch/issues/23144
                output = torch.cat(output, dim = 0)
            
                # gather targets
                target = utils.all_gather_tensor(target, args.world_size)
                target = torch.cat(target, dim = 0)

            if output.shape[-1] == 2:
                prob = output.softmax(dim = 1)[:, 1]
            elif output.shape[-1] == 1:
                prob = output.sigmoid().view(-1, )
            else:
                raise AssertionError()
            
            nan_mask = torch.isnan(prob)
            if torch.sum(nan_mask.long()) > 0:
                print('CONTAIN NAN:', prob[nan_mask], output[nan_mask])
                prob[nan_mask] = 0.
            # prob = torch.nan_to_num(prob, nan = 0.0)
            
            output = prob
            _targets = target
            _preds = output
            _sample_weights = torch.ones_like(_preds)

            cur_bs = output.size(0)
            targets[cur_sample_idx:cur_sample_idx + cur_bs] = _targets
            preds[cur_sample_idx:cur_sample_idx + cur_bs] = _preds
            sample_weights[cur_sample_idx:cur_sample_idx + cur_bs] = _sample_weights
            cur_sample_idx += cur_bs

            batch_time_m.update(time.time() - end)
            end = time.time()
            if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
                log_name = 'Test' + log_suffix
                _logger.info(
                    '{0}: [{1:>4d}/{2}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        log_name, batch_idx, last_idx,
                        batch_time=batch_time_m,
                        loss=losses_m,
                        top1=top1_m,
                        top5=top5_m)
                )

    # truncate duplicated samples (tail)
    targets = targets[:val_len].cpu().numpy()
    preds = preds[:val_len].cpu().numpy()
    # why nan = 1 ?
    preds = np.nan_to_num(preds, nan=1, posinf=1, neginf=0)
    sample_weights = sample_weights[:val_len].cpu().numpy()
    val_df = loader.dataset.get_df()
    val_df['preds'] = preds
    val_df['targets'] = targets
    val_df['sample_weights'] = sample_weights
    assert (val_df['targets'] == val_df['cancer']).all()

    if utils.is_primary(args):
        additional_info = True
        val_df.to_csv(pred_save_path, index = False)
    else:
        additional_info = False
    
    metric_results = exp.compute_metrics(
        val_df,
        plot_save_path,
        additional_info = additional_info
    )

    metrics =[('loss', losses_m.avg), ('top1', top1_m.avg)]         
    metrics.extend([(k, v) for k, v in metric_results.items()])
    metrics = OrderedDict(metrics)
    # print('---------------')
    # print(metrics)
    # print('---------------')
    return metrics


In [18]:
# import graphviz
# from torchview import draw_graph
# graphviz.set_jupyter_format('png')

'svg'

In [19]:
def main():
    utils.setup_default_logging()
    args, args_text = _parse_args()
    
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
    args.prefetcher = not args.no_prefetcher
    device = utils.init_distributed_device(args)
    # prevent OOM error while initializing CUDNN
    force_cudnn_initialization()
    torch.cuda.empty_cache()

    args.device = device
    
    # if distributed training is enabled, the code sets up training in a distributed mode with multiple processes.
    if args.distributed:
        _logger.info(
            'Training in distributed mode with multiple processes, 1 device per process.'
            f'Process {args.rank}, total {args.world_size}, device {args.device}.')
    else:
        _logger.info(f'Training with a single process on 1 device ({args.device}).')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    amp_dtype = torch.float16
    if args.amp:
        if args.amp_impl == 'apex':
            assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
            use_amp = 'apex'
            assert args.amp_dtype == 'float16'
        else:
            assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
            use_amp = 'native'
            assert args.amp_dtype in ('float16', 'bfloat16')
        if args.amp_dtype == 'bfloat16':
            amp_dtype = torch.bfloat16

    utils.random_seed(args.seed, args.rank)

    if args.fuser:
        utils.set_jit_fuser(args.fuser)
    if args.fast_norm:
        set_fast_norm()

    ### BUILD EXP
    exp_fn = get_exp_fn(args.exp)
    exp = exp_fn(args)

    ### BUILD MODEL
    model = exp.build_model()
    
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.grad_checkpointing:
        model.set_grad_checkpointing(enable=True)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits
    args.num_aug_splits = num_aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.to(device=device)
    if args.channels_last:
        model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        args.dist_bn = ''  # disable dist_bn when sync BN active
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN used with Apex AMP
            # WARNING this won't currently work with models using BatchNormAct2d
            model = convert_syncbn_model(model)
        else:
            model = convert_sync_batchnorm(model)
        if utils.is_primary(args):
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)
    elif args.torchcompile:
        # FIXME dynamo might need move below DDP (Distributed Data Parallel) wrapping? TBD
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
        torch._dynamo.reset()
        model = torch.compile(model, backend=args.torchcompile)
    elif args.aot_autograd:
        assert has_functorch, "functorch is needed for --aot-autograd"
        model = memory_efficient_fusion(model)

    if not args.lr:
        global_batch_size = args.batch_size * args.world_size
        batch_ratio = global_batch_size / args.lr_base_size
        if not args.lr_base_scale:
            on = args.opt.lower()
            args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
        if args.lr_base_scale == 'sqrt':
            batch_ratio = batch_ratio ** 0.5
        args.lr = args.lr_base * batch_ratio
        if utils.is_primary(args):
            _logger.info(
                f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
                f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')

    ### BUILD OPTIMIZER
    optimizer = exp.build_optimizer(model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        assert device.type == 'cuda'
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if utils.is_primary(args):
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
        if device.type == 'cuda':
            loss_scaler = NativeScaler()
        if utils.is_primary(args):
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if utils.is_primary(args):
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=utils.is_primary(args),
        )

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP (Distributed Data Parallel) wrapper.
        model_ema = utils.ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP (Distributed Data Parallel) preferred unless native amp is activated
            if utils.is_primary(args):
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if utils.is_primary(args):
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP (Distributed Data Parallel).

    # create the train and eval datasets
    if args.data and not args.data_dir:
        args.data_dir = args.data

    ### BUILD LOADERS
    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob,
            switch_prob=args.mixup_switch_prob,
            mode=args.mixup_mode,
            label_smoothing=args.smoothing,
            num_classes=args.num_classes
        )
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)
    args.mixup_active = mixup_active

    train_loader = exp.build_train_loader(collate_fn)
    eval_loader = exp.build_val_loader()

    ### BUILD LOSSES
    train_loss_fn = exp.build_train_loss_fn()
    val_loss_fn = exp.build_val_loss_fn()
    print('TRAIN LOSS FUNCTION:', train_loss_fn)
    print('VAL LOSS FUNCTION:', val_loss_fn)

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    if utils.is_primary(args):
        if args.experiment_name:
            exp_name = args.experiment_name
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                f'fold{args.exp_kwargs["fold_idx"]}',
                'x'.join([str(e) for e in exp.data_config['input_size']])
            ])
        if args.output:
            output_dir = args.output
        else:
            try:
                output_dir = exp.output_dir
            except:
                output_dir = './output/train'
        output_dir = os.path.join(output_dir, exp_name)
        assert not os.path.exists(output_dir), f'Output directory {output_dir} exist !'
        os.makedirs(output_dir, exist_ok=False)

        decreasing = True if eval_metric == 'loss' else False
        saver = utils.CheckpointSaver(
            model=model,
            optimizer=optimizer,
            args=args,
            model_ema=model_ema,
            amp_scaler=loss_scaler,
            checkpoint_dir=output_dir,
            recovery_dir=output_dir,
            decreasing=decreasing,
            max_history=args.checkpoint_hist
        )
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    args.output_dir = output_dir
    if utils.is_primary(args) and args.log_wandb:
        if has_wandb:
            wandb.init(project=args.experiment_name, config=args)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    # setup learning rate schedule and starting epoch
    updates_per_epoch = len(train_loader)
    args.updates_per_epoch = updates_per_epoch
    lr_scheduler, num_epochs = exp.build_lr_scheduler(optimizer)
    
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        if args.sched_on_updates:
            lr_scheduler.step_update(start_epoch * updates_per_epoch)
        else:
            lr_scheduler.step(start_epoch)

    if utils.is_primary(args):
        _logger.info(
            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')

    torch.cuda.empty_cache()
    if output_dir is not None:
        save_results_dir = os.path.join(output_dir, 'results')
        os.makedirs(save_results_dir)
    else:
        save_results_dir = None
    # print('MODEL:\n', model)
    args.save_results_dir = save_results_dir

    assert train_loader.sampler.num_epochs > num_epochs

    num_updates = 0
    try:
        for epoch in range(start_epoch, num_epochs):
            print(f'START EPOCH {epoch}')
            if hasattr(train_loader.dataset, 'set_epoch'):
                train_loader.dataset.set_epoch(epoch)
            elif hasattr(train_loader.sampler, 'set_epoch'):
                train_loader.sampler.set_epoch(epoch)

            train_metrics, num_updates, best_metric, best_epoch = train_one_epoch(
                exp,
                eval_loader,
                val_loss_fn,
                best_metric,
                best_epoch,
                epoch,
                model,
                train_loader,
                optimizer,
                train_loss_fn,
                args,
                lr_scheduler=lr_scheduler,
                saver=saver,
                output_dir=output_dir,
                amp_autocast=amp_autocast,
                loss_scaler=loss_scaler,
                model_ema=model_ema,
                mixup_fn=mixup_fn,
                num_updates = num_updates,
            )
            
            # break

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if utils.is_primary(args):
                    _logger.info("Distributing BatchNorm running means and vars")
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            if save_results_dir is not None:
                plot_save_path = os.path.join(save_results_dir, f'plot_{epoch}.jpg')
                ema_plot_save_path = os.path.join(save_results_dir, f'ema_plot_{epoch}.jpg')
                pred_save_path = os.path.join(args.save_results_dir, f'pred_{epoch}.csv')
            else:
                plot_save_path = None
                ema_plot_save_path = None

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')

                eval_metrics = validate(
                    exp,
                    model_ema.module,
                    eval_loader,
                    val_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    log_suffix=' (EMA)',
                    plot_save_path = ema_plot_save_path,
                    pred_save_path= pred_save_path,
                )
            else:
                eval_metrics = validate(
                    exp,
                    model,
                    eval_loader,
                    val_loss_fn,
                    args,
                    amp_autocast=amp_autocast,
                    plot_save_path = plot_save_path,
                    pred_save_path = pred_save_path,
                )

            # primary process/rank only
            if output_dir is not None:
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
                utils.update_summary(
                    epoch,
                    train_metrics,
                    eval_metrics,
                    filename=os.path.join(output_dir, 'summary.csv'),
                    lr=sum(lrs) / len(lrs),
                    write_header=best_metric is None,
                    log_wandb=args.log_wandb and has_wandb,
                )

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass

    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))


In [20]:
# %load_ext autoreload
# %autoreload 2

In [21]:
if __name__ == '__main__':
    main()

Training with a single process on 1 device (cuda:0).


utils/train.py
Using global configuration (SETTINGS.json):
--------------------------------------------------------------------------------
ASSETS_DIR: ./assets/
MODEL_CHECKPOINT_DIR: ./checkpoints/
MODEL_FINAL_SELECTION_DIR: ./assets/reproduce/
PROCESSED_DATA_DIR: ./datasets/processed/
RAW_DATA_DIR: 
SUBMISSION_DIR: ./submissions/
TEMP_DIR: ./tmp/
__JSON_PATH__: /media/na/e0adac50-20ce-4eb4-9c9d-98faf82ddd46/rsna_breast/SETTINGS.json
--------------------------------------------------------------------------------





------
EXP METADATA:
 {'fold_idx': 2, 'num_sched_epochs': 10, 'num_epochs': 350, 'start_ratio': 0.1429, 'end_ratio': 0.1429, 'one_pos_mode': True}


Loading pretrained weights from Hugging Face hub (timm/convnext_small.fb_in22k_ft_in1k_384)
[timm/convnext_small.fb_in22k_ft_in1k_384] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
Data processing configuration for current model + dataset:
	input_size: (3, 2048, 1024)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 1.0
	crop_mode: squash
Using native Torch AMP. Training in mixed precision.


TRAIN AUG:
 Compose([
  CustomRandomSizedCropNoResize(always_apply=False, p=0.4, scale=(0.5, 1.0), ratio=(0.5, 0.8)),
  HorizontalFlip(always_apply=False, p=0.5),
  VerticalFlip(always_apply=False, p=0.5),
  OneOf([
    Downscale(always_apply=False, p=0.1, scale_min=0.75, scale_max=0.95),
    Downscale(always_apply=False, p=0.1, scale_min=0.75, scale_max=0.95),
    Downscale(always_apply=False, p=0.8, scale_min=0.75, scale_max=0.95),
  ], p=0.125),
  OneOf([
    RandomToneCurve(always_apply=False, p=0.5, scale=0.3),
    RandomBrightnessContrast(always_apply=False, p=0.5, brightness_limit=(-0.1, 0.2), contrast_limit=(-0.4, 0.5), brightness_by_max=True),
  ], p=0.5),
  OneOf([
    ShiftScaleRotate(always_apply=False, p=0.6, shift_limit_x=(-0.1, 0.1), shift_limit_y=(-0.2, 0.2), scale_limit=(-0.15000000000000002, 0.1499999999999999), rotate_limit=(-30, 30), interpolation=1, border_mode=0, value=0, mask_value=None, rotate_method='largest_box'),
    ElasticTransform(always_apply=False, p=0.2

100%|██████████████████████████████████| 41055/41055 [00:00<00:00, 85793.45it/s]


Done loading rsna-breast-cancer-detection with 41055 samples.
DATASET TOTAL LENGTH: 41055 with positive percent = 0.021191085129704055
Num pos: 870, Num neg: 40185
RATIO PER EPOCHS: [0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429
 0.1429 0.1429 0.1429 0.1429 0.1429 0

Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927

Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927

Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927
Pre-compute for ratio = 0.1429
ONE POS MODE: Switch num_pos_samples to 45927

100%|██████████████████████████████████| 13651/13651 [00:00<00:00, 81497.69it/s]
Scheduled epochs: 300. LR stepped per epoch.


Done loading rsna-breast-cancer-detection with 13651 samples.
DATASET TOTAL LENGTH: 13651 with positive percent = 0.021097355505091203
SAMPLER: None
Using pos weight: None
TRAIN LOSS FUNCTION: BinaryCrossEntropyPosSmoothOnly()
VAL LOSS FUNCTION: CrossEntropyLoss()
START EPOCH 0
Set epoch to 0 with sampler ratio = 0.1429
Current number of updates: 0
DENSE CKPT START/END: [10, 18], interval = 22964


0it [00:00, ?it/s]

STARTING COMPUTE EPOCH 0 SAMPLE INDEXS...
Current epoch sampler ratio: 0.1429
45927 iters with 45927 samples


6344it [3:05:27,  1.75s/it]


### Select and copy the best checkpoints

In [16]:
!PYTHONPATH=$(pwd):$PYTHONPATH python3 utils/select_classification_best_ckpts.py --mode fully_reproduce

### Convert best Convnext models to TensorRT

In [17]:
# pip install nvidia-pyindex
# pip install onnx-graphsurgeon


# pythonpath => # pythonpath => to set location of custom Python libraries that are not installed in the site packages directory (the global default location).
!PYTHONPATH=$(pwd)/timm:$PYTHONPATH python3 utils/convert_convnext_tensorrt.py \
    --mode reproduce


torch_tensorrt is not installed yet.
Using global configuration (SETTINGS.json):
--------------------------------------------------------------------------------
ASSETS_DIR: ./assets/
MODEL_CHECKPOINT_DIR: ./checkpoints/
MODEL_FINAL_SELECTION_DIR: ./assets/reproduce/
PROCESSED_DATA_DIR: ./datasets/processed/
RAW_DATA_DIR: 
SUBMISSION_DIR: ./submissions/
TEMP_DIR: ./tmp/
__JSON_PATH__: /media/na/e0adac50-20ce-4eb4-9c9d-98faf82ddd46/rsna_breast/SETTINGS.json
--------------------------------------------------------------------------------




MODEL CHECKPOINTS:
 ['./assets/reproduce/best_convnext_fold_1.pth.tar', './assets/reproduce/best_convnext_fold_2.pth.tar', './assets/reproduce/best_convnext_fold_3.pth.tar']
TENSORRT ENGINE WILL BE SAVED TO ./assets/reproduce/best_ensemble_convnext_small_batch2_fp32.engine
USING BACKEND torch2trt
Loading model from ./assets/reproduce/best_convnext_fold_1.pth.tar
Data config: {'input_size': (3, 384, 384), 'interpolation': 'bicubic', 'mean': (0.485, 0.

In [None]:
#2147483647

In [None]:
#594279733