## CIFAR 10

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai.conv_learner import *
from fastai.models.cifar10.wideresnet import wrn_22_cat, wrn_22, WideResNetConcat
torch.backends.cudnn.benchmark = True
PATH = Path("data/cifar10/")
os.makedirs(PATH,exist_ok=True)

In [3]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

bs=1024
sz=32
workers=7

In [4]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
def pad(img, p=4, padding_mode='reflect'):
    return Image.fromarray(np.pad(np.asarray(img), ((p, p), (p, p), (0, 0)), padding_mode))

def torch_loader(data_path, size, prefetcher=True):
    if not os.path.exists(data_path/'train'): download_cifar10(data_path)

    # Data loading code
    traindir = str(data_path/'train')
    valdir = str(data_path/'test')
    tfms = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]

    train_tfms = transforms.Compose([
        pad, # TODO: use `padding` rather than assuming 4
        transforms.RandomCrop(size),
        transforms.RandomHorizontalFlip(),
    ] + tfms)

    train_dataset = datasets.ImageFolder(traindir, train_tfms)
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose(tfms))

    train_loader = DataLoader(
        train_dataset, batch_size=bs, shuffle=True,
        num_workers=workers, pin_memory=True)

    val_loader = DataLoader(
        val_dataset, batch_size=bs, shuffle=False,
        num_workers=workers, pin_memory=True)
    
    aug_loader = DataLoader(
        datasets.ImageFolder(valdir, train_tfms),
        batch_size=bs, shuffle=False,
        num_workers=workers, pin_memory=True)

    if prefetcher:
        train_loader = DataPrefetcher(train_loader)
        val_loader = DataPrefetcher(val_loader)
        aug_loader = DataPrefetcher(aug_loader)
    
    data = ModelData(data_path, train_loader, val_loader)
    data.sz = size
    data.aug_dl = aug_loader
    return data

# Seems to speed up training by ~2%
class DataPrefetcher():
    def __init__(self, loader, stop_after=None):
        self.loader = loader
        self.dataset = loader.dataset
        self.stream = torch.cuda.Stream()
        self.stop_after = stop_after
        self.next_input = None
        self.next_target = None

    def __len__(self):
        return len(self.loader)

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loaditer)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(async=True)
            self.next_target = self.next_target.cuda(async=True)

    def __iter__(self):
        count = 0
        self.loaditer = iter(self.loader)
        self.preload()
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            count += 1
            yield input, target
            if type(self.stop_after) is int and (count > self.stop_after):
                break

In [5]:
data = torch_loader(PATH, sz)

