In [1]:
import torch
import numpy as np
from datasets import load_dataset, load_metric
from transformers import DeiTModel, DeiTFeatureExtractor, DeiTForImageClassification, TrainingArguments, Trainer
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
from timm.models import create_model
from timm.optim import create_optimizer
from timm.utils import ModelEma, NativeScaler, get_state_dict
from timm.scheduler import create_scheduler
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from PIL import Image
from losses import DistillationLoss
from pathlib import Path
from engine import train_one_epoch, evaluate
import time
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
from timm.data import create_transform
import os
import utils
import json
import datetime

In [10]:
checkpoint = torch.load('./DeiT/out_imgnet/checkpoint.pth')

In [2]:
datasets.CIFAR10('./data', download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./data
    Split: Train

In [2]:
cifar_path = './data'
imgnet_path = '/home/ecbm4040/.cache/huggingface/datasets/downloads/extracted/0201c2598c4cf28a3ea355e57a1b281d33183822b6848bd3c1ea3285835cb028/ILSVRC/Data/CLS-LOC'

In [48]:
class build_args():
    def __init__(self, 
                 epochs=1,
                 batch_size=256,
                 input_size=224, color_jitter=0.4, 
                 aa='rand-m9-mstd0.5-inc1', train_interpolation='bicubic', 
                 reprob=0.25, remode='pixel', recount=1,
                 distributed=False, distillation_type='none',
                 device='cuda', seed=0, finetune='', eval_mode=False,
                 data_set='IMNET', data_path = './data', num_workers=4, pin_mem=True,
                 mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0,
                 mixup_switch_prob=0.5, mixup_mode='batch', smoothing=0.1,
                 model='deit_tiny_patch16_224', drop=0.0, drop_path=0.1,
                 model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False,
                 sched='cosine', lr=5e-4, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0,
                 warmup_lr=1e-6, min_lr=1e-5, opt='adamw', opt_eps=1e-8, opt_betas=None,
                 clip_grad=None, momentum=0.9, weight_decay=0.05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10,
                 teacher_model='regnety_160', teacher_path='', output_dir='./deit_output', resume='',
                 distillation_alpha=0.5, distillation_tau=1.0):
        
        self.epochs = epochs
        self.batch_size = batch_size
        self.output_dir = output_dir
        self.resume = resume
        self.start_epoch=0
        
        # model
        self.input_size = input_size
        self.model = model
        self.finetune = finetune
        self.drop = drop
        self.drop_path = drop_path
        self.model_ema = model_ema
        self.model_ema_decay = model_ema_decay
        self.model_ema_force_cpu = model_ema_force_cpu
        
        # optimizer
        self.opt = opt
        self.opt_eps = opt_eps
        self.opt_betas = opt_betas
        self.clip_grad = clip_grad
        self.momentum = momentum
        self.weight_decay = weight_decay
        
        # augmentation
        self.color_jitter = color_jitter
        self.aa = aa
        self.train_interpolation = train_interpolation
        self.smoothing = smoothing
        self.reprob = reprob
        self.remode = remode
        self.recount = recount
        self.distributed = distributed
        self.distillation_type = distillation_type
        self.device = device
        self.seed = seed
        self.eval = eval_mode
        self.data_set = data_set
        self.data_path = data_path
        self.num_workers = num_workers
        self.pin_mem = pin_mem
        self.nb_classes = 0
        
        # mixup
        self.mixup = mixup
        self.cutmix = cutmix
        self.cutmix_minmax = cutmix_minmax
        self.mixup_prob = mixup_prob
        self.mixup_switch_prob = mixup_switch_prob
        self.mixup_mode = mixup_mode
        
        # learning rate schedule
        self.sched = sched
        self.lr = lr
        self.lr_noise = lr_noise
        self.lr_noise_pct = lr_noise_pct
        self.lr_noise_std = lr_noise_std
        self.warmup_lr = warmup_lr
        self.min_lr = min_lr
        self.decay_epochs = 30
        self.warmup_epochs = 5
        self.cooldown_epochs = 10
        
        # distillation
        self.teacher_model = teacher_model
        self.teacher_path = teacher_path
        self.distillation_alpha = distillation_alpha
        self.distillation_tau = distillation_tau
    
args = build_args(data_set='IMNET', data_path=imgnet_path)

In [4]:
if args.distillation_type != 'none' and args.finetune and not args.eval:
    raise NotImplementedError("Finetuning with distillation not yet supported")

In [5]:
device = torch.device(args.device)

In [6]:
# fix the seed for reproducibility
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)

