In [1]:
%load_ext autoreload
%autoreload 2 

In [2]:
from fastai.tabular.all import * 
from mock import Mock
from tabnet.utils import *
from tabnet.model import *

In [3]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [4]:
class LinDecoder(Module):
    def __init__(self, n_cat, n_cont, n_d, ps=0.1, **kwargs):
        store_attr()
        
        self.decoder = nn.Sequential(
            LinBnDrop(n_d, 256, p=ps, act=Mish()),
            LinBnDrop(256, 512, p=ps, act=Mish()),
            LinBnDrop(512, 1024, p=ps, act=Mish())
        )
        
        self.decoder_cont = nn.Sequential(
            LinBnDrop(1024, n_cont, p=ps, bn=False, act=None),
        )
        
        self.decoder_cat = LinBnDrop(1024, n_cat, p=ps, bn=False, act=None)

        
    def forward(self, x):
        xs = x.sum(dim=1)
        decoded = self.decoder(xs)
        decoded_cats = self.decoder_cat(decoded)
        decoded_conts = self.decoder_cont(decoded)

        return torch.cat([decoded_cats, decoded_conts], dim=1)

In [5]:
def _create_shared_blocks(n_in, n_out, n_shared):
    return [_initial_block(n_in, n_out)] + \
            [_rest_block(n_out) for _ in range(n_shared-1)]

def _initial_block(n_in, n_out):
    return nn.Linear(n_in, 2*n_out, bias=False)

def _rest_block(n):
    return nn.Linear(n, 2*n, bias=False)


In [6]:
class TabNetDec2(TabNetBase):
    def __init__(self, n_cat, n_cont, n_d, n_a, n_shared_ft_blocks, n_dec_steps, **kwargs):
        store_attr()
        super().__init__(n_d=n_d, n_a=n_a, n_shared_ft_blocks=n_shared_ft_blocks, **kwargs)
        
        shared_ft_blocks = _create_shared_blocks(self.n_d, self.n_d + self.n_a, self.n_shared_ft_blocks)
        
        self.steps = nn.ModuleList([
                            nn.Sequential(
                                self._create_feature_transform(shared_ft_blocks),
                                nn.Linear(self.n_d+self.n_a, n_cat+n_cont)) for _ in range(self.n_dec_steps)
                        ])
        
        
        
        
    def forward(self, x):
        xs = x.sum(dim=1)
        
        output = 0 
        
        for step in self.steps:
            output = output + step(xs)
        
        return output

In [7]:
class MRL1(Module):
    def __init__(self, lambda_reg=1e-4, eps=1e-5): store_attr()
    
    def forward(self, preds, targ):
        preds, targ = preds*(1-self.S), targ*(1-self.S)
        norm = (targ - targ.mean(dim=0)).pow(2).sum(dim=0).sqrt()
        norm_mask = norm >= 1e-6
        norm = norm[norm_mask]
        error = (preds - targ)
        error = error[:,norm_mask]

        loss = (error / norm).pow(2).sum(dim=1).mean()
        return loss

In [8]:
class MRL2(Module):
    def __init__(self, lambda_reg=1e-4, eps=1e-5): store_attr()
    
    def forward(self, preds, targ):
        preds, targ = preds*(1-self.S), targ*(1-self.S)
        norm = (targ - targ.mean(dim=0)).pow(2).sum(dim=0).sqrt()
        error = (preds - targ)

        loss = (error / norm).abs().sum(dim=1).mean()
        return loss

# Tests

### SS Forest

In [9]:
data_dir = Path('./data')

In [10]:
def extract_gzip(file, dest=None):
    import gzip
    dest = dest or Path(dest)
    with gzip.open(file, 'rb') as f_in:
        with open(dest / file.stem, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)

In [11]:
forest_type_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz'
forest_path = untar_data(forest_type_url, dest=data_dir, extract_func=extract_gzip)

In [12]:
target = "Covertype"

cat_names = [
    "Wilderness_Area1", "Wilderness_Area2", "Wilderness_Area3",
    "Wilderness_Area4", "Soil_Type1", "Soil_Type2", "Soil_Type3", "Soil_Type4",
    "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8", "Soil_Type9",
    "Soil_Type10", "Soil_Type11", "Soil_Type12", "Soil_Type13", "Soil_Type14",
    "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19",
    "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24",
    "Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29",
    "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34",
    "Soil_Type35", "Soil_Type36", "Soil_Type37", "Soil_Type38", "Soil_Type39",
    "Soil_Type40"
]

cont_names = [
    "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways",
    "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points"
]

feature_columns = (
    cont_names + cat_names + [target])

params = dict(cont_names = cont_names, y_names = target, cat_names = cat_names)
procs=[Categorify, FillMissing, Normalize]
model_params = dict(n_d=64, n_a=64, n_steps=5, virtual_batch_size=512, gamma=1.5, bs=1024*16,
                    lambda_sparse=1e-4, momentum=0.7, n_shared_ft_blocks=2, n_independent_ft_blocks=2,
                    n_dec_steps=10, p=0.8)

