In [1]:
%load_ext autoreload
%autoreload 2 

In [2]:
from fastai.tabular.all import * 
from mock import Mock
from tabnet.utils import *
from tabnet.model import *

In [3]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [4]:
class LinDecoder(Module):
    def __init__(self, n_cat, n_cont, n_d, ps=0.1, **kwargs):
        store_attr()
        
        self.decoder = nn.Sequential(
            LinBnDrop(n_d, 256, p=ps, act=Mish()),
            LinBnDrop(256, 512, p=ps, act=Mish()),
            LinBnDrop(512, 1024, p=ps, act=Mish())
        )
        
        self.decoder_cont = nn.Sequential(
            LinBnDrop(1024, n_cont, p=ps, bn=False, act=None),
        )
        
        self.decoder_cat = LinBnDrop(1024, n_cat, p=ps, bn=False, act=None)

        
    def forward(self, x):
        xs = x.sum(dim=1)
        decoded = self.decoder(xs)
        decoded_cats = self.decoder_cat(decoded)
        decoded_conts = self.decoder_cont(decoded)

        return torch.cat([decoded_cats, decoded_conts], dim=1)

In [5]:
def _create_shared_blocks(n_in, n_out, n_shared):
    return [_initial_block(n_in, n_out)] + \
            [_rest_block(n_out) for _ in range(n_shared-1)]

def _initial_block(n_in, n_out):
    return nn.Linear(n_in, 2*n_out, bias=False)

def _rest_block(n):
    return nn.Linear(n, 2*n, bias=False)


In [6]:
class TabNetDec2(TabNetBase):
    def __init__(self, n_cat, n_cont, n_d, n_a, n_shared_ft_blocks, n_dec_steps, **kwargs):
        store_attr()
        super().__init__(n_d=n_d, n_a=n_a, n_shared_ft_blocks=n_shared_ft_blocks, **kwargs)
        
        shared_ft_blocks = _create_shared_blocks(self.n_d, self.n_d + self.n_a, self.n_shared_ft_blocks)
        
        self.steps = nn.ModuleList([
                            nn.Sequential(
                                self._create_feature_transform(shared_ft_blocks),
                                nn.Linear(self.n_d+self.n_a, n_cat+n_cont)) for _ in range(self.n_dec_steps)
                        ])
        
        
        
        
    def forward(self, x):
        xs = x.sum(dim=1)
        
        output = 0 
        
        for step in self.steps:
            output = output + step(xs)
        
        return output

In [7]:
class MRL1(Module):
    def __init__(self, lambda_reg=1e-4, eps=1e-5): store_attr()
    
    def forward(self, preds, targ):
        preds, targ = preds*(1-self.S), targ*(1-self.S)
        norm = (targ - targ.mean(dim=0)).pow(2).sum(dim=0).sqrt()
        norm_mask = norm >= 1e-6
        norm = norm[norm_mask]
        error = (preds - targ)
        error = error[:,norm_mask]

        loss = (error / norm).pow(2).sum(dim=1).mean()
        return loss

In [8]:
class MRL2(Module):
    def __init__(self, lambda_reg=1e-4, eps=1e-5): store_attr()
    
    def forward(self, preds, targ):
        preds, targ = preds*(1-self.S), targ*(1-self.S)
        norm = (targ - targ.mean(dim=0)).pow(2).sum(dim=0).sqrt()
        error = (preds - targ)

        loss = (error / norm).abs().sum(dim=1).mean()
        return loss

# Tests

### SS Forest

In [9]:
data_dir = Path('./data')

In [10]:
def extract_gzip(file, dest=None):
    import gzip
    dest = dest or Path(dest)
    with gzip.open(file, 'rb') as f_in:
        with open(dest / file.stem, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)

In [11]:
forest_type_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz'
forest_path = untar_data(forest_type_url, dest=data_dir, extract_func=extract_gzip)

In [12]:
target = "Covertype"

cat_names = [
    "Wilderness_Area1", "Wilderness_Area2", "Wilderness_Area3",
    "Wilderness_Area4", "Soil_Type1", "Soil_Type2", "Soil_Type3", "Soil_Type4",
    "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8", "Soil_Type9",
    "Soil_Type10", "Soil_Type11", "Soil_Type12", "Soil_Type13", "Soil_Type14",
    "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19",
    "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24",
    "Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29",
    "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34",
    "Soil_Type35", "Soil_Type36", "Soil_Type37", "Soil_Type38", "Soil_Type39",
    "Soil_Type40"
]