In [7]:
# input do not change, set cuda to find best algorithm
cudnn.benchmark = True

In [69]:
def build_transform(is_train, args):
    resize_im = args.input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=args.train_interpolation,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                args.input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int((256 / 224) * args.input_size)
        t.append(
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)

def build_dataset(is_train, args):
    transform = build_transform(is_train, args)

    if args.data_set == 'CIFAR':
        dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
        nb_classes = 100
    elif args.data_set == 'IMNET':
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = 1000

    return dataset, nb_classes

In [52]:
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)

In [70]:
dataset_val, _ = build_dataset(is_train=False, args=args)

In [54]:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

In [55]:
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=True,
    )

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, sampler=sampler_val,
    batch_size=int(1.5 * args.batch_size),
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=False
    )

In [17]:
# use mixup strategy
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
    mixup_fn = Mixup(
        mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
        prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
        label_smoothing=args.smoothing, num_classes=args.nb_classes)

In [18]:
print(f"Creating model: {args.model}")

Creating model: deit_tiny_patch16_224


In [19]:
model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None)

In [20]:
model.to(device)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((192,), ep

In [21]:
model_ema = None
if args.model_ema:
    # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    model_ema = ModelEma(
        model,
        decay=args.model_ema_decay,
        device='cpu' if args.model_ema_force_cpu else '',
        resume='')

In [22]:
model_without_ddp = model

In [23]:
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

number of params: 5717416


In [24]:
linear_scaled_lr = args.lr*args.batch_size/512.0

In [25]:
args.lr = linear_scaled_lr
optimizer = create_optimizer(args, model_without_ddp)
loss_scaler = NativeScaler()
lr_scheduler, _ = create_scheduler(args, optimizer)
criterion = LabelSmoothingCrossEntropy()

In [26]:
if mixup_active:
    # smoothing is handled with mixup label transform
    criterion = SoftTargetCrossEntropy()
elif args.smoothing:
    criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
    criterion = torch.nn.CrossEntropyLoss()

In [27]:
teacher_model = None

In [28]:
# wrap the criterion in our custom DistillationLoss, which
# just dispatches to the original criterion if args.distillation_type is 'none'
criterion = DistillationLoss(
    criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
)

In [42]:
output_dir = Path(args.output_dir)

In [30]:
if args.resume:
    if args.resume.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(
            args.resume, map_location='cpu', check_hash=True)
    else:
        checkpoint = torch.load(args.resume, map_location='cpu')
    model_without_ddp.load_state_dict(checkpoint['model'])
    if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
        if args.model_ema:
            utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
        if 'scaler' in checkpoint:
            loss_scaler.load_state_dict(checkpoint['scaler'])
    lr_scheduler.step(args.start_epoch)

In [31]:
if args.eval:
    test_stats = evaluate(data_loader_val, model, device)
    print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")

In [32]:
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0

Start training for 1 epochs


In [33]:
epoch = 0
train_stats = train_one_epoch(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, model_ema, mixup_fn,
            set_training_mode=args.finetune == ''  # keep in eval mode during finetuning
            )

