In [None]:
#default_exp model.modelmanager

In [None]:
#hide
#from collections.abc import Iterable,Iterator,Generator,Sequence
#from collections import OrderedDict,defaultdict,Counter,namedtuple


In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Dataloader

> create dataset and dataloader

In [None]:
# export

from lib.data.lists import *
from lib.model.model import *
from lib.learner.learner import *
from functools import partial

import torch
from torch import Tensor
from torch import nn


class ModelManager():
    def __init__(self,model):self.model=model

    #@classmethod
    #def create_from_model(model:nn.Module): 

    def find_modules(self,condition):
        return find_submodules(self.model, condition)

    def summary(self, xb:Tensor, only_leaves=True, print_mod=False):
        #device = next(model.parameters()).device
        #xb     = xb.to(device)
        f      = lambda hook,mod,inp,out: print(f"\n{mod}\n{out.shape}") if print_mod else print(f"{type(mod)} {out.shape}")
        mods = self.find_modules(lambda m: not isinstance(m, nn.Sequential) and not isinstance(m, ResBlock) ) if only_leaves else \
               self.model.children() 
        with Hooks(mods, f) as hooks: self.model(xb)

    def grads_summary(self):
        modules = self.find_modules( condition=lambda m: not isinstance(m, nn.Sequential) )
        for module in modules:
            if len(list(module.children()))==0:
                requires_grad     = [p.requires_grad for p in module.parameters(recurse=False)]
                str_requires_grad = "None "    
                if len(requires_grad) > 0:    
                    str_requires_grad = "False" if sum(requires_grad) == 0 else "True " if sum(requires_grad)==len(requires_grad) else "None"
                print(f"requires_grad: {str_requires_grad} : {type(module).__name__}")

    def save(self, path, subdir="models"):
        mdl_path = Path(path)/subdir
        mdl_path.mkdir(exist_ok=True)
        st = self.model.state_dict()
        torch.save(st, mdl_path/'iw5')
    
    def load(self, path, subdir="models"):
        mdl_path = Path(path)/subdir
        st = torch.load(mdl_path/'iw5')    
        self.model.load_state_dict(st)

    @staticmethod
    def set_grad(module, requires_grad, train_bn=False):
        if isinstance(module, (nn.BatchNorm2d)): return

        for p in module.parameters(recurse=False):
            p.requires_grad_(requires_grad)

    def change_requires_grad_(self, modules, requires_grad, train_bn):
        condition = lambda m: not isinstance(m, nn.Sequential)
        selection = []
        for m in modules:   selection.extend( ModelManager.find_submodules(m, condition) )
        for m in selection: ModelManager.set_grad(m, requires_grad, train_bn)
        
    def freeze( self, train_bn=False ):
        self.change_requires_grad_([self.model[0]], requires_grad=False, train_bn=train_bn)    
        self.change_requires_grad_(self.model[1:],  requires_grad=True,  train_bn=train_bn)
    
    def unfreeze( self, train_bn=False ):
        self.change_requires_grad_(self.model,    requires_grad=True, train_bn=train_bn)    
        
    def getFirstbatch(self, databunch:DataBunch, normalization:Callback ):
        cbfs  = [partial(BatchTransformXCallback, tfm = normalization), GetOneBatchCallback]
        learn = Learner( self.model, databunch, loss_func=None)
        learn.fit(1, opt=None, cb_funcs=cbfs)
        cb    = learn.find_subcription_by_cls(GetOneBatchCallback)
        if cb is None: print("cb is None")
        return cb.xb, cb.yb
    
    def adapt_model(self, databunch, normalization):
        #get rid of norm
        cut   = next( i for i,o in enumerate(self.model.children()) if isinstance(o,nn.AdaptiveAvgPool2d) )
        m_cut = self.model[:cut]
    
        xb,_  = self.getFirstbatch( databunch, normalization )
        pred  = m_cut(xb)
        ni    = pred.shape[1]
    
        self.model = nn.Sequential(
            m_cut, 
            #AdaptiveConcatPool2d(), 
            nn.AdaptiveAvgPool2d(1),
            Flatten(),
            nn.Linear(ni, databunch.c_out)
            #nn.Linear(ni*2, data.c_out)
        )
        
    def predict(self, input_data, tfm_input):
        with torch.no_grad():
            return self.model( tfm_input(torch.tensor(input_data) ) )

class CnnModelManager(ModelManager):

    def initialize(self, is_resnet:bool, uniform:bool=False, a=0.0, nonlinearity="relu"):
        if isinstance(self.model,XResNet): 
            self.model.initialize(uniform, a)



# Tests

# Export scripts

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_test.ipynb.
Converted 01_data.external.ipynb.
Converted 02_lists.ipynb.
Converted 03_images.ipynb.
Converted 05_Learner.ipynb.
Converted 05_model.ipynb.
Converted 06_modelmanger.ipynb.
Converted 07_optimizers.ipynb.
Converted app_image_01_mnist_optimizers.ipynb.
Converted app_image_02_imagenette_optimizers.ipynb.
Converted fin_01_candlestick.ipynb.
Converted fin_02_simfin_data.ipynb.
Converted index.ipynb.