In [13]:
df = pd.read_csv(forest_path, header=None, names=feature_columns).sample(n=200_000)
df.shape

(200000, 55)

In [14]:
val_pct = 0.2
curriculum = False

In [15]:
to = tabular_pandas(df, **params, 
                    tabular_type=TabularPandasIdentity, val_pct=val_pct)
loss_func = MRL2()

#head = lambda n_cat, n_cont, **kwargs: LinDecoder(n_cat, n_cont, **kwargs)
head = lambda n_cat, n_cont, **kwargs: TabNetDec2(n_cat, n_cont, **kwargs)

In [17]:
vals = []
for val_pct in sorted([0.8, 0.9, 0.99]*3):
    print(val_pct)
    (before, after) = score_before_after_ss(df, params, val_pct=val_pct, decoder_head=head, 
                        loss_func=loss_func, cycle_lr=[(250, slice(5e-3, 1e-1)), (30, slice(1e-3, 1e-1/2)),
                                                       (250, slice(5e-3, 1e-1))], **model_params)
    vals.append((before.item(), after.item(), val_pct))
    res = pd.DataFrame(vals, columns=['before','after','val'])
    res.to_csv('forest_res.csv')

0.8
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,4.139518,3.737472,0.076131,00:01
1,3.725219,3.341648,0.313619,00:01
2,3.334324,2.997022,0.429888,00:01
3,2.977854,2.546309,0.476969,00:01
4,2.638077,1.815441,0.495012,00:01
5,2.358522,1.351173,0.487125,00:01
6,2.141742,1.252292,0.481375,00:01
7,1.972914,1.241264,0.489013,00:01
8,1.838879,1.235695,0.489038,00:01
9,1.730273,1.204323,0.488831,00:01


epoch,train_loss,valid_loss,mse,time
0,0.433419,0.224169,0.264376,00:06
1,0.327559,0.343282,4.049806,00:06
2,0.260606,0.181392,0.814618,00:06
3,0.236425,0.201597,1.787496,00:06
4,0.218239,0.210605,6.216951,00:06
5,0.207673,0.136232,0.317053,00:06
6,0.201271,0.370284,71.247108,00:06
7,0.192511,0.166269,1.196614,00:06
8,0.18573,0.16358,0.428983,00:06
9,0.18051,0.166755,0.441369,00:06


epoch,train_loss,valid_loss,accuracy,time
0,4.061196,3.542257,0.26755,00:01
1,3.757657,1.895187,0.447587,00:01
2,3.454114,1.969438,0.367231,00:01
3,3.145954,1.92448,0.363963,00:01
4,2.839163,1.828862,0.414181,00:01
5,2.565533,1.648348,0.472731,00:01
6,2.342066,1.881508,0.425687,00:01
7,2.166826,1.656682,0.433781,00:01
8,2.02941,1.534977,0.477206,00:01
9,1.915222,1.474438,0.481969,00:01


0.8
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,3.726172,3.58581,0.337281,00:01
1,3.250528,2.408999,0.551031,00:01
2,2.845424,1.92609,0.464575,00:01
3,2.475776,1.649496,0.561069,00:01
4,2.186068,1.281489,0.5364,00:01
5,1.973571,1.130179,0.572688,00:01
6,1.81288,1.048169,0.5957,00:01
7,1.685146,1.019445,0.578094,00:01
8,1.582782,1.033322,0.620663,00:01
9,1.499103,1.07184,0.610575,00:01


epoch,train_loss,valid_loss,mse,time
0,0.38947,0.273165,0.308772,00:06
1,0.281935,0.116235,0.260398,00:06
2,0.233775,0.145821,0.280364,00:06
3,0.213038,0.161714,0.285398,00:06
4,0.209578,0.194656,9.771667,00:06
5,0.210321,0.167695,1.074916,00:06
6,0.203976,0.209994,3.978366,00:06
7,0.200576,0.194375,1.010305,00:06
8,0.197356,0.24851,6.281077,00:06
9,0.194695,0.179911,0.408046,00:06


epoch,train_loss,valid_loss,accuracy,time
0,3.997751,3.210176,0.085094,00:01
1,3.749769,3.086391,0.018,00:01
2,3.506187,2.012424,0.344356,00:01
3,3.256727,2.20423,0.387994,00:01
4,2.998519,1.625342,0.424444,00:01
5,2.745578,1.588727,0.497019,00:01
6,2.517602,1.939551,0.420369,00:01
7,2.333917,2.020789,0.438713,00:01
8,2.18776,2.912214,0.405944,00:01
9,2.067704,2.209891,0.415269,00:01


