In [1]:
from __future__ import print_function, division

import os
import copy
import time
import random
import argparse
import numpy as np
from urllib import request
from zipfile import ZipFile
from functools import partial


import timm
import torch, torchvision

from torch import nn
from timm.loss import BinaryCrossEntropy, LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

from torchvision import datasets, models
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode

from torch.utils.data import random_split
from torch.utils.data.dataloader import default_collate
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

import transforms as T
import utils

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining, pb2


import simpleargs

In [2]:
seed = 99
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f0ee014d950>

In [3]:
def load_data(data_dir='./hymenoptera_data'):
    augment = [transforms.RandomResizedCrop(size=224, interpolation=InterpolationMode.BILINEAR),
               transforms.RandomHorizontalFlip()]
    ######### Recipe 2 #########           
    if cfg.num_ops:
        augment.append(autoaugment.RandAugment(num_ops=cfg.num_ops, magnitude=cfg.magnitude, num_magnitude_bins=cfg.num_magnitude_bins, interpolation=cfg.interpolation))
        
    if cfg.num_magnitude_bins:
        augment.append(autoaugment.TrivialAugmentWide(num_magnitude_bins=cfg.num_magnitude_bins, interpolation=cfg.interpolation))

    augment.extend([
            transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])  
    ######### Recipe 4 #########
    if cfg.random_erase_prob:
        augment.append(transforms.RandomErasing(p=cfg.random_erase_prob))
    
    ######### Recipe 9 #########
    if cfg.train_crop_size:
        augment[0] = transforms.RandomResizedCrop(size=cfg.train_crop_size, interpolation=cfg.interpolate)  

    augment = transforms.Compose(augment)
    
    valid_augment = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
                       
    trainset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=augment)
    validset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=valid_augment)
    num_classes = len(trainset.classes)
    return trainset, validset, num_classes

