## CIFAR 10

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

PosixPath('/home/paperspace')

In [27]:
from fastai.conv_learner import *
from fastai.models.cifar10.wideresnet import wrn_22_cat, wrn_22, WideResNetConcat
torch.backends.cudnn.benchmark = True
PATH = Path.home()/"data/cifar10/"
os.makedirs(PATH,exist_ok=True)

In [28]:
%pwd

'/home/paperspace/fastai/courses/dl2'

In [29]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))
workers=7

In [30]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
def pad(img, p=4, padding_mode='reflect'):
    return Image.fromarray(np.pad(np.asarray(img), ((p, p), (p, p), (0, 0)), padding_mode))

def torch_loader(data_path, size, bs, val_bs=None, prefetcher=True):
#     if not os.path.exists(data_path/'train'): download_cifar10(data_path)

    val_bs = val_bs or bs
    # Data loading code
    traindir = str(data_path/'train')
    valdir = str(data_path/'test')
    tfms = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]

    train_tfms = transforms.Compose([
        pad, # TODO: use `padding` rather than assuming 4
        transforms.RandomCrop(size),
        transforms.RandomHorizontalFlip(),
    ] + tfms)

    train_dataset = datasets.ImageFolder(traindir, train_tfms)
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose(tfms))

    train_loader = DataLoader(
        train_dataset, batch_size=bs, shuffle=True,
        num_workers=workers, pin_memory=True)

    val_loader = DataLoader(
        val_dataset, batch_size=val_bs, shuffle=False,
        num_workers=workers, pin_memory=True)
    
    aug_loader = DataLoader(
        datasets.ImageFolder(valdir, train_tfms),
        batch_size=bs, shuffle=False,
        num_workers=workers, pin_memory=True)

    if prefetcher:
        train_loader = DataPrefetcher(train_loader)
        val_loader = DataPrefetcher(val_loader)
        aug_loader = DataPrefetcher(aug_loader)
    
    data = ModelData(data_path, train_loader, val_loader)
    data.sz = size
    data.aug_dl = aug_loader
    return data

# Seems to speed up training by ~2%
class DataPrefetcher():
    def __init__(self, loader, stop_after=None):
        self.loader = loader
        self.dataset = loader.dataset
        self.stream = torch.cuda.Stream()
        self.stop_after = stop_after
        self.next_input = None
        self.next_target = None

    def __len__(self):
        return len(self.loader)

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loaditer)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(async=True)
            self.next_target = self.next_target.cuda(async=True)

    def __iter__(self):
        count = 0
        self.loaditer = iter(self.loader)
        self.preload()
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            count += 1
            yield input, target
            if type(self.stop_after) is int and (count > self.stop_after):
                break

In [31]:
# m = WideResNetConcat(num_groups=3, N=3, num_classes=10, k=1, drop_p=0.)

In [35]:
from fastai.models.cifar10.wideresnet import wrn_22

In [None]:
bs=128
sz=32
data = torch_loader(PATH, sz, bs, 512)

# m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
# m = ResNet18()
m = wrn_22()

# m = FP16(m.cuda())
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = F.cross_entropy
# learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=1e-4
lr=0.5
# learn.clip = 1e-2

learn.opt_fn = partial(optim.SGD, nesterov=True, momentum=0.9)

%time learn.fit(lr, 1, wds=wd, cycle_len=30, use_clr_beta=(20,20,0.95,0.85))

HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                   
    0      1.176328   1.536508   0.5127    
    1      0.919135   1.020056   0.6475                      
 93%|█████████▎| 364/391 [00:48<00:03,  7.49it/s, loss=0.796]

In [None]:
bs=128
sz=32
data = torch_loader(PATH, sz, bs, 512)

# m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
# m = ResNet18()
m = wrn_22()

# m = FP16(m.cuda())
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = F.cross_entropy
# learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=1e-4
lr=1e-3
# learn.clip = 1e-2

learn.opt_fn = partial(optim.Adam, betas=(0.95,0.99))

%time learn.fit(lr, 1, wds=wd, cycle_len=30, use_clr_beta=(10,7.5,0.95,0.85))

In [None]:
bs=128
sz=32
data = torch_loader(PATH, sz, bs, 512)

# m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
# m = ResNet18()
m = wrn_22()

# m = FP16(m.cuda())
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = F.cross_entropy
# learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=1e-4
lr=3e-3
# learn.clip = 1e-2

learn.opt_fn = partial(optim.Adam, betas=(0.95,0.99))

%time learn.fit(lr, 1, wds=wd, cycle_len=30, use_clr_beta=(10,7.5,0.95,0.85))

In [5]:

# --
# Model definition
# Derived from models in `https://github.com/kuangliu/pytorch-cifar`

class PreActBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            )
            
    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        return out + shortcut


class ResNet18(nn.Module):
    def __init__(self, num_blocks=[2, 2, 2, 2], num_classes=10):
        super().__init__()
        
        self.in_channels = 64
        
        self.prep = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.layers = nn.Sequential(
            self._make_layer(64, 64, num_blocks[0], stride=1),
            self._make_layer(64, 128, num_blocks[1], stride=2),
            self._make_layer(128, 256, num_blocks[2], stride=2),
            self._make_layer(256, 256, num_blocks[3], stride=2),
        )
        
        self.classifier = nn.Linear(512, num_classes)
        
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(PreActBlock(in_channels=in_channels, out_channels=out_channels, stride=stride))
            in_channels = out_channels
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = x.half()
        x = self.prep(x)
        
        x = self.layers(x)
        
        x_avg = F.adaptive_avg_pool2d(x, (1, 1))
        x_avg = x_avg.view(x_avg.size(0), -1)
        
        x_max = F.adaptive_max_pool2d(x, (1, 1))
        x_max = x_max.view(x_max.size(0), -1)
        
        x = torch.cat([x_avg, x_max], dim=-1)
        
        x = self.classifier(x)
        
        return x


In [7]:
def get_TTA_accuracy(learn):
    preds, targs = learn.TTA()
    # combining the predictions across augmented and non augmented inputs
    preds = 0.6 * preds[0] + 0.4 * preds[1:].sum(0)
    return accuracy_np(preds, targs)

def get_TTA_accuracy_2(learn):
    log_preds,y = learn.TTA()
    preds = np.mean(np.exp(log_preds),0)
    acc = accuracy(torch.FloatTensor(preds),torch.LongTensor(y))
    print('TTA acc:', acc)