0.8
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,3.774495,3.430426,0.448856,00:01
1,3.355254,3.434314,0.484437,00:01
2,2.970603,2.275056,0.488687,00:01
3,2.640136,1.949161,0.494669,00:01
4,2.356958,1.462575,0.475406,00:01
5,2.135918,1.323517,0.46895,00:01
6,1.962436,1.226029,0.473981,00:01
7,1.828547,1.185676,0.537425,00:01
8,1.719703,1.152342,0.524806,00:01
9,1.628528,1.12934,0.549956,00:01


epoch,train_loss,valid_loss,mse,time
0,0.40954,0.226167,0.358604,00:06
1,0.300474,0.138199,0.314854,00:06
2,0.244769,0.127898,0.394979,00:06
3,0.220534,0.156833,0.399586,00:06
4,0.209007,0.187931,2.073673,00:06
5,0.202595,0.207638,2.988817,00:06
6,0.201268,0.249086,3.332977,00:06
7,0.197788,0.189968,0.657509,00:06
8,0.197826,0.179951,0.868157,00:06
9,0.19234,0.151371,0.434526,00:06


epoch,train_loss,valid_loss,accuracy,time
0,4.301208,8.199581,0.000244,00:01
1,4.026732,3.102846,0.464269,00:01
2,3.774995,2.245693,0.485519,00:01
3,3.528883,1.91049,0.488394,00:01
4,3.283266,2.361567,0.490194,00:01
5,3.029251,2.477825,0.489487,00:01
6,2.783129,3.75542,0.490294,00:01
7,2.565883,2.998173,0.486475,00:01
8,2.384658,3.569346,0.468769,00:01
9,2.236729,2.859421,0.461794,00:01


0.9
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,4.108922,3.75726,0.039644,00:01
1,3.892401,3.592696,0.326378,00:01
2,3.677664,3.39091,0.410372,00:01
3,3.483351,3.133496,0.453456,00:01
4,3.286992,2.929617,0.488811,00:01
5,3.081979,2.25169,0.476306,00:01
6,2.874218,2.013922,0.508917,00:01
7,2.680626,1.701505,0.512189,00:01
8,2.512136,1.475123,0.504828,00:01
9,2.366914,1.394394,0.496717,00:01


epoch,train_loss,valid_loss,mse,time
0,0.43638,0.274154,0.296914,00:06
1,0.321497,0.207905,0.271735,00:06
2,0.270775,0.126897,0.230755,00:06
3,0.236852,0.144304,0.239685,00:06
4,0.219969,0.179269,2.307996,00:06
5,0.211396,0.25376,3.297386,00:06
6,0.200483,0.177804,3.239417,00:06
7,0.190901,0.219955,2.289304,00:06
8,0.181709,0.187552,0.554123,00:06
9,0.179652,0.245513,5.113589,00:06


epoch,train_loss,valid_loss,accuracy,time
0,4.095952,3.062075,0.362672,00:01
1,3.891111,4.142721,0.247006,00:01
2,3.713277,6.711753,0.085033,00:01
3,3.531996,3.873977,0.475278,00:01
4,3.346938,4.208906,0.473372,00:01
5,3.196444,3.454185,0.463339,00:01
6,3.03231,4.841044,0.479244,00:01
7,2.861942,5.483731,0.486789,00:01
8,2.692423,3.335411,0.457139,00:01
9,2.53593,4.174023,0.378694,00:01


0.9
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,4.513237,4.079873,0.000167,00:01
1,4.304941,3.804252,0.053606,00:01
2,4.099479,3.648759,0.188611,00:01
3,3.889983,3.276885,0.372283,00:01
4,3.691117,3.224535,0.433378,00:01
5,3.474872,2.638498,0.397778,00:01
6,3.251215,2.373159,0.397417,00:01
7,3.037549,1.891449,0.397828,00:01
8,2.841826,1.64501,0.406083,00:01
9,2.667726,1.375621,0.445728,00:01


epoch,train_loss,valid_loss,mse,time
0,0.442565,0.262776,0.391719,00:06
1,0.338352,0.238861,0.266655,00:06
2,0.275996,0.18276,0.24259,00:06
3,0.247917,0.148118,0.241771,00:06
4,0.229503,0.676445,111.28331,00:06
5,0.217078,0.196504,0.711122,00:06
6,0.212503,0.249932,1.2992,00:06
7,0.200471,0.201026,0.855883,00:06
8,0.188616,0.168555,0.564935,00:06
9,0.181197,0.168578,0.286142,00:06


epoch,train_loss,valid_loss,accuracy,time
0,4.11638,4.125823,0.373067,00:01
1,3.960918,2.532278,0.486944,00:01
2,3.799617,3.450465,0.472144,00:01
3,3.64151,6.149631,0.404183,00:01
4,3.477764,5.070448,0.3699,00:01
5,3.305907,5.594799,0.437506,00:01
6,3.13439,4.069432,0.515933,00:01
7,2.962692,9.650477,0.4169,00:01
8,2.791376,5.287096,0.492622,00:01
9,2.635165,7.346716,0.520383,00:01


