In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from audio.audio import *
import fastai
import gc
import fastprogress
import inspect

In [3]:
def f_args(kwargs, *funcs):
    args = []
    for f in funcs:
        args += inspect.getfullargspec(f).args

    return keep_only(args, kwargs)

def keep_only(args, kwargs):
    dic = {}
    for k, v in kwargs.items(): 
        if k in args:
            dic[k] = v
    return dic

In [4]:
class progress_disabled():
    ''' Context manager to disable the progress update bar and Recorder print'''
    def __init__(self,learn:Learner):
        self.learn = learn
    def __enter__(self):
        fastprogress.fastprogress.NO_BAR = True
        fastai.basic_train.master_bar, fastai.basic_train.progress_bar = fastprogress.force_console_behavior()
        self.learn.callback_fns[0] = partial(Recorder,add_time=True,silent=True) #silence recorder
        
        return self.learn
    
    def __exit__(self,type,value,traceback):
        fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar,progress_bar
        self.learn.callback_fns[0] = partial(Recorder,add_time=True)

In [5]:
@dataclass
class HyperParameters():
    base_arch=models.resnet34
    freeze=True
    schedule=[(1,slice(1e-4, 1e-3))]
    silent_learn=False
    normalize=imagenet_stats
    size=224
    bs=32
    
    def __getitem__(self, i):
        return self.__getattribute__(i)

In [84]:
class Experiement:
    
    def __init__(self, params=HyperParameters()):
        self.params = params
        self.schedule = params.schedule
        self.silent = params.silent_learn
        self.freeze = params.freeze
        self.normalize = params.normalize
    
    def __str__(self):
        return str(self.params.__dict__)
    
    def create_transforms(self, **kwargs):
        d  = f_args(kwargs, get_transforms)
        return get_transforms(**d)
        
    def create_databunch(self, *args, **kwargs):
        raise NotImplemented
    
    def create_learner(self, db, **kwargs):
        d = f_args(kwargs, cnn_learner, Learner)
        return cnn_learner(db, **d)
    
    def fit(self, it, lr, l):
        l.fit_one_cycle(it, lr)

    def run(self):
        learn = self._gen_learn()
        if not self.freeze: learn.unfreeze()
        for it, lr in self.schedule:
            if self.silent:
                with progress_disabled(learn) as l: self.fit(it,lr,l)
            else:
                print(self)
                self.fit(it,lr,learn)
        return learn
    
    def _gen_learn(self):
        args = self.params.__dict__
        tfms = self.create_transforms(**args)
        db = self.create_databunch(tfms, **args)
        if self.normalize is not None:
            db = db.normalize(self.normalize)
        return self.create_learner(db, **args)
    
    def _one_try(self):
        try:
            
            learn = self._gen_learn()
            with progress_disabled(learn) as l:
                l.fit(1)
        except e:
            return e

In [86]:
_EXCLUDED = ['normalize']

def create_hparameters(base=HyperParameters(), **kwargs):
    vks = []

    for k, v in kwargs.items():
        if k in _EXCLUDED: setattr(base, k, v) 
        elif isinstance(v, (tuple, range)): vks.append(k)
        else: setattr(base, k, v)
    
    if len(vks) < 1: return [base]
    es = []
    k = vks[0]
    for v in kwargs[k]:
        b = deepcopy(base)
        setattr(b,k,v)
        args = deepcopy(kwargs)
        del args[k]
        es += create_hparameters(b, **args)
    return es

hps = create_hparameters(base_arch=(models.resnet18, models.resnet34), freeze=(True, False), size=128, metrics=accuracy, bs=32, normalize=(imagenet_stats))
for hp in hps:
    print(hp.__dict__)

{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet18 at 0x7f26b89fe2f0>, 'freeze': True}
{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet18 at 0x7f26b89fe2f0>, 'freeze': False}
{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet34 at 0x7f26b89fe730>, 'freeze': True}
{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet34 at 0x7f26b89fe730>, 'freeze': False}


In [77]:
path = untar_data(URLs.FLOWERS)
def process_csv(bunch):
    df = pd.read_csv(path/bunch, header=None)
    df['label'] = df[0].str.split().str[1]
    df[0] = df[0].str.split().str[0]
    return df
train_df = process_csv('train.txt')
train_df.head()

Unnamed: 0,0,label
0,jpg/image_03860.jpg,16
1,jpg/image_06092.jpg,13
2,jpg/image_02400.jpg,42
3,jpg/image_02852.jpg,55
4,jpg/image_07710.jpg,96


In [78]:
class MyExperiment(Experiement):
    
    def create_databunch(self, tfms, bs=2, size=224, **kwargs):
            d = f_args(kwargs, ImageList.from_df)
            data = ImageList.from_df(train_df, path, **d)
            data = data.split_by_rand_pct(.2, seed=2, **d)
            data = data.label_from_df(cols='label', **d)
            data = data.transform(tfms, size=size)
            data = data.databunch(bs=bs)
            return data


In [79]:
def highest_accuracy(_, metrics, __, M):
    best = max(metrics)
    if M is None: return best
    return best if best > M else None

def grid_search(hps, check_batch=True, best=highest_accuracy), save_best=False):
    losses, metrics, vals = [], [], []
    max_metric = 0
    exps = [MyExperiment(hp) for h in hps]
    if check_batch:
        for exp in exps: 
            gc.collect()
            e = exp._one_try()
            if e is not None:
                print("The following configuration threw an error")
                print(exp)
                print(e)
    M = None
    bestl = None
    for exp in exps:
        learn = exp.run() 
        losses += learn.recorder.losses
        metrics += learn.recorder.metrics
        vals += learn.recorder.val_losses
        res = best(losses, metrics, vals, M)
        if res is not None:
            bestl = learn
            M = res
    if save_best: bestl.export()
    return bestl, M

grid_search(hps, check_batch=False)

([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet34 at 0x7f26b89fe730>, 'freeze': False}


epoch,train_loss,valid_loss,accuracy,time
0,4.700364,3.942179,0.181373,00:06


([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet34 at 0x7f26b89fe730>, 'freeze': False}


epoch,train_loss,valid_loss,accuracy,time
0,4.726741,3.966913,0.156863,00:05


([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet34 at 0x7f26b89fe730>, 'freeze': False}


epoch,train_loss,valid_loss,accuracy,time
0,4.787484,4.034906,0.122549,00:05


([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
{'size': 128, 'metrics': <function accuracy at 0x7f26b8f431e0>, 'bs': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 'base_arch': <function resnet34 at 0x7f26b89fe730>, 'freeze': False}


epoch,train_loss,valid_loss,accuracy,time
0,4.7947,4.049545,0.166667,00:05


(Learner(data=ImageDataBunch;
 
 Train: LabelList (816 items)
 x: ImageList
 Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
 y: CategoryList
 16,42,55,96,5
 Path: /home/h/.fastai/data/oxford-102-flowers;
 
 Valid: LabelList (204 items)
 x: ImageList
 Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
 y: CategoryList
 77,64,24,31,9
 Path: /home/h/.fastai/data/oxford-102-flowers;
 
 Test: None, model=Sequential(
   (0): Sequential(
     (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace)
     (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
     (4): Sequential(
       (0): BasicBlock(
         (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps