In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7'

import pandas as pd
import numpy as np

import cv2
from torch.utils.data import Dataset
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2


from torch.utils.data import Dataset,DataLoader
import torch
from torch import nn
from torch.nn.modules.loss import _WeightedLoss
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from model import CassvaImgClassifier
import torch.nn.functional as F
from utils import seed_everything
from sklearn.model_selection import StratifiedKFold


from albumentations.core.transforms_interface import ImageOnlyTransform
import albumentations
from PIL import Image, ImageOps, ImageEnhance
from torchvision import transforms
from tqdm import tqdm_notebook as tqdm
import apex

In [3]:
seed_everything(2021)

In [4]:
CFG = {
    'test_fold': 0,
    'seed': 2021,
    'model_arch': 'tf_efficientnet_b4_ns',
    'log_file': "/home/samenko/Cassava/logs/augmix_cosLR_Parallel4_Adam_albAugs.log",
    'img_size': 512,
    'epochs': 20,
    'train_bs': 8*4,
    'valid_bs': 32,
    'T_0': 10,
    'lr': 1e-4,#0.1,#
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 8,
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'fp16': True,
    'print_freq':1
}

In [5]:
data = pd.read_csv('/home/data/Cassava/train.csv')
data['fold'] = 0
strkf = StratifiedKFold(n_splits=5)
_ = strkf.get_n_splits(data.image_id, data.label)
f = 0
for train_index, test_index in strkf.split(data.image_id, data.label):
    data.loc[data.index.isin(test_index), 'fold'] = f
    f = f + 1

train_data = data[(data.fold != CFG['test_fold'])].reset_index(drop=True)
val_data = data[data.fold == CFG['test_fold']].reset_index(drop=True)

In [6]:
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb
    #return Image.open(path)#.thumbnail((512,512,3), Image.ANTIALIAS)

class CassavaDataset(Dataset):
    def __init__(self, df, transforms):
        super().__init__()
        self.df = df
        self.transforms = transforms

    def __len__(self, ):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df[self.df.index == idx]
        image_name = row.image_id.values[0]
        img = get_img('/home/data/Cassava/train_images/' + image_name)
        #img = self.transforms(img)
        img = self.transforms(image=img)['image']
        img = Image.fromarray(img.astype('uint8'), 'RGB')
        label = row.label.values[0]
        return img, label

In [7]:
TRAIN_AUGS = Compose([
    RandomResizedCrop(CFG['img_size'], CFG['img_size']),
    Transpose(p=0.5),
    HorizontalFlip(p=0.5),
    VerticalFlip(p=0.5),
    ShiftScaleRotate(p=0.5),
    HueSaturationValue(hue_shift_limit=20, sat_shift_limit=20, val_shift_limit=20, p=0.5),
    RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
    #Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    CoarseDropout(p=0.5),
    Cutout(p=0.5),
    #ToTensorV2(p=1.0),
], p=1.)


TEST_AUGS = Compose([
    CenterCrop(CFG['img_size'], CFG['img_size'], p=1.0),
    Resize(CFG['img_size'], CFG['img_size']),
    #Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    #ToTensorV2(p=1.0),
], p=1.)

val_ds = CassavaDataset(val_data, TRAIN_AUGS)

In [8]:
preprocess = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [9]:
import augmentations
class AugMixDataset(torch.utils.data.Dataset):
    """Dataset wrapper to perform AugMix augmentation."""

    def __init__(self, dataset, preprocess, no_jsd=False, val_mode = False):
        self.dataset = dataset
        self.preprocess = preprocess
        self.no_jsd = no_jsd
        self.val_mode = val_mode

    def __getitem__(self, i):
        x, y = self.dataset[i]
        if self.val_mode:
            return self.preprocess(x), y
        if self.no_jsd:
            return aug(x, self.preprocess), y
        else:
            im_tuple = (self.preprocess(x), aug(x, self.preprocess),
                        aug(x, self.preprocess))
            return im_tuple, y
        

    def __len__(self):
        return len(self.dataset)

def aug(image, preprocess):
    """Perform AugMix augmentations and compute mixture.
    Args:
      image: PIL.Image input image
      preprocess: Preprocessing function which should return a torch tensor.
    Returns:
      mixed: Augmented and mixed image.
    """
    aug_list = augmentations.augmentations
    if args['all_ops']:
        aug_list = augmentations.augmentations_all

    ws = np.float32(np.random.dirichlet([1] * args['mixture_width']))
    m = np.float32(np.random.beta(1, 1))

    mix = torch.zeros_like(preprocess(image))
    for i in range(args['mixture_width']):
        image_aug = image.copy()
        depth = args['mixture_depth'] if args['mixture_depth'] > 0 else np.random.randint(
            1, 4)
        for _ in range(depth):
            op = np.random.choice(aug_list)
            image_aug = op(image_aug, args['aug_severity'])
        # Preprocessing commutes since all coefficients are convex
        mix += ws[i] * preprocess(image_aug)

    mixed = (1 - m) * preprocess(image) + m * mix
    return mixed