0.9
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,4.144164,3.677304,0.086672,00:01
1,3.935677,3.505528,0.303061,00:01
2,3.721814,3.173599,0.471878,00:01
3,3.47997,2.62145,0.484428,00:01
4,3.229294,2.503603,0.483139,00:01
5,2.971328,1.897852,0.490189,00:01
6,2.741946,1.746477,0.478983,00:01
7,2.54547,1.516796,0.481683,00:01
8,2.384405,1.413778,0.479272,00:01
9,2.251055,1.329707,0.463606,00:01


epoch,train_loss,valid_loss,mse,time
0,0.404577,0.194979,0.269156,00:06
1,0.317575,0.242182,0.57898,00:06
2,0.264727,0.223519,3.070946,00:06
3,0.229812,0.132729,0.282196,00:06
4,0.213152,0.130171,0.250881,00:06
5,0.202135,0.147389,0.239597,00:06
6,0.192508,0.216254,2.014632,00:06
7,0.195644,0.239412,3.342099,00:06
8,0.190008,0.245973,2.205659,00:06
9,0.185829,0.177898,0.562191,00:06


epoch,train_loss,valid_loss,accuracy,time
0,4.286608,6.24538,6e-06,00:01
1,4.063598,3.430273,0.000311,00:01
2,3.834165,2.891346,0.055467,00:01
3,3.6206,2.111374,0.431661,00:01
4,3.422087,2.151081,0.478067,00:01
5,3.234396,1.84001,0.487211,00:01
6,3.048969,1.606295,0.476456,00:01
7,2.870111,1.368573,0.497367,00:01
8,2.698742,1.263747,0.483717,00:01
9,2.542425,1.475217,0.480983,00:01


0.99
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,3.821472,3.696147,0.364924,00:01
1,3.438746,3.162863,0.523222,00:01
2,3.087744,2.072742,0.607778,00:01
3,2.758304,1.871698,0.621061,00:01
4,2.469301,1.527684,0.612131,00:01
5,2.245564,1.532975,0.602717,00:01
6,2.063101,1.155545,0.621985,00:01
7,1.916866,1.001468,0.661313,00:01
8,1.800308,0.94603,0.657763,00:01
9,1.699147,0.934956,0.642283,00:01


epoch,train_loss,valid_loss,mse,time
0,0.384979,0.220591,0.651068,00:06
1,0.278645,0.179378,0.763952,00:06
2,0.235253,0.160273,0.53369,00:06
3,0.213647,0.154689,1.202471,00:06
4,0.206146,0.24783,9.778967,00:06
5,0.203627,0.162769,0.978642,00:06
6,0.203162,0.169747,0.747024,00:06
7,0.211615,0.203619,2.140245,00:06
8,0.20752,0.175232,0.2802,00:06
9,0.199723,0.188502,0.782293,00:06


epoch,train_loss,valid_loss,accuracy,time
0,4.233669,4.176957,0.032354,00:01
1,3.885441,2.91602,0.248258,00:01
2,3.54676,2.099547,0.373025,00:01
3,3.209785,1.851298,0.408354,00:01
4,2.907145,1.549081,0.355712,00:01
5,2.636059,1.732926,0.369677,00:01
6,2.41211,1.827572,0.369727,00:01
7,2.235273,1.454508,0.492141,00:01
8,2.094964,1.467043,0.465222,00:01
9,1.979247,1.382216,0.422212,00:01


0.99
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,3.845534,3.793951,0.482692,00:01
1,3.494403,2.841597,0.493566,00:01
2,3.184952,2.464523,0.484232,00:01
3,2.906916,2.342381,0.489899,00:01
4,2.651425,1.925036,0.488576,00:01
5,2.419274,1.437495,0.495217,00:01
6,2.227436,1.303676,0.466808,00:01
7,2.074139,1.198937,0.490773,00:01
8,1.952275,1.20642,0.495616,00:01
9,1.847792,1.154836,0.553015,00:01


epoch,train_loss,valid_loss,mse,time
0,0.42811,0.165234,0.270907,00:06
1,0.328746,0.120202,0.249819,00:06
2,0.272107,0.22207,1.277551,00:06
3,0.240351,0.144639,0.27212,00:06
4,0.217708,0.143709,0.346313,00:06
5,0.206628,0.234741,4.375148,00:06
6,0.201165,0.160777,0.87466,00:06
7,0.190227,0.134083,0.396544,00:06
8,0.184265,0.208931,2.976059,00:06
9,0.187685,0.19395,0.949008,00:06


epoch,train_loss,valid_loss,accuracy,time
0,3.724135,3.172584,0.100071,00:01
1,3.365587,1.969381,0.489798,00:01
2,3.033286,1.856953,0.483136,00:01
3,2.741342,2.569924,0.489126,00:01
4,2.497225,1.776818,0.488061,00:01
5,2.300479,1.75182,0.488439,00:01
6,2.130813,2.227268,0.488152,00:01
7,1.994123,1.698404,0.487338,00:01
8,1.877038,1.542839,0.493308,00:01
9,1.780722,1.333005,0.484061,00:01