In [6]:
'''Pre-activation ResNet in PyTorch.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or (1,1)
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
    

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.bn2.bias.data.zero_()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x), inplace=True)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out), inplace=True))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, concatpool=False):
        super(PreActResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.pool = AdaptiveConcatPool2d() if concatpool else nn.AdaptiveMaxPool2d((1,1))
        
        self.linear = nn.Linear(512*block.expansion*(concatpool+1), num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
#         out = F.adaptive_max_pool2d(out, 1)
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        return F.log_softmax(self.linear(out))

def preact_resnet18(): return PreActResNet(PreActBlock, [2,2,2,2])
def preact_resnet2332(): return PreActResNet(PreActBlock, [2,3,3,2])
def preact_resnet3333(): return PreActResNet(PreActBlock, [3,3,3,3])
def preact_resnet34(): return PreActResNet(PreActBlock, [3,4,6,3])
def preact_resnet50(): return PreActResNet(PreActBottleneck, [3,4,6,3])
def preActResNet101(): return PreActResNet(PreActBottleneck, [3,4,23,3])
def preActResNet152(): return PreActResNet(PreActBottleneck, [3,8,36,3])


In [7]:
# m = WideResNetConcat(num_groups=3, N=3, num_classes=10, k=1, drop_p=0.)

In [8]:
def get_TTA_accuracy(learn):
    preds, targs = learn.TTA()
    # combining the predictions across augmented and non augmented inputs
    preds = 0.6 * preds[0] + 0.4 * preds[1:].sum(0)
    return accuracy_np(preds, targs)

def get_TTA_accuracy_2(learn):
    log_preds,y = learn.TTA()
    preds = np.mean(np.exp(log_preds),0)
    acc = accuracy(torch.FloatTensor(preds),torch.LongTensor(y))
    print('TTA acc:', acc)

In [12]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=5e-4
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-3,1e-2), lr_decay=DecayType.LINEAR, momentum=(0.95,0.85), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-2,1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=4, lr=(1e-3,1e-5), lr_decay=DecayType.LINEAR, momentum=(0.95,0.98), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

HBox(children=(IntProgress(value=0, description='Epoch', max=26), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.559481   1.375577   0.494     
    1      1.214859   1.258109   0.5722                   
    2      0.98733    1.075784   0.6534                    
    3      0.851114   1.346923   0.6442                    
    4      0.722213   1.100872   0.6906                    
    5      0.624781   0.689455   0.7868                    
    6      0.563165   0.756333   0.7731                    
    7      0.522025   0.709459   0.7825                    
    8      0.478817   0.630427   0.7977                    
    9      0.427421   0.535656   0.8182                    
    10     0.393279   0.941095   0.7433                    
    11     0.372102   0.756953   0.7801                    
    12     0.356929   0.435244   0.8555                    
    13     0.294046   0.357349   0.884                     
    14     0.245216   0.412861   0.8711                    
    15     0.20782    0.334846   0.8983                   

[0.260223046875, 0.9327999989509582]

In [13]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=5e-4
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(2e-4, 2e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(2e-3,2e-2), lr_decay=DecayType.LINEAR, momentum=(0.95,0.85), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(2e-2,2e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=4, lr=(2e-3,1e-5), lr_decay=DecayType.LINEAR, momentum=(0.95,0.98), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

HBox(children=(IntProgress(value=0, description='Epoch', max=26), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.518948   2.029938   0.3991    
    1      1.204774   1.092209   0.6287                   
    2      0.968324   1.203947   0.608                     
    3      0.823453   1.125828   0.6542                    
    4      0.713277   1.419574   0.6458                    
    5      0.63577    1.110034   0.6782                    
    6      0.569573   0.858042   0.7264                    
    7      0.512871   0.714342   0.767                     
    8      0.468518   0.972811   0.7007                    
    9      0.434033   0.677549   0.7749                    
    10     0.403709   0.594791   0.8069                    
    11     0.395056   0.605798   0.809                     
    12     0.360575   0.507203   0.8432                    
    13     0.336032   0.437429   0.8599                    
    14     0.279051   0.442298   0.8669                    
    15     0.240927   0.377872   0.8849                   

[0.2800693359375, 0.9268000018119812]

In [14]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=5e-4
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=4, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=12, lr=(1e-3,1e-2), lr_decay=DecayType.LINEAR, momentum=(0.95,0.85), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=12, lr=(1e-2,1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=6, lr=(1e-3,1e-5), lr_decay=DecayType.LINEAR, momentum=(0.95,0.98), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

HBox(children=(IntProgress(value=0, description='Epoch', max=34), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.55699    1.578994   0.4194    
    1      1.21989    1.157563   0.6014                   
    2      0.990838   1.016119   0.6649                    
    3      0.830221   0.946743   0.6957                    
    4      0.712191   0.817422   0.7304                    
    5      0.641922   0.723941   0.7628                    
    6      0.588504   0.621541   0.7904                    
    7      0.543973   1.61429    0.6102                    
    8      0.515282   0.738418   0.765                     
    9      0.464268   0.784556   0.7783                    
    10     0.428416   0.890738   0.7562                    
    11     0.404087   1.065597   0.7035                    
    12     0.387005   0.783281   0.7591                    
    13     0.35338    0.517973   0.8383                    
    14     0.322273   0.474601   0.8458                    
    15     0.310378   0.796406   0.7807                   

[0.2893052734375, 0.9308000016212463]

In [16]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=5e-4
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=12, lr=(1e-3,4e-2), lr_decay=DecayType.LINEAR, momentum=(0.95,0.85), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=12, lr=(4e-2,1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=4, lr=(1e-3,1e-5), lr_decay=DecayType.LINEAR, momentum=(0.95,0.98), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.546197   1.556898   0.4582    
    1      1.200931   1.157685   0.6213                   
 98%|█████████▊| 48/49 [00:07<00:00,  6.69it/s, loss=1.01]

RuntimeError: value cannot be converted to type Half without overflow: inf

In [10]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=1e-3
lr=1e-3
learn.clip = .5

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.547924   1.418103   0.4774    
    1      1.210919   1.399253   0.5572                   



[1.399253125, 0.5572]

In [11]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=1e-3
lr=1e-3

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.552995   1.385894   0.4987    
    1      1.204416   1.219089   0.5984                   



[1.219088671875, 0.5984000007629394]

In [None]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=5e-4
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

In [None]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=5e-4
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-3,1e-2), lr_decay=DecayType.LINEAR, momentum=(0.95,0.85), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-2,1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=4, lr=(1e-3,1e-5), lr_decay=DecayType.LINEAR, momentum=(0.95,0.98), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

In [None]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=1e-3
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-3,1e-2), lr_decay=DecayType.LINEAR, momentum=(0.95,0.85), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-2,1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=4, lr=(1e-3,1e-5), lr_decay=DecayType.LINEAR, momentum=(0.95,0.98), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

In [None]:
m = PreActResNet(PreActBlock, [2,2,2,2], concatpool=True)
learn = Learner.from_model_data(m, data)
learn.half()
learn.crit = nn.CrossEntropyLoss()
learn.opt_fn = optim.Adam
learn.metrics = [accuracy]
wd=1e-4
lr=1e-3
learn.clip = 3e-1

def_phase = {'opt_fn':optim.Adam, 'wds':wd}

phases = [
    TrainingPhase(**def_phase, epochs=2, lr=(1e-4, 1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-3,1e-2), lr_decay=DecayType.LINEAR, momentum=(0.95,0.85), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=10, lr=(1e-2,1e-3), lr_decay=DecayType.LINEAR, momentum=(0.85,0.95), momentum_decay=DecayType.LINEAR, wd_loss=False),
    TrainingPhase(**def_phase, epochs=4, lr=(1e-3,1e-5), lr_decay=DecayType.LINEAR, momentum=(0.95,0.98), momentum_decay=DecayType.LINEAR, wd_loss=False)
]

learn.fit_opt_sched(phases, loss_scale=512)

In [None]:
learn.load('att6-tta')

In [None]:
phases = [TrainingPhase(**def_phase, epochs=4, lr=(.04,.001), lr_decay=DecayType.LINEAR, momentum=(0.95))]
learn.fit_opt_sched(phases, data_list=[data], loss_scale=512)

In [None]:
tta_data = torch_loader(PATH, sz, prefetcher=False)

In [None]:
learn.data_ = tta_data

In [None]:
get_TTA_accuracy(learn)

In [None]:
get_TTA_accuracy_2(learn)