cont_names = [
    "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways",
    "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points"
]

feature_columns = (
    cont_names + cat_names + [target])

params = dict(cont_names = cont_names, y_names = target, cat_names = cat_names)
procs=[Categorify, FillMissing, Normalize]
model_params = dict(n_d=64, n_a=64, n_steps=5, virtual_batch_size=512, gamma=1.5, bs=1024*16,
                    lambda_sparse=1e-4, momentum=0.7, n_shared_ft_blocks=2, n_independent_ft_blocks=2,
                    n_dec_steps=10, p=0.8, curriculum=True)

In [13]:
df = pd.read_csv(forest_path, header=None, names=feature_columns).sample(n=200_000)
df.shape

(200000, 55)

In [14]:
val_pct = 0.2

In [15]:
loss_func = MRL2()

#head = lambda n_cat, n_cont, **kwargs: LinDecoder(n_cat, n_cont, **kwargs)
head = lambda n_cat, n_cont, **kwargs: TabNetDec2(n_cat, n_cont, **kwargs)

In [18]:
vals = []
for val_pct in sorted([0.2, 0.4, 0.6]*2):
    print(val_pct)
    (before, after) = score_before_after_ss(df, params, val_pct=val_pct, decoder_head=head, 
                        loss_func=loss_func, cycle_lr=[(350, slice(5e-3, 1e-1)), (100, slice(1e-3, 1e-1/2)),
                                                       (350, slice(5e-3, 1e-1))], **model_params)
    vals.append((before.item(), after.item(), val_pct))
    res = pd.DataFrame(vals, columns=['before','after','val'])
    res.to_csv('forest_res.csv')

0.2
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8, 'curriculum': True}


epoch,train_loss,valid_loss,accuracy,time
0,2.925017,2.380807,0.482475,00:03
1,1.951404,1.10925,0.602825,00:03
2,1.526027,0.902281,0.663025,00:03
3,1.294128,0.826152,0.672075,00:03
4,1.142889,0.80364,0.67305,00:03
5,1.038442,0.780601,0.6786,00:03
6,0.962587,0.747324,0.685675,00:03
7,0.905106,0.726009,0.693925,00:03
8,0.860777,0.70204,0.694225,00:03
9,0.828547,0.711876,0.697775,00:03


epoch,train_loss,valid_loss,mse,time
0,0.670779,0.201447,0.24683,00:07
1,0.449109,0.169819,0.230805,00:07
2,0.343572,0.135362,0.245165,00:07
3,0.284154,0.129907,0.245388,00:07
4,0.246157,0.139271,0.249408,00:07
5,0.219997,0.154999,0.361067,00:07
6,0.2009,0.128274,0.239571,00:07
7,0.186258,0.160028,0.557504,00:07
8,0.176567,0.196912,1.811295,00:07
9,0.170934,0.159659,0.596846,00:07


epoch,train_loss,valid_loss,accuracy,time
0,2.549387,1.760486,0.50135,00:03
1,1.810809,1.657297,0.50635,00:03
2,1.505343,1.268475,0.522725,00:03
3,1.32632,2.15841,0.498225,00:03
4,1.203488,1.933436,0.437775,00:03
5,1.111365,1.67636,0.4897,00:03
6,1.039941,1.296988,0.58095,00:03
7,0.981495,1.802179,0.582375,00:03
8,0.935443,1.157109,0.6046,00:03
9,0.897912,0.996482,0.627625,00:03


0.2
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8, 'curriculum': True}


epoch,train_loss,valid_loss,accuracy,time
0,2.583025,2.126178,0.488575,00:03
1,1.782283,1.26156,0.488025,00:02
2,1.439053,1.08765,0.522725,00:02
3,1.251008,0.852971,0.646025,00:02
4,1.127322,0.900209,0.64045,00:02
5,1.034545,0.823823,0.63795,00:02
6,0.963329,0.77027,0.661975,00:02
7,0.908466,0.748801,0.67495,00:02
8,0.864825,0.703361,0.698975,00:02
9,0.829618,0.689838,0.702775,00:02