In [10]:
args ={"all_ops":True,
      "mixture_width":3,
      "mixture_depth":np.random.randint(1, 4),
       "aug_severity":3
      }

In [11]:
train_ds = CassavaDataset(train_data, TRAIN_AUGS)
val_ds = CassavaDataset(val_data, TEST_AUGS)
am_train_data = AugMixDataset(train_ds, preprocess, no_jsd=False)
am_val_data = AugMixDataset(val_ds, preprocess, no_jsd=False, val_mode=True)

In [12]:
# train_ds = CassavaDataset(train_data, train_transform)
# val_ds = CassavaDataset(val_data, val_transform)
# am_train_data = AugMixDataset(train_ds, preprocess, no_jsd=False)

In [13]:
train_loader = torch.utils.data.DataLoader(
    am_train_data,
    batch_size=CFG['train_bs'],
    shuffle=True,
    num_workers=CFG['num_workers'],
    pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    am_val_data, #val_ds,
    batch_size=CFG['valid_bs'],
    shuffle=False,
    num_workers=CFG['num_workers'],
    pin_memory=True)

In [14]:
device = torch.device('cuda')
model = CassvaImgClassifier(CFG['model_arch'], data.label.nunique(), pretrained=True)#.to(device)
model = torch.nn.DataParallel(model).cuda()

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])

# model, optimizer = apex.amp.initialize(
#                 model,
#                 optimizer,
#                 opt_level='O1')

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['epochs'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)
loss_val = nn.CrossEntropyLoss().to(device)

In [17]:
def valid_one_epoch( model, val_loader, loss_fn, epoch):
    loss_sum = 0
    sample_num = 0
    preds_all = []
    targets_all = []
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    with torch.no_grad():
        model = model.eval();
        for step, (x, y_true) in pbar:
            x = x.to(device).float()
            y_true = y_true.to(device).long()
            y_pred = model(x)
            preds_all += [torch.argmax(y_pred, 1).detach().cpu().numpy()]
            targets_all += [y_true.detach().cpu().numpy()]
            l = loss_fn(y_pred, y_true)
            loss_sum += l.item() * y_true.shape[0]
            sample_num += y_true.shape[0]
        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'val epoch {epoch} loss: {loss_sum / sample_num:.4f}'
            pbar.set_description(description)
    preds_all = np.concatenate(preds_all)
    targets_all = np.concatenate(targets_all)
    print('validation multi-class accuracy = {:.4f}'.format((preds_all == targets_all).mean()))
    with open(CFG['log_file'], 'a+') as logger:
        logger.write(f"Epoch: {epoch} val acc = {(preds_all == targets_all).mean()}\n")
    #return (preds_all == targets_all).mean(), loss_sum / sample_num

In [None]:
for epoch in range(CFG['epochs']):
    
    model = model.train()
    loss_ema = 0.
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), position=0, leave=True)
    for step, (images, targets) in pbar:
        optimizer.zero_grad()
        #augmix
        images_all = torch.cat(images, 0).cuda()
        targets = targets.cuda()
        logits_all = model(images_all)
        logits_clean, logits_aug1, logits_aug2 = torch.split(logits_all, images[0].size(0))
        loss = F.cross_entropy(logits_clean, targets)
        p_clean, p_aug1, p_aug2 = F.softmax(
          logits_clean, dim=1), F.softmax(
              logits_aug1, dim=1), F.softmax(
                  logits_aug2, dim=1)
        # Clamp mixture distribution to avoid exploding KL divergence
        p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
        loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                            F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                            F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
        #with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
        #    scaled_loss.backward()
        loss.backward()
        optimizer.step()
        scheduler.step()
        loss_ema = loss_ema * 0.9 + float(loss) * 0.1
        description = f'tain epoch {epoch} loss_ema: {loss_ema:.4f} loss: {loss:.4f}'
        pbar.set_description(description)
        #if step % CFG['print_freq'] == 0: print('Train Loss {:.4f}'.format(loss_ema))
    
    #train_loss_ema = train(model, train_loader, optimizer, scheduler)
    valid_one_epoch(model, val_loader, loss_val, epoch)
    #model = model.eval()
    #total_loss = 0.
    #total_correct = 0
    #with torch.no_grad():
    #    for images, targets in test_loader:
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, max=535.0), HTML(value='')))