Skip to content
This repository has been archived by the owner. It is now read-only.
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/42_tabular.model.ipynb (unless otherwise specified).
__all__ = ['emb_sz_rule', 'get_emb_sz', 'TabularModel', 'tabular_config']
# Cell
from ..torch_basics import *
from .core import *
# Cell
def emb_sz_rule(n_cat):
"Rule of thumb to pick embedding size corresponding to `n_cat`"
return min(600, round(1.6 * n_cat**0.56))
# Cell
def _one_emb_sz(classes, n, sz_dict=None):
"Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`."
sz_dict = ifnone(sz_dict, {})
n_cat = len(classes[n])
sz = sz_dict.get(n, int(emb_sz_rule(n_cat))) # rule of thumb
return n_cat,sz
# Cell
def get_emb_sz(to, sz_dict=None):
"Get default embedding size from `TabularPreprocessor` `proc` or the ones in `sz_dict`"
return [_one_emb_sz(to.classes, n, sz_dict) for n in to.cat_names]
# Cell
class TabularModel(Module):
"Basic model for tabular data."
def __init__(self, emb_szs, n_cont, out_sz, layers, ps=None, embed_p=0.,
y_range=None, use_bn=True, bn_final=False, bn_cont=True):
ps = ifnone(ps, [0]*len(layers))
if not is_listy(ps): ps = [ps]*len(layers)
self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])
self.emb_drop = nn.Dropout(embed_p)
self.bn_cont = nn.BatchNorm1d(n_cont) if bn_cont else None
n_emb = sum(e.embedding_dim for e in self.embeds)
self.n_emb,self.n_cont = n_emb,n_cont
sizes = [n_emb + n_cont] + layers + [out_sz]
actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]
_layers = [LinBnDrop(sizes[i], sizes[i+1], bn=use_bn and (i!=len(actns)-1 or bn_final), p=p, act=a)
for i,(p,a) in enumerate(zip(ps+[0.],actns))]
if y_range is not None: _layers.append(SigmoidRange(*y_range))
self.layers = nn.Sequential(*_layers)
def forward(self, x_cat, x_cont=None):
if self.n_emb != 0:
x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
x =, 1)
x = self.emb_drop(x)
if self.n_cont != 0:
if self.bn_cont is not None: x_cont = self.bn_cont(x_cont)
x =[x, x_cont], 1) if self.n_emb != 0 else x_cont
return self.layers(x)
# Cell
def tabular_config(**kwargs):
"Convenience function to easily create a config for `tabular_model`"
return kwargs