epoch,train_loss,valid_loss,mse,time
0,0.645279,0.275506,0.261162,00:05
1,0.431691,0.174811,0.233458,00:05
2,0.336277,0.149451,0.246447,00:05
3,0.279745,0.165573,0.266106,00:05
4,0.243146,0.133441,0.253561,00:05
5,0.216703,0.127497,0.233456,00:05
6,0.197901,0.127179,0.237474,00:05
7,0.185306,0.135227,0.270134,00:05
8,0.175313,0.145041,0.456044,00:05
9,0.169427,0.274154,7.038696,00:11


epoch,train_loss,valid_loss,accuracy,time
0,2.449894,1.454255,0.446075,00:02
1,1.734469,1.20031,0.40215,00:02
2,1.462893,1.484112,0.56365,00:02
3,1.314777,1.414143,0.56465,00:02
4,1.217202,1.110466,0.5783,00:02
5,1.143801,1.08254,0.558825,00:02
6,1.086506,1.147505,0.521175,00:02
7,1.040931,1.056988,0.56255,00:02
8,1.002486,1.06126,0.556575,00:02
9,0.968182,1.257258,0.51855,00:02


0.4
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8, 'curriculum': True}


epoch,train_loss,valid_loss,accuracy,time
0,2.969132,2.331822,0.5324,00:02
1,2.003845,1.060457,0.588925,00:02
2,1.559512,0.917039,0.65975,00:02
3,1.316756,0.860227,0.664,00:02
4,1.167167,0.831112,0.669312,00:02
5,1.063975,0.812016,0.669475,00:02
6,0.98908,0.772009,0.6726,00:02
7,0.930934,0.746983,0.678963,00:02
8,0.885457,0.720318,0.686725,00:02
9,0.849374,0.703174,0.69405,00:02


epoch,train_loss,valid_loss,mse,time
0,0.660002,0.277593,0.261891,00:05
1,0.444521,0.128585,0.219154,00:05
2,0.34056,0.129952,0.236121,00:06
3,0.282988,0.141752,0.220815,00:06
4,0.248398,0.122189,0.226764,00:06
5,0.223089,0.142206,0.233348,00:05
6,0.204267,0.120484,0.22685,00:05
7,0.189668,0.169576,0.640352,00:06
8,0.180505,0.144991,0.315387,00:05
9,0.174231,0.147165,0.589907,00:05


epoch,train_loss,valid_loss,accuracy,time
0,3.338326,2.355439,0.461688,00:02
1,2.429852,1.457226,0.508788,00:02
2,1.891433,1.13673,0.605613,00:02
3,1.573227,1.103606,0.590075,00:02
4,1.369468,0.985884,0.617837,00:02
5,1.231457,0.897083,0.646288,00:02
6,1.132755,0.888996,0.635938,00:02
7,1.056383,0.929368,0.639063,00:02
8,0.997056,0.830817,0.652675,00:02
9,0.950593,0.864093,0.623275,00:02


0.4
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8, 'curriculum': True}


epoch,train_loss,valid_loss,accuracy,time
0,2.967971,2.20824,0.490487,00:02
1,2.055822,1.241542,0.475313,00:02
2,1.630423,1.196448,0.486275,00:02
3,1.386012,1.202136,0.48715,00:02
4,1.231148,1.019376,0.562225,00:02
5,1.121989,0.931574,0.581863,00:02
6,1.042203,0.902316,0.606512,00:02
7,0.981393,0.791818,0.646312,00:02
8,0.932565,0.771254,0.6587,00:02
9,0.894304,0.739987,0.670712,00:02


epoch,train_loss,valid_loss,mse,time
0,0.682167,0.258411,0.258378,00:05
1,0.463773,0.197666,0.240301,00:05
2,0.354374,0.125845,0.23184,00:05
3,0.291175,0.145959,0.240079,00:05
4,0.251907,0.130105,0.227967,00:05
5,0.226554,0.128101,0.228244,00:05
6,0.207162,0.127944,0.238033,00:05
7,0.192085,0.137686,0.285756,00:05
8,0.181642,0.155742,0.378213,00:05
9,0.174849,0.151887,0.405793,00:05