0.99
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,3.89645,3.79501,0.47802,00:01
1,3.560219,3.01498,0.478919,00:01
2,3.214176,2.417084,0.487859,00:01
3,2.895096,2.000299,0.478712,00:01
4,2.612025,1.744881,0.436086,00:01
5,2.375031,1.503954,0.478606,00:01
6,2.193384,1.339937,0.468788,00:01
7,2.051649,1.286672,0.442298,00:01
8,1.931222,1.244841,0.443308,00:01
9,1.831763,1.251488,0.452455,00:01


epoch,train_loss,valid_loss,mse,time
0,0.474374,0.191109,0.261088,00:07
1,0.340417,0.151098,0.428898,00:06
2,0.275203,0.195647,1.137054,00:06
3,0.251974,0.226837,2.023103,00:06
4,0.231769,0.160876,0.301772,00:06
5,0.211401,0.173226,1.950643,00:06
6,0.21114,0.183071,0.539062,00:06
7,0.205109,0.180771,0.946773,00:06
8,0.198333,0.160334,0.702493,00:06
9,0.191213,0.194855,0.374489,00:06


epoch,train_loss,valid_loss,accuracy,time
0,3.724545,3.996199,0.03352,00:01
1,3.381918,1.896366,0.418152,00:01
2,3.024107,1.802779,0.47649,00:01
3,2.70599,1.755567,0.486667,00:01
4,2.433202,2.122566,0.488071,00:01
5,2.216664,2.166627,0.488162,00:01
6,2.040185,1.874195,0.487783,00:01
7,1.90456,1.676913,0.476884,00:01
8,1.796764,1.510167,0.478581,00:01
9,1.707939,1.385555,0.451146,00:01


In [None]:
vals = []
for val_pct in sorted([0.2, 0.4, 0.6]*3):
    print(val_pct)
    (before, after) = score_before_after_ss(df, params, val_pct=val_pct, decoder_head=head, 
                        loss_func=loss_func, cycle_lr=[(350, slice(5e-3, 1e-1)), (100, slice(1e-3, 1e-1/2)),
                                                       (350, slice(5e-3, 1e-1))], **model_params)
    vals.append((before.item(), after.item(), val_pct))
    res = pd.DataFrame(vals, columns=['before','after','val'])
    res.to_csv('forest_res_06.csv')

0.2
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,2.542612,2.047883,0.48175,00:03
1,1.723152,1.324158,0.488,00:03
2,1.411138,1.289773,0.488,00:03
3,1.243166,1.210041,0.487925,00:03
4,1.132474,1.132369,0.510275,00:03
5,1.053638,1.080519,0.5469,00:03
6,0.992569,0.98555,0.58605,00:03
7,0.94292,0.889092,0.607125,00:03
8,0.902542,0.789889,0.64315,00:03
9,0.867982,0.757093,0.666575,00:03


epoch,train_loss,valid_loss,mse,time
0,0.425444,0.192531,0.259698,00:06
1,0.296967,0.184311,0.274223,00:06
2,0.242917,0.165568,0.628115,00:06
3,0.210719,0.196862,0.36514,00:06
4,0.189669,0.142374,0.477136,00:06
5,0.171603,0.123528,0.311544,00:06
6,0.16167,0.120942,0.385832,00:06
7,0.152687,0.156968,1.675259,00:06
8,0.147407,0.150677,0.954363,00:06
9,0.145116,0.201453,2.610399,00:06


epoch,train_loss,valid_loss,accuracy,time
0,2.893256,1.534404,0.489175,00:03
1,2.06,1.247031,0.4828,00:03
2,1.676297,1.211532,0.4513,00:03
3,1.462592,1.113515,0.479275,00:03
4,1.327069,1.312574,0.478175,00:03
5,1.231087,1.325644,0.490625,00:03
6,1.158819,1.888572,0.357275,00:03
7,1.102047,1.890496,0.3255,00:03
8,1.057727,1.356754,0.4806,00:03
9,1.021686,1.721212,0.47825,00:03


0.2
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,2.511116,1.714741,0.491875,00:03
1,1.713044,1.224997,0.484375,00:03
2,1.388932,1.243315,0.49225,00:03
3,1.205593,1.187258,0.495225,00:03
4,1.08995,1.05739,0.521675,00:03
5,1.007218,0.903597,0.58345,00:03
6,0.945434,0.801752,0.632575,00:03
7,0.898721,0.753362,0.657825,00:03
8,0.861343,0.728651,0.675575,00:03
9,0.830243,0.761109,0.663075,00:03