Epoch: [0]  [   0/5004]  eta: 8:33:41  lr: 0.000001  loss: 6.9520 (6.9520)  time: 6.1593  data: 5.2085  max mem: 6466
Epoch: [0]  [  10/5004]  eta: 2:06:21  lr: 0.000001  loss: 6.9459 (6.9437)  time: 1.5181  data: 0.9482  max mem: 6531
Epoch: [0]  [  20/5004]  eta: 2:01:50  lr: 0.000001  loss: 6.9392 (6.9441)  time: 1.2322  data: 0.6977  max mem: 6531
Epoch: [0]  [  30/5004]  eta: 1:51:58  lr: 0.000001  loss: 6.9392 (6.9437)  time: 1.2587  data: 0.7233  max mem: 6531
Epoch: [0]  [  40/5004]  eta: 1:51:35  lr: 0.000001  loss: 6.9395 (6.9423)  time: 1.2249  data: 0.6839  max mem: 6531
Epoch: [0]  [  50/5004]  eta: 1:48:28  lr: 0.000001  loss: 6.9395 (6.9428)  time: 1.2568  data: 0.7132  max mem: 6531
Epoch: [0]  [  60/5004]  eta: 1:49:25  lr: 0.000001  loss: 6.9431 (6.9433)  time: 1.2852  data: 0.7407  max mem: 6531
Epoch: [0]  [  70/5004]  eta: 1:46:15  lr: 0.000001  loss: 6.9451 (6.9436)  time: 1.2364  data: 0.6907  max mem: 6531
Epoch: [0]  [  80/5004]  eta: 1:46:38  lr: 0.000001  los

Epoch: [0]  [ 700/5004]  eta: 1:28:31  lr: 0.000001  loss: 6.9234 (6.9337)  time: 1.2125  data: 0.6242  max mem: 6531
Epoch: [0]  [ 710/5004]  eta: 1:28:12  lr: 0.000001  loss: 6.9264 (6.9336)  time: 1.2319  data: 0.6425  max mem: 6531
Epoch: [0]  [ 720/5004]  eta: 1:28:06  lr: 0.000001  loss: 6.9264 (6.9335)  time: 1.2385  data: 0.6512  max mem: 6531
Epoch: [0]  [ 730/5004]  eta: 1:27:46  lr: 0.000001  loss: 6.9291 (6.9335)  time: 1.2213  data: 0.6331  max mem: 6531
Epoch: [0]  [ 740/5004]  eta: 1:27:42  lr: 0.000001  loss: 6.9287 (6.9334)  time: 1.2327  data: 0.6460  max mem: 6531
Epoch: [0]  [ 750/5004]  eta: 1:27:23  lr: 0.000001  loss: 6.9278 (6.9334)  time: 1.2493  data: 0.6641  max mem: 6531
Epoch: [0]  [ 760/5004]  eta: 1:27:15  lr: 0.000001  loss: 6.9289 (6.9333)  time: 1.2186  data: 0.6309  max mem: 6531
Epoch: [0]  [ 770/5004]  eta: 1:26:57  lr: 0.000001  loss: 6.9266 (6.9332)  time: 1.2185  data: 0.6297  max mem: 6531
Epoch: [0]  [ 780/5004]  eta: 1:26:50  lr: 0.000001  los

Epoch: [0]  [1400/5004]  eta: 1:13:29  lr: 0.000001  loss: 6.9210 (6.9281)  time: 1.2137  data: 0.6280  max mem: 6531
Epoch: [0]  [1410/5004]  eta: 1:13:13  lr: 0.000001  loss: 6.9215 (6.9280)  time: 1.1844  data: 0.6001  max mem: 6531
Epoch: [0]  [1420/5004]  eta: 1:13:02  lr: 0.000001  loss: 6.9163 (6.9279)  time: 1.1819  data: 0.5991  max mem: 6531
Epoch: [0]  [1430/5004]  eta: 1:12:47  lr: 0.000001  loss: 6.9103 (6.9278)  time: 1.2057  data: 0.6223  max mem: 6531
Epoch: [0]  [1440/5004]  eta: 1:12:39  lr: 0.000001  loss: 6.9186 (6.9278)  time: 1.2519  data: 0.6650  max mem: 6531
Epoch: [0]  [1450/5004]  eta: 1:12:24  lr: 0.000001  loss: 6.9197 (6.9277)  time: 1.2360  data: 0.6490  max mem: 6531
Epoch: [0]  [1460/5004]  eta: 1:12:14  lr: 0.000001  loss: 6.9191 (6.9276)  time: 1.2171  data: 0.6319  max mem: 6531
Epoch: [0]  [1470/5004]  eta: 1:11:58  lr: 0.000001  loss: 6.9205 (6.9276)  time: 1.2093  data: 0.6237  max mem: 6531
Epoch: [0]  [1480/5004]  eta: 1:11:50  lr: 0.000001  los