epoch,train_loss,valid_loss,accuracy,time
0,2.867232,1.65669,0.471825,00:02
1,2.01367,1.246261,0.489587,00:02
2,1.654469,1.114069,0.514838,00:02
3,1.457713,1.088459,0.522637,00:02
4,1.331164,1.041282,0.521113,00:02
5,1.239544,1.018881,0.5236,00:02
6,1.167972,1.075996,0.509888,00:02
7,1.108273,1.167184,0.511675,00:02
8,1.055366,2.760416,0.368988,00:02
9,1.009129,3.207308,0.382212,00:02


0.6
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8, 'curriculum': True}


epoch,train_loss,valid_loss,accuracy,time
0,3.393209,3.126799,0.544033,00:01
1,2.569528,1.595583,0.612742,00:01
2,2.029011,1.055396,0.628625,00:01
3,1.706121,0.985521,0.611367,00:01
4,1.503671,0.903459,0.6494,00:01
5,1.361501,0.869727,0.656008,00:01
6,1.254836,0.84021,0.657408,00:01
7,1.170359,0.820876,0.666217,00:01
8,1.104354,0.811084,0.666717,00:01
9,1.050094,0.796153,0.664425,00:01


epoch,train_loss,valid_loss,mse,time
0,0.67963,0.249104,0.259025,00:05
1,0.455403,0.154037,0.29658,00:05
2,0.351248,0.144277,0.266775,00:05
3,0.291308,0.156525,0.307412,00:05
4,0.254108,0.130939,0.243092,00:05
5,0.225806,0.125354,0.238203,00:05
6,0.206988,0.151917,0.248082,00:05
7,0.191241,0.150027,0.369505,00:05
8,0.179403,0.138648,0.286234,00:05
9,0.17282,0.164837,0.628252,00:05


epoch,train_loss,valid_loss,accuracy,time
0,3.687016,2.711525,0.457708,00:01
1,3.113081,1.966787,0.465975,00:01
2,2.622377,1.511139,0.498858,00:01
3,2.25963,1.705456,0.482083,00:01
4,2.004467,1.220046,0.527158,00:01
5,1.823889,1.338458,0.4806,00:01
6,1.687479,1.54236,0.437525,00:01
7,1.579086,1.670842,0.419183,00:01
8,1.49031,1.588423,0.406733,00:01
9,1.415691,1.609312,0.408925,00:01


0.6
{'n_d': 64, 'n_a': 64, 'n_steps': 5, 'virtual_batch_size': 512, 'gamma': 1.5, 'bs': 16384, 'lambda_sparse': 0.0001, 'momentum': 0.7, 'n_shared_ft_blocks': 2, 'n_independent_ft_blocks': 2, 'n_dec_steps': 10, 'p': 0.8, 'curriculum': True}


epoch,train_loss,valid_loss,accuracy,time
0,3.224725,2.552136,0.417175,00:01
1,2.515308,1.888419,0.495967,00:01
2,2.041324,1.55881,0.488658,00:01
3,1.766384,1.258079,0.492692,00:01
4,1.579731,1.257125,0.48225,00:01
5,1.447387,1.16591,0.5052,00:01
6,1.348395,1.042867,0.5492,00:01
7,1.266287,1.018165,0.554592,00:01
8,1.19823,0.98439,0.551233,00:01
9,1.141161,0.89043,0.618983,00:01


epoch,train_loss,valid_loss,mse,time
0,0.663546,0.334052,0.275109,00:07
1,0.442413,0.124994,0.227834,00:07
2,0.340965,0.122547,0.229913,00:07
3,0.283443,0.123205,0.231778,00:07
4,0.246833,0.139525,0.2306,00:07
5,0.221015,0.117477,0.229261,00:07
6,0.201861,0.144614,0.36846,00:07
7,0.18733,0.150761,0.60034,00:07
8,0.177517,0.158123,0.428454,00:07
9,0.172187,0.173182,0.765062,00:07


epoch,train_loss,valid_loss,accuracy,time
0,3.817264,2.825176,0.296458,00:02
1,3.271517,2.249825,0.431217,00:02
2,2.762721,1.658904,0.46375,00:02
3,2.353474,2.024934,0.481075,00:02
4,2.069092,1.925478,0.377342,00:02
5,1.874271,2.594393,0.393733,00:02
6,1.734284,1.743225,0.50945,00:02
7,1.625026,2.019701,0.469758,00:02
8,1.537776,4.944553,0.455342,00:02
9,1.466219,4.118552,0.492033,00:02


# Export

In [None]:
from nbdev.export import notebook2script
notebook2script()