epoch,train_loss,valid_loss,mse,time
0,0.427182,0.152538,0.251744,00:06
1,0.30156,0.169692,0.281264,00:06
2,0.242689,0.107352,0.233324,00:06
3,0.203358,0.114165,0.2336,00:06
4,0.182859,0.104486,0.259688,00:06
5,0.166996,0.114525,0.252742,00:06
6,0.155385,0.105283,0.234046,00:06
7,0.148638,0.170666,0.726533,00:06
8,0.145308,0.133794,0.812536,00:06
9,0.140801,0.124067,0.465987,00:06


epoch,train_loss,valid_loss,accuracy,time
0,3.1194,1.790334,0.364875,00:03
1,2.293689,2.592206,0.39775,00:03
2,1.85807,2.406851,0.337125,00:03
3,1.625405,2.0029,0.296425,00:03
4,1.481557,2.445015,0.2768,00:03
5,1.382728,1.308672,0.4584,00:03
6,1.309696,1.194678,0.50045,00:03
7,1.252189,1.108769,0.535925,00:03
8,1.207957,1.288518,0.556425,00:03
9,1.172304,1.612916,0.50975,00:03


0.2
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,2.797851,2.315633,0.4841,00:03
1,1.914398,1.336477,0.483675,00:03
2,1.553723,1.274852,0.4907,00:03
3,1.354516,1.230929,0.4918,00:03
4,1.222311,1.178519,0.511675,00:03
5,1.127417,1.143088,0.53605,00:03
6,1.056412,1.03752,0.566025,00:03
7,1.00148,0.92336,0.603575,00:03
8,0.956178,0.816556,0.644525,00:03
9,0.916994,0.740257,0.6823,00:03


epoch,train_loss,valid_loss,mse,time
0,0.44129,0.264464,0.289075,00:06
1,0.310948,0.175302,0.252907,00:06
2,0.248991,0.100514,0.250075,00:06
3,0.210834,0.119019,0.237687,00:06
4,0.187722,0.117496,0.246469,00:06
5,0.172652,0.098166,0.234636,00:06
6,0.161081,0.111802,0.257463,00:06
7,0.154311,0.128344,0.427807,00:06
8,0.14969,0.151649,0.567499,00:06
9,0.146121,0.218064,4.295249,00:06


epoch,train_loss,valid_loss,accuracy,time
0,3.331735,1.733756,0.395725,00:03
1,2.495712,2.881155,0.45965,00:03
2,1.991578,1.65104,0.388,00:03
3,1.716689,1.284178,0.470375,00:03
4,1.538463,1.688674,0.4485,00:03
5,1.408174,1.739173,0.468575,00:03
6,1.311865,10.15593,0.377375,00:03
7,1.23932,4.284959,0.51275,00:03
8,1.182898,3.329475,0.5054,00:03
9,1.138292,2.323338,0.511625,00:03


0.4
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,2.658454,1.952277,0.478837,00:02
1,1.859892,1.267524,0.49775,00:02
2,1.52006,1.093912,0.583438,00:02
3,1.318324,0.932787,0.638963,00:02
4,1.178398,0.892951,0.650125,00:02
5,1.076821,0.856163,0.656887,00:02
6,1.000306,0.772874,0.676337,00:02
7,0.942217,0.72733,0.684362,00:02
8,0.895449,0.709951,0.68795,00:02
9,0.857286,0.692686,0.696113,00:02


epoch,train_loss,valid_loss,mse,time
0,0.441267,0.359997,0.353373,00:06
1,0.315038,0.160771,0.281045,00:06
2,0.249769,0.124134,0.32288,00:06
3,0.21518,0.133957,0.364263,00:06
4,0.190354,0.109746,0.318783,00:06
5,0.173169,0.124863,0.31441,00:06
6,0.160857,0.124602,0.348287,00:06
7,0.15381,0.143329,0.580632,00:06
8,0.148448,0.154374,0.484542,00:06
9,0.147051,0.172225,2.29053,00:06


epoch,train_loss,valid_loss,accuracy,time
0,2.576403,1.551624,0.514625,00:02
1,1.823246,1.19885,0.473938,00:02
2,1.508004,1.028284,0.538063,00:02
3,1.334645,0.997051,0.531587,00:02
4,1.221432,0.988863,0.546037,00:02
5,1.138862,0.979631,0.556387,00:02
6,1.074136,1.038468,0.580725,00:02
7,1.021301,1.171168,0.542813,00:02
8,0.976058,1.385954,0.525563,00:02
9,0.937849,1.034886,0.54845,00:02


0.4
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,2.620463,1.910622,0.597763,00:02
1,1.828932,1.087404,0.650625,00:02
2,1.488951,1.043166,0.61885,00:02
3,1.28785,0.976299,0.64135,00:02
4,1.151634,0.896488,0.651062,00:02
5,1.054657,0.825653,0.673513,00:02
6,0.983582,0.781652,0.681125,00:02
7,0.928156,0.751486,0.68,00:02
8,0.885678,0.722891,0.690838,00:02
9,0.85041,0.713059,0.689462,00:02