Epoch: [0]  [2100/5004]  eta: 0:58:55  lr: 0.000001  loss: 6.9200 (6.9240)  time: 1.2154  data: 0.6581  max mem: 6531
Epoch: [0]  [2110/5004]  eta: 0:58:40  lr: 0.000001  loss: 6.9210 (6.9240)  time: 1.2059  data: 0.6496  max mem: 6531
Epoch: [0]  [2120/5004]  eta: 0:58:29  lr: 0.000001  loss: 6.9112 (6.9239)  time: 1.1641  data: 0.6056  max mem: 6531
Epoch: [0]  [2130/5004]  eta: 0:58:15  lr: 0.000001  loss: 6.9110 (6.9239)  time: 1.1823  data: 0.6206  max mem: 6531
Epoch: [0]  [2140/5004]  eta: 0:58:05  lr: 0.000001  loss: 6.9138 (6.9238)  time: 1.2146  data: 0.6518  max mem: 6531
Epoch: [0]  [2150/5004]  eta: 0:57:50  lr: 0.000001  loss: 6.9079 (6.9237)  time: 1.1991  data: 0.6376  max mem: 6531
Epoch: [0]  [2160/5004]  eta: 0:57:39  lr: 0.000001  loss: 6.9120 (6.9237)  time: 1.1717  data: 0.6092  max mem: 6531
Epoch: [0]  [2170/5004]  eta: 0:57:25  lr: 0.000001  loss: 6.9136 (6.9236)  time: 1.1578  data: 0.5944  max mem: 6531
Epoch: [0]  [2180/5004]  eta: 0:57:14  lr: 0.000001  los

Epoch: [0]  [2800/5004]  eta: 0:44:24  lr: 0.000001  loss: 6.9104 (6.9208)  time: 1.1997  data: 0.6372  max mem: 6531
Epoch: [0]  [2810/5004]  eta: 0:44:11  lr: 0.000001  loss: 6.9104 (6.9208)  time: 1.1854  data: 0.6266  max mem: 6531
Epoch: [0]  [2820/5004]  eta: 0:44:00  lr: 0.000001  loss: 6.9095 (6.9207)  time: 1.1631  data: 0.6036  max mem: 6531
Epoch: [0]  [2830/5004]  eta: 0:43:47  lr: 0.000001  loss: 6.9095 (6.9207)  time: 1.1883  data: 0.6339  max mem: 6531
Epoch: [0]  [2840/5004]  eta: 0:43:35  lr: 0.000001  loss: 6.9104 (6.9207)  time: 1.2137  data: 0.6601  max mem: 6531
Epoch: [0]  [2850/5004]  eta: 0:43:22  lr: 0.000001  loss: 6.9022 (6.9206)  time: 1.1904  data: 0.6338  max mem: 6531
Epoch: [0]  [2860/5004]  eta: 0:43:11  lr: 0.000001  loss: 6.9022 (6.9205)  time: 1.1608  data: 0.6083  max mem: 6531
Epoch: [0]  [2870/5004]  eta: 0:42:58  lr: 0.000001  loss: 6.9059 (6.9205)  time: 1.1838  data: 0.6304  max mem: 6531
Epoch: [0]  [2880/5004]  eta: 0:42:46  lr: 0.000001  los

