In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.distributions import Dirichlet
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import os

In [2]:
from typing import List

import torch as ch
import torchvision

from ffcv.fields import IntField, RGBImageField, FloatField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder, FloatDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
from torch.utils.data import Dataset
class myCIFAR10_train():
    def __init__(self):
        """
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        """
        
        train_dataset = torchvision.datasets.CIFAR10(root = './data', train = True, 
                                                     #transform = transform_train, 
                                                     #transform = transform_tensor,
                                                     download = True)
        alpha = torch.ones(len(train_dataset))#.to(device)   
        dirichlet = Dirichlet(alpha)
        weights = dirichlet.sample()
        
        self.dataset = train_dataset
        self.weights = weights
    
    def __getitem__(self, idx):
        data_tuple = self.dataset[idx]
        weights = self.weights[idx]
        return (data_tuple[0], data_tuple[1], float(weights.numpy()))
    
    def __len__(self):
        return len(self.dataset)

class myCIFAR10_test():
    def __init__(self):
        """
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        """
        
        test_dataset = torchvision.datasets.CIFAR10(root = './data', train = False, 
                                                     #transform = transform_test, 
                                                     download = True)
        alpha = torch.ones(len(test_dataset))#.to(device)   
        dirichlet = Dirichlet(alpha)
        weights = dirichlet.sample()
        
        self.dataset = test_dataset
        self.weights = weights
    
    def __getitem__(self, idx):
        data_tuple = self.dataset[idx]
        weights = self.weights[idx]
        return (data_tuple[0], data_tuple[1], float(weights.numpy()))
    
    def __len__(self):
        return len(self.dataset)

In [5]:
batch_size_train = 128
batch_size_test = 100
train_data = myCIFAR10_train()
test_data = myCIFAR10_test()
# train_loader = torch.utils.data.DataLoader(dataset= train_data, batch_size = batch_size_train, shuffle = True)
# test_loader = torch.utils.data.DataLoader(dataset= test_data, batch_size = batch_size_test, shuffle = False)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# train_data[0]
print(type(train_data[0][0]))
print(type(train_data[0][1]))
print(type(train_data[0][2]))
#print(train_data[0][2].shape)
#print(train_data[0][2].dtype)

<class 'PIL.Image.Image'>
<class 'int'>
<class 'float'>


In [7]:
# test_data[0]
print(type(test_data[0][0]))
print(type(test_data[0][1]))
print(type(test_data[0][2]))
#print(test_data[0][2].shape)
#print(test_data[0][2].dtype)

<class 'PIL.Image.Image'>
<class 'int'>
<class 'float'>


In [8]:
datasets = {
    'train': train_data,
    'test': test_data
}

for (name, ds) in datasets.items():
    writer = DatasetWriter(f'/tmp/cifar_{name}.beton', {
        'image': RGBImageField(),
        'label': IntField(),
        'weight': FloatField(),
    })
    writer.from_indexed_dataset(ds)

100%|██████████| 50000/50000 [00:00<00:00, 123871.28it/s]
100%|██████████| 10000/10000 [00:00<00:00, 33013.82it/s]


In [9]:
# Note that statistics are wrt to uin8 range, [0,255].
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]

BATCH_SIZE = 512