epoch,train_loss,valid_loss,mse,time
0,0.392891,0.165495,0.288526,00:06
1,0.286281,0.141862,0.290374,00:06
2,0.233244,0.128718,0.26076,00:06
3,0.202359,0.186258,0.965069,00:06
4,0.182222,0.21487,1.265433,00:06
5,0.169104,0.159013,0.462345,00:06
6,0.157859,0.171892,1.112427,00:06
7,0.149343,0.126485,0.352058,00:06
8,0.144212,0.178613,2.737664,00:06
9,0.144147,0.16574,1.365534,00:06


epoch,train_loss,valid_loss,accuracy,time
0,3.320244,2.203493,0.4821,00:02
1,2.423777,2.950368,0.485963,00:02
2,1.925048,1.78564,0.414925,00:02
3,1.645736,1.269273,0.471613,00:02
4,1.466301,1.291502,0.44845,00:02
5,1.342537,1.368177,0.4486,00:02
6,1.249179,1.190985,0.514238,00:02
7,1.178002,1.26084,0.497963,00:02
8,1.120072,1.193889,0.5545,00:02
9,1.070763,1.160099,0.540925,00:02


0.4
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,2.661966,2.190047,0.5449,00:03
1,1.917904,1.307454,0.51565,00:03
2,1.562584,1.012354,0.593138,00:03
3,1.359061,1.08857,0.543025,00:03
4,1.220735,0.882322,0.658475,00:03
5,1.116665,0.817358,0.662425,00:03
6,1.035688,0.799727,0.666488,00:03
7,0.973237,0.789786,0.662063,00:03
8,0.923665,0.748376,0.666925,00:03
9,0.88381,0.722434,0.685875,00:03


epoch,train_loss,valid_loss,mse,time
0,0.434761,0.23211,0.292904,00:07
1,0.319688,0.128304,0.224837,00:07
2,0.26083,0.113026,0.230751,00:07
3,0.222063,0.12429,0.228309,00:07
4,0.195268,0.115404,0.231656,00:07
5,0.179796,0.115347,0.24376,00:07
6,0.170902,0.116304,0.240628,00:07
7,0.161745,0.124504,0.440394,00:07
8,0.152833,0.123808,0.353338,00:07
9,0.145854,0.131982,0.831792,00:07


epoch,train_loss,valid_loss,accuracy,time
0,3.170588,1.821355,0.48855,00:02
1,2.212634,1.931304,0.458162,00:02
2,1.768544,1.786685,0.473388,00:02
3,1.529771,1.469266,0.505638,00:02
4,1.377277,1.347237,0.541188,00:02
5,1.272344,1.144675,0.554213,00:02
6,1.1931,1.188084,0.52205,00:02
7,1.131038,0.994093,0.561363,00:02
8,1.079189,1.172684,0.4886,00:02
9,1.036173,1.628623,0.449713,00:02


0.6
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,3.422367,3.029757,0.546683,00:02
1,2.728855,1.866964,0.553158,00:02
2,2.198314,1.220078,0.595183,00:02
3,1.875152,1.107026,0.581667,00:02
4,1.661659,1.000857,0.600617,00:02
5,1.511032,0.958371,0.621633,00:02
6,1.393942,0.887998,0.645333,00:02
7,1.301815,0.911618,0.621383,00:02
8,1.226161,0.874133,0.654017,00:02
9,1.163633,0.883079,0.630933,00:02


epoch,train_loss,valid_loss,mse,time
0,0.43101,0.366943,0.42578,00:06
1,0.314404,0.126471,0.236889,00:06
2,0.253264,0.139644,0.247976,00:06
3,0.215529,0.090911,0.22881,00:06
4,0.191207,0.104359,0.226914,00:06
5,0.172839,0.12133,0.267857,00:06
6,0.160866,0.109259,0.269155,00:06
7,0.155842,0.133652,0.35903,00:06
8,0.150346,0.165643,0.820758,00:06
9,0.145975,0.13057,0.4345,00:06


epoch,train_loss,valid_loss,accuracy,time
0,3.595074,3.894434,0.487783,00:02
1,3.008435,2.317753,0.488008,00:02
2,2.511711,4.921638,0.488008,00:02
3,2.175171,1.844493,0.488108,00:02
4,1.953831,1.867211,0.492008,00:02
5,1.802418,2.667583,0.513133,00:02
6,1.694982,1.371579,0.399458,00:02
7,1.613646,1.467972,0.492158,00:02
8,1.548957,1.490534,0.502983,00:02
9,1.494435,1.491483,0.484942,00:02


0.6
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8}


epoch,train_loss,valid_loss,accuracy,time
0,3.572642,3.470832,0.478025,00:02
1,2.818635,1.862957,0.50865,00:02
2,2.239708,1.311928,0.546108,00:02
3,1.883254,1.117377,0.592558,00:02
4,1.648524,1.042188,0.622225,00:02
5,1.482383,0.971415,0.632175,00:02
6,1.358013,1.056541,0.606983,00:02
7,1.261214,1.04173,0.608242,00:02
8,1.184683,0.960777,0.634175,00:02
9,1.124302,0.992899,0.627767,00:02