Epoch: [0]  [3500/5004]  eta: 0:30:10  lr: 0.000001  loss: 6.9038 (6.9181)  time: 1.1691  data: 0.6133  max mem: 6531
Epoch: [0]  [3510/5004]  eta: 0:29:58  lr: 0.000001  loss: 6.9022 (6.9181)  time: 1.1791  data: 0.6246  max mem: 6531
Epoch: [0]  [3520/5004]  eta: 0:29:46  lr: 0.000001  loss: 6.9056 (6.9180)  time: 1.1713  data: 0.6191  max mem: 6531
Epoch: [0]  [3530/5004]  eta: 0:29:33  lr: 0.000001  loss: 6.9081 (6.9180)  time: 1.1552  data: 0.6024  max mem: 6531
Epoch: [0]  [3540/5004]  eta: 0:29:22  lr: 0.000001  loss: 6.9068 (6.9180)  time: 1.1563  data: 0.6010  max mem: 6531
Epoch: [0]  [3550/5004]  eta: 0:29:09  lr: 0.000001  loss: 6.8986 (6.9179)  time: 1.1835  data: 0.6261  max mem: 6531
Epoch: [0]  [3560/5004]  eta: 0:28:58  lr: 0.000001  loss: 6.8996 (6.9179)  time: 1.2111  data: 0.6551  max mem: 6531
Epoch: [0]  [3570/5004]  eta: 0:28:45  lr: 0.000001  loss: 6.8984 (6.9178)  time: 1.2056  data: 0.6473  max mem: 6531
Epoch: [0]  [3580/5004]  eta: 0:28:33  lr: 0.000001  los

Epoch: [0]  [4200/5004]  eta: 0:16:05  lr: 0.000001  loss: 6.9035 (6.9156)  time: 1.1720  data: 0.6158  max mem: 6531
Epoch: [0]  [4210/5004]  eta: 0:15:52  lr: 0.000001  loss: 6.9011 (6.9156)  time: 1.1650  data: 0.6018  max mem: 6531
Epoch: [0]  [4220/5004]  eta: 0:15:41  lr: 0.000001  loss: 6.9034 (6.9155)  time: 1.2030  data: 0.6426  max mem: 6531
Epoch: [0]  [4230/5004]  eta: 0:15:28  lr: 0.000001  loss: 6.9035 (6.9155)  time: 1.2297  data: 0.6760  max mem: 6531
Epoch: [0]  [4240/5004]  eta: 0:15:17  lr: 0.000001  loss: 6.9044 (6.9155)  time: 1.1891  data: 0.6345  max mem: 6531
Epoch: [0]  [4250/5004]  eta: 0:15:04  lr: 0.000001  loss: 6.9042 (6.9154)  time: 1.1616  data: 0.6063  max mem: 6531
Epoch: [0]  [4260/5004]  eta: 0:14:53  lr: 0.000001  loss: 6.9036 (6.9154)  time: 1.2022  data: 0.6431  max mem: 6531
Epoch: [0]  [4270/5004]  eta: 0:14:40  lr: 0.000001  loss: 6.8983 (6.9153)  time: 1.2164  data: 0.6546  max mem: 6531
Epoch: [0]  [4280/5004]  eta: 0:14:29  lr: 0.000001  los

Epoch: [0]  [4900/5004]  eta: 0:02:04  lr: 0.000001  loss: 6.8976 (6.9135)  time: 1.2580  data: 0.6926  max mem: 6531
Epoch: [0]  [4910/5004]  eta: 0:01:52  lr: 0.000001  loss: 6.8965 (6.9135)  time: 1.2662  data: 0.7077  max mem: 6531
Epoch: [0]  [4920/5004]  eta: 0:01:40  lr: 0.000001  loss: 6.9050 (6.9135)  time: 1.2419  data: 0.6857  max mem: 6531
Epoch: [0]  [4930/5004]  eta: 0:01:28  lr: 0.000001  loss: 6.9006 (6.9134)  time: 1.2161  data: 0.6621  max mem: 6531
Epoch: [0]  [4940/5004]  eta: 0:01:16  lr: 0.000001  loss: 6.8988 (6.9134)  time: 1.1916  data: 0.6410  max mem: 6531
Epoch: [0]  [4950/5004]  eta: 0:01:04  lr: 0.000001  loss: 6.8988 (6.9134)  time: 1.1965  data: 0.6476  max mem: 6531
Epoch: [0]  [4960/5004]  eta: 0:00:52  lr: 0.000001  loss: 6.8963 (6.9133)  time: 1.1877  data: 0.6407  max mem: 6531
Epoch: [0]  [4970/5004]  eta: 0:00:40  lr: 0.000001  loss: 6.8963 (6.9133)  time: 1.1787  data: 0.6260  max mem: 6531
Epoch: [0]  [4980/5004]  eta: 0:00:28  lr: 0.000001  los