loaders = {}
for name in ['train', 'test']:
    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice('cuda:0'), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]
    weight_pipeline: List[Operation] = [FloatDecoder(), ToTensor(), ToDevice('cuda:0'), Squeeze()]

    # Add image transforms and normalization
    if name == 'train':
        image_pipeline.extend([
            RandomHorizontalFlip(),
            RandomTranslate(padding=2),
            Cutout(8, tuple(map(int, CIFAR_MEAN))), # Note Cutout is done before normalization.
        ])
    image_pipeline.extend([
        ToTensor(),
        ToDevice('cuda:0', non_blocking=True),
        ToTorchImage(),
        Convert(ch.float16),
        torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    # Create loaders
    loaders[name] = Loader(f'/tmp/cifar_{name}.beton',
                            batch_size=BATCH_SIZE,
                            num_workers=8,
                            order=OrderOption.RANDOM,
                            drop_last=(name == 'train'),
                            pipelines={'image': image_pipeline,
                                       'label': label_pipeline,
                                       'weight': weight_pipeline
                                       })

In [10]:
class Mul(ch.nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight
    def forward(self, x): return x * self.weight

class Flatten(ch.nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

class Residual(ch.nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    def forward(self, x): return x + self.module(x)

def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
    return ch.nn.Sequential(
            ch.nn.Conv2d(channels_in, channels_out,
                         kernel_size=kernel_size, stride=stride, padding=padding,
                         groups=groups, bias=False),
            ch.nn.BatchNorm2d(channels_out),
            ch.nn.ReLU(inplace=True)
    )

NUM_CLASSES = 10
model = ch.nn.Sequential(
    conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
    conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
    Residual(ch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
    conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
    ch.nn.MaxPool2d(2),
    Residual(ch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
    conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
    ch.nn.AdaptiveMaxPool2d((1, 1)),
    Flatten(),
    ch.nn.Linear(128, NUM_CLASSES, bias=False),
    Mul(0.2)
)
model = model.to(memory_format=ch.channels_last).cuda()

In [11]:
import numpy as np
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, lr_scheduler

EPOCHS = 24

opt = SGD(model.parameters(), lr=.5, momentum=0.9, weight_decay=5e-4)
iters_per_epoch = 50000 // BATCH_SIZE
lr_schedule = np.interp(np.arange((EPOCHS+1) * iters_per_epoch),
                        [0, 5 * iters_per_epoch, EPOCHS * iters_per_epoch],
                        [0, 1, 0])
scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
scaler = GradScaler()
loss_fn = CrossEntropyLoss(label_smoothing=0.1)

In [12]:
from tqdm import tqdm

for ep in range(EPOCHS):

    # calculate accuracy
    model.eval()
    if ep%4==0:
        with ch.no_grad():
            total_correct, total_num = 0., 0.
            for ims, labs, w in tqdm(loaders['test']):
                with autocast():
                    out = (model(ims) + model(ch.fliplr(ims))) / 2. # Test-time augmentation
                    total_correct += out.argmax(1).eq(labs).sum().cpu().item()
                    total_num += ims.shape[0]
            print(f'Accuracy: {total_correct / total_num * 100:.1f}%')

    # model training
    model.train()
    for ims, labs, w in tqdm(loaders['train']):
        opt.zero_grad(set_to_none=True)
        with autocast():
            out = model(ims)
            loss = loss_fn(out, labs)

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        scheduler.step()    

model.eval()


100%|██████████| 20/20 [00:03<00:00,  6.12it/s]


Accuracy: 8.8%


100%|██████████| 97/97 [00:11<00:00,  8.10it/s]
100%|██████████| 97/97 [00:09<00:00, 10.60it/s]
100%|██████████| 97/97 [00:09<00:00, 10.60it/s]
100%|██████████| 97/97 [00:09<00:00, 10.46it/s]
100%|██████████| 20/20 [00:01<00:00, 17.08it/s]


Accuracy: 66.8%


100%|██████████| 97/97 [00:09<00:00, 10.61it/s]
100%|██████████| 97/97 [00:09<00:00, 10.62it/s]
100%|██████████| 97/97 [00:09<00:00, 10.48it/s]
100%|██████████| 97/97 [00:09<00:00, 10.60it/s]
100%|██████████| 20/20 [00:01<00:00, 17.25it/s]


Accuracy: 72.5%


100%|██████████| 97/97 [00:09<00:00, 10.47it/s]
100%|██████████| 97/97 [00:09<00:00, 10.55it/s]
100%|██████████| 97/97 [00:09<00:00, 10.65it/s]
100%|██████████| 97/97 [00:09<00:00, 10.64it/s]
100%|██████████| 20/20 [00:01<00:00, 16.96it/s]


Accuracy: 80.7%


100%|██████████| 97/97 [00:09<00:00, 10.60it/s]
100%|██████████| 97/97 [00:09<00:00, 10.44it/s]
100%|██████████| 97/97 [00:09<00:00, 10.43it/s]
100%|██████████| 97/97 [00:09<00:00, 10.52it/s]
100%|██████████| 20/20 [00:01<00:00, 16.92it/s]


Accuracy: 82.1%


100%|██████████| 97/97 [00:09<00:00, 10.62it/s]
100%|██████████| 97/97 [00:09<00:00, 10.55it/s]
100%|██████████| 97/97 [00:09<00:00, 10.56it/s]
100%|██████████| 97/97 [00:09<00:00, 10.50it/s]
100%|██████████| 20/20 [00:01<00:00, 16.85it/s]


Accuracy: 88.1%


100%|██████████| 97/97 [00:09<00:00, 10.49it/s]
100%|██████████| 97/97 [00:09<00:00, 10.56it/s]
100%|██████████| 97/97 [00:09<00:00, 10.49it/s]
100%|██████████| 97/97 [00:09<00:00, 10.50it/s]


Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Sequential(
    (0): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (2): Residual(
    (module): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True

In [13]:
with ch.no_grad():
    total_correct, total_num = 0., 0.
    for ims, labs, w in tqdm(loaders['test']):
        with autocast():
            out = (model(ims) + model(ch.fliplr(ims))) / 2. # Test-time augmentation
            total_correct += out.argmax(1).eq(labs).sum().cpu().item()
            total_num += ims.shape[0]

    print(f'Accuracy: {total_correct / total_num * 100:.1f}%')

100%|██████████| 20/20 [00:01<00:00, 17.18it/s]

Accuracy: 92.2%