epoch,train_loss,valid_loss,mse,time
0,0.487357,0.292296,0.349379,00:06
1,0.335233,0.161421,0.234944,00:06
2,0.275791,0.113262,0.237314,00:06
3,0.23245,0.122289,0.229871,00:06
4,0.20332,0.09888,0.23097,00:06
5,0.183196,0.11211,0.236135,00:06
6,0.168965,0.122917,0.249181,00:06
7,0.159593,0.133535,0.269581,00:06
8,0.157052,0.212621,0.646718,00:06
9,0.155405,0.153611,0.381058,00:06


In [None]:
dls = to.dataloaders(bs=model_params['bs'])
dls.n_inp = 2
cbs = [SetPrior(), TabularMasking(p=model_params['p'], curriculum=curriculum),
       MaskRegularizer(model_params['lambda_sparse'])]
model = TabNetSelfSupervised(head, to, **model_params)

In [None]:
learn = Learner(dls, model, cbs=cbs, loss_func=loss_func, metrics=[mse])

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(30, slice(1e-3, 1e-1/2))

In [None]:
learn.recorder.plot_loss()

In [None]:
dl = learn.dls.test_dl(df)
preds, targs = learn.get_preds(dl=dl)
mse(preds, targs)

In [None]:
for p, t in zip(preds[0].tolist(), targs[0].tolist()):
    print(p, t)

In [None]:
mp = {**model_params, 'virtual_batch_size':100}

In [None]:
l = tabnet_df_classifier(df, **params, enc=learn.model.enc, val_pct=0.99, **mp)

In [None]:
l.dls.train.bs = l.dls.train.n//2 if l.dls.train.n < model_params['bs'] else model_params['bs']

In [None]:
l.lr_find()

In [None]:
l.fit_one_cycle(100, slice(5e-3, 1e-1))

In [None]:
l.recorder.plot_loss()

### SS Flowers

In [None]:
data = load_iris()
X,y = data['data'], data['target']
cont_names = ['s_len', 's_wid', 'p_len', 'p_wid']
cat_names = []
y_names = 'target'
df = pd.concat([pd.DataFrame(X, columns=cont_names),
           pd.DataFrame(y, columns=[y_names])], axis=1)

In [None]:
model_params = dict(n_d=16, n_a=16, n_steps=3, virtual_batch_size=4, gamma=1.5, bs=32,
                    lambda_sparse=1e-4, momentum=0.7, n_shared_ft_blocks=2, n_independent_ft_blocks=0, p=0.3)
val_pct = 0.2
curriculum = False
head = lambda n_out, **kwargs: LinDecoder(n_out, **kwargs)
loss_func = MSELossFlat()

In [None]:
to = tabular_pandas(df, cat_names, cont_names, y_names, tabular_type=TabularPandasIdentity, val_pct=val_pct)
dls = to.dataloaders(bs=model_params['bs'])
dls.n_inp = 2
cbs = [SetPrior(), TabularMasking(p=model_params['p'], curriculum=curriculum),
       MaskRegularizer(model_params['lambda_sparse'])]
model = TabNetSelfSupervised(head, to, **model_params)

In [None]:
learn = Learner(dls, model, cbs=cbs, loss_func=loss_func, metrics=[mse])

In [None]:
learn.dls.cpu()

In [None]:
learn.fit_one_cycle(30, slice(1e-4, 1e-3))

In [None]:
dl = learn.dls.test_dl(df)
preds, targs = learn.get_preds(dl=dl)

In [None]:
preds[0][:10], targs[0][:10]

### BaseLine

In [None]:
splits = RandomSplitter()(range_of(df))
dls = TabularPandas(df, cat_names=[], cont_names=['s_len', 's_wid', 'p_len', 'p_wid'], splits=splits,
              procs=[Categorify, FillMissing, Normalize], y_names='target', y_block=CategoryBlock()).dataloaders(bs=4)

In [None]:
learn = tabular_learner(dls, metrics=[accuracy])

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(20, 1e-2)

### TabNet

In [None]:
model_params = dict(n_d=16, n_a=16, n_steps=5, virtual_batch_size=512, gamma=1.5, bs=30,
                    lambda_sparse=1e-4, momentum=0.7, n_shared_ft_blocks=2, n_independent_ft_blocks=2)

In [None]:
learn = tabnet_df_classifier(df, cat_names=[], cont_names=['s_len', 's_wid', 'p_len', 'p_wid'],
                                                     y_names='target', val_pct=0.2, **model_params)

In [None]:
learn.summary()

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(10, 1e-2)

# Export

In [None]:
from nbdev.export import notebook2script
notebook2script()