In [1]:
from fastai.vision.all import *
from datetime import datetime
import torch
torch.cuda.is_available()

True

In [2]:
path = Path('/home/fredguth/.fastai/data/cifar10_mnist')

In [3]:
def label_func(f):
    return (str(f).split("_y")[1:][0][0])

def noise_func(f):
    return (str(f).split("_n")[1:][0][0])

In [4]:
def get_dls(task="CIFAR"):
    dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   splitter  = IndexSplitter(list(range(10000))),
                   get_items = get_image_files,
                   get_y = label_func if (task == 'CIFAR') else noise_func,
                   batch_tfms= [Normalize],
                   n_inp     = 1 )
    return dblock.dataloaders(path, num_workers=4)

In [5]:
class SimpleProgress(Callback):
    def after_batch(self):
        print ('\033[92m'+u"\u258F", end="")

In [14]:
class RunMNIST(Callback):
    def after_batch(self):
        print ('\033[94m'+ u"\u258F", end="")
    def after_epoch(self):  
        epoch = self.epoch
        model = self.learn.model.detach().clone()
        model[1] = create_head(512,10)
        l =cnn_learner(
                           dls=get_dls(task="MNIST"), 
                           arch=modell 
                           pretrained=True, 
                           loss_func=F.cross_entropy, 
                           metrics=accuracy, 
                           cbs=[CSVLogger(fname=f"mnist_after_cifar_e{epoch}.csv")]
                          )
        l.dls = get_dls(task="MNIST")
        l.remove_cbs([CSVLogger, RunMNIST])
        l.add_cb(CSVLogger(fname=f"mnist_after_cifar_e{epoch}.csv"))
        l.add_cb(SimpleProgress)
        with l.no_bar(): l.fine_tune(epochs = 3, base_lr=0.002)

In [7]:
cifar_learner= cnn_learner(dls=get_dls(task="CIFAR"), 
                           arch=resnet18, 
                           pretrained=False, 
                           loss_func=F.cross_entropy, 
                           metrics=accuracy, 
                           cbs=[CSVLogger(fname=f"cifar.csv"), RunMNIST]
                          )

In [12]:
b = create_head(512,10)

In [13]:
b

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=10, bias=False)
)

In [11]:
cifar_learner.model[1]

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=10, bias=False)
)

In [8]:
with cifar_learner.no_bar():  cifar_learner.fit(2, 0.002)

[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94m▏[94

[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92

[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92

[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92

[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92m▏[92