In [2]:
%load_ext autoreload
%autoreload 2 
#default_exp utils

In [3]:
#exporti
from fastai.tabular.all import * 
from tabnet.model import * 

In [4]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Model creating functions

### classifier

In [18]:
#export
@delegates(TabNetBase.__init__)
def TabNetClassifier(head_func, to, **kwargs):
    return TabNet(head_func, emb_szs=get_emb_sz(to), n_cont=len(to.cont_names), n_out=to.c, **kwargs)

### self supervised

In [19]:
#export
@delegates(TabNetBase.__init__)
def TabNetSelfSupervised(head_func, to, bs=1024, **kwargs):
    n_out = len(get_emb_sz(to)) + len(to.cont_names)
    return TabNet(tabnet_decoder, emb_szs=get_emb_sz(to), n_cont=len(to.cont_names), n_out=n_out, **kwargs)

# Self Supervised Data Loader 

In [20]:
#export
class ReadTabBatchIdentity(ItemTransform):
    
    def __init__(self, to): store_attr()
        
    def encodes(self, to):
        if not to.with_cont: res = (tensor(to.cats).long(),)
        else: res = (tensor(to.cats).long(),tensor(to.conts).float())
        res = res + res #
        if to.device is not None: res = to_device(res, to.device)
        return res 

In [21]:
#export
class TabularPandasIdentity(TabularPandas): pass 

In [22]:
#export
@delegates()
class TabDataLoaderIdentity(TabDataLoader):
    do_item = noops
    def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):
        if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatchIdentity(dataset)
        super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)

    def create_batch(self, b): return self.dataset.iloc[b]

TabularPandasIdentity._dl_type = TabDataLoaderIdentity

# Experiment Helpers 

In [29]:
#export
def tabular_pandas(df, cat_names, cont_names, y_names, val_pct=0.2, tabular_type=TabularPandas):
    splits = RandomSplitter(valid_pct=val_pct)(range_of(df))
    to = tabular_type(df, procs=[Categorify, FillMissing,Normalize], cont_names=cont_names, cat_names=cat_names,
                           y_names=y_names, splits=splits, y_block=CategoryBlock())
    return to

In [30]:
#export 
@delegates(TabNetClassifier)
def tabnet_df_classifier(df, cat_names, cont_names, y_names, cbs=[], enc=None, val_pct=0.2, **kwargs):
    to = tabular_pandas(df, cat_names, cont_names, y_names, val_pct=val_pct)
    dls = to.dataloaders(bs=kwargs['bs'])
    model = TabNetClassifier(linear_head, to, **tabnet_args)
    if enc is not None: model.enc = enc
    cbs=[SetPrior(), MaskRegularizer(kwargs['lambda_sparse']), *cbs]
    return Learner(dls, model, CrossEntropyLossFlat(), cbs=cbs, metrics=[accuracy])

In [31]:
#export
@delegates(TabNetSelfSupervised)
def tabnet_df_self_sup(df, cat_names, cont_names, y_names, cbs=[], curriculum=False, **kwargs):
    to = tabular_pandas(df, cat_names, cont_names, y_names, tabular_type=TabularPandasIdentity)
    dls = to.dataloaders(bs=kwargs['bs'])
    dls.n_inp = 2
    cbs = [SetPrior(), TabularMasking(p=0.8, curriculum=curriculum), MaskRegularizer(kwargs['lambda_sparse']), *cbs]
    model = TabNetSelfSupervised(tabnet_decoder, to, **kwargs)
    return Learner(dls, model, cbs=cbs, loss_func=MaskReconstructionLoss())

In [32]:
#export 
@delegates(tabnet_df_self_sup)
def score_before_after_ss(df, ds_params, model_params, cycle_lr=[(10, 1e-1/2)]*3, **kwargs):
    learn = tabnet_df_classifier(df, **ds_params, tabnet_args=model_params)
    learn.fit_one_cycle(*cycle_lr[0]) 
    before = learn.get_preds()
    
    learn_ss = tabnet_df_self_sup(df, **ds_params, tabnet_args=model_params)
    learn_ss.fit_one_cycle(*cycle_lr[1])
    
    learn = tabnet_df_classifier(df, **ds_params, tabnet_args=model_params, enc=learn_ss.model.enc)
    learn.fit_one_cycle(*cycle_lr[2]) 
    after = learn.get_preds()

    return (before, after)

# Export

In [33]:
from nbdev.export import notebook2script
notebook2script()

Converted 01_core.ipynb.
Converted 02_model.ipynb.
Converted 03_experiments.ipynb.
Converted 04_utils.ipynb.
Converted index.ipynb.