In [34]:
lr_scheduler.step(epoch)

In [43]:
if args.output_dir:
    checkpoint_paths = [output_dir / 'checkpoint.pth']
    for checkpoint_path in checkpoint_paths:
        utils.save_on_master({
            'model': model_without_ddp.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'model_ema': get_state_dict(model_ema),
            'scaler': loss_scaler.state_dict(),
            'args': args,
        }, checkpoint_path)

In [56]:
test_stats = evaluate(data_loader_val, model, device)

Test:  [  0/131]  eta: 0:19:00  loss: 6.7376 (6.7376)  acc1: 2.6042 (2.6042)  acc5: 4.6875 (4.6875)  time: 8.7055  data: 8.1211  max mem: 6531
Test:  [ 10/131]  eta: 0:04:28  loss: 6.8595 (6.8519)  acc1: 0.0000 (0.4735)  acc5: 0.0000 (1.3494)  time: 2.2185  data: 1.8928  max mem: 6531
Test:  [ 20/131]  eta: 0:03:54  loss: 6.8579 (6.8560)  acc1: 0.0000 (0.4092)  acc5: 0.5208 (1.4013)  time: 1.7842  data: 1.4789  max mem: 6531
Test:  [ 30/131]  eta: 0:03:13  loss: 6.8882 (6.8773)  acc1: 0.0000 (0.3696)  acc5: 0.2604 (1.1341)  time: 1.7464  data: 1.4391  max mem: 6531
Test:  [ 40/131]  eta: 0:03:00  loss: 6.8899 (6.8758)  acc1: 0.0000 (0.3303)  acc5: 0.0000 (0.9909)  time: 1.8458  data: 1.5349  max mem: 6531
Test:  [ 50/131]  eta: 0:02:35  loss: 6.8448 (6.8673)  acc1: 0.0000 (0.2859)  acc5: 0.0000 (1.0672)  time: 1.9316  data: 1.6164  max mem: 6531
Test:  [ 60/131]  eta: 0:02:16  loss: 6.8422 (6.8655)  acc1: 0.0000 (0.3671)  acc5: 0.7812 (1.2893)  time: 1.7852  data: 1.4639  max mem: 6531

In [57]:
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")

Accuracy of the network on the 50000 test images: 0.4%


In [58]:
if max_accuracy < test_stats["acc1"]:
    max_accuracy = test_stats["acc1"]
    if args.output_dir:
        checkpoint_paths = [output_dir / 'best_checkpoint.pth']
        for checkpoint_path in checkpoint_paths:
            utils.save_on_master({
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'model_ema': get_state_dict(model_ema),
                'scaler': loss_scaler.state_dict(),
                'args': args,
            }, checkpoint_path)

In [59]:
print(f'Max accuracy: {max_accuracy:.2f}%')

Max accuracy: 0.36%


In [60]:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
             **{f'test_{k}': v for k, v in test_stats.items()},
             'epoch': epoch,
             'n_parameters': n_parameters}

In [65]:
if args.output_dir and utils.is_main_process():
    with (output_dir / "log.txt").open("a") as f:
        f.write(json.dumps(log_stats) + "\n")

In [68]:
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))

Training time 2:02:17