In [4]:
def train(checkpoint_dir=None, data_dir=None):
    seed = 99
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    trainset, validset, num_classes = load_data(data_dir=data_dir)
    
    model = torch.hub.load("facebookresearch/deit:main", "deit_"+cfg.model+"_patch16_224", pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    
    num_ftrs = model.head.in_features
    model.head = nn.Linear(num_ftrs, num_classes)
    
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
    model.to(device)
    
    ######### Recipe 10 #########
    ema_model = None
    if (cfg.model_ema_decay and cfg.model_ema_steps):
        adjust = 1 * cfg.batch_size * cfg.model_ema_steps / cfg.epochs
        alpha = 1.0 - cfg.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        ema_model = utils.ExponentialMovingAverage(model, device=device, decay=1.0 - alpha)
        
    
    parameters = model.head.parameters()
    
    ######### Recipe 8 #########                   
    if cfg.weight_decay:
        norm_weight_decay=0.0
        param_groups = utils.split_normalization_params(model)
        wd_groups = [norm_weight_decay, cfg.weight_decay]
        parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
    
    
    ######### Recipe 1 #########
    if cfg.lr:
        optimizer = torch.optim.AdamW(parameters, lr=cfg.lr)
        main_lr_scheduler = CosineAnnealingLR(optimizer, T_max=cfg.epochs - cfg.lr_warmup_epochs)
        warmup_lr_scheduler = LinearLR(optimizer, start_factor=cfg.lr_warmup_decay, total_iters=cfg.lr_warmup_epochs)
        lr_scheduler = SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[cfg.lr_warmup_epochs])
        
    ######### Recipe 5 #########
    if cfg.mixup_alpha or cfg.cutmix_alpha:
        if cfg.bce:
            train_criterion = BinaryCrossEntropy(smoothing=0.0)
        else:
            train_criterion = SoftTargetCrossEntropy()
    elif cfg.smooth:
        if cfg.bce:
            train_criterion = BinaryCrossEntropy(smoothing=cfg.smooth)
        else:
            train_criterion = LabelSmoothingCrossEntropy(smoothing=cfg.smooth)
    else:
        train_criterion = nn.CrossEntropyLoss()
    train_criterion = train_criterion.to(device)


    valid_criterion = nn.CrossEntropyLoss().to(device)
    
    collate_fn = None
    mixup_transforms = []
    ######### Recipe 6 #########
    if cfg.mixup_alpha:
        mixup_transforms.append(T.RandomMixup(num_classes, p=1.0, alpha=mixup_alpha))
    ######### Recipe 7 #########    
    if cfg.cutmix_alpha:
        mixup_transforms.append(T.RandomCutmix(num_classes, p=1.0, alpha=cfg.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))
        
        
    if checkpoint_dir:
        checkpoints = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(checkpoints["model_state"])
        optimizer.load_state_dict(checkpoints["optimizer_state"])
        lr_scheduler.load_state_dict(checkpoints["lr_scheduler_state"])
        if ema_model:
            ema_model.load_state_dict(checkpoints["model_ema"])
    
    
    train_loader = torch.utils.data.DataLoader(
        trainset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn)
    
    valid_loader = torch.utils.data.DataLoader(
        validset,
        batch_size=32,
        shuffle=False,
        num_workers=2,
        pin_memory=True)
    
    
    best=0.0
    for epoch in range(cfg.epochs): 
        train_one_epoch(epoch, train_loader, model, ema_model, optimizer, train_criterion, device=device)

        lr_scheduler.step()
        if ema_model:
            valid_loss, accuracy = validate(valid_loader, ema_model, valid_criterion, device=device)
        else:
            valid_loss, accuracy = validate(valid_loader, model, valid_criterion, device=device)
        
          
            
        print(f"Epoch: {epoch} Valid loss:{valid_loss:4.4f} Valid accuracy: {accuracy:4.4f}")
        
        if checkpoint_dir is not None and not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir+"/checkpoint")
        if checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            ckpts = {"model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "lr_scheduler_state": lr_scheduler.state_dict()}
            if ema_model:
                ckpts["ema_state"] = ema_model.state_dict()
            torch.save(ckpts, path)
            
    
    print("Finished Training")
    

In [5]:
def train_one_epoch(epoch, train_loader, model, ema_model, optimizer, criterion, device="cpu"):
    running_loss = 0.0
    epoch_steps = 0
    model.train()
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        epoch_steps += 1
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps))
            running_loss = 0.0

        if ema_model and i % cfg.model_ema_steps == 0:
            ema_model.update_parameters(model)
            if epoch < cfg.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                ema_model.n_averaged.fill_(0)


def validate(valid_loader, model, criterion, device="cpu"):
    valid_loss = 0.0
    valid_steps = 0
    total = 0
    correct = 0
    model.eval()
    with torch.inference_mode():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss = criterion(outputs, labels)
            valid_loss += loss.cpu().numpy()
            valid_steps += 1

    valid_loss = valid_loss / valid_steps
    accuracy = correct / total
    return  valid_loss, accuracy
   

In [6]:
cfg = simpleargs
data_dir="./hymenoptera_data/"
# cfg.seed=99
cfg.model="tiny"
cfg.epochs=25
###### scheduler type
cfg.bce=True
###### pass to use optimized hparam  
cfg.batch_size=64
cfg.lr=0.00491
cfg.lr_warmup_epochs=6
cfg.lr_warmup_decay=3.7999999e-05
cfg.weight_decay=None
cfg.smooth=None
cfg.mixup_alpha=None
cfg.cutmix_alpha=None
cfg.random_erase_prob=None
cfg.model_ema_steps=None
cfg.model_ema_decay=None
cfg.train_crop_size=None
cfg.interpolation=None
cfg.num_ops=None
cfg.magnitude=None
cfg.num_magnitude_bins=None

In [7]:
train(data_dir=data_dir)

Using cache found in /home/enoch/.cache/torch/hub/facebookresearch_deit_main


Epoch: 0 Valid loss:0.6928 Valid accuracy: 0.6340


AttributeError: 'collections.OrderedDict' object has no attribute 'train'