# SAINT Experiments

Experimenting with the [SAINT](https://github.com/somepago/saint) implementation and bringing it into fastai

In [1]:
!pip install fastai -U >> /dev/null

In [2]:
from fastai.tabular.all import *

In [3]:
path = untar_data(URLs.ADULT_SAMPLE)

dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation',
                 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])

In [4]:
cat,cont,y = next(iter(dls.valid))

In [5]:
!git clone https://github.com/somepago/saint.git

Cloning into 'saint'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 35 (delta 12), reused 28 (delta 8), pack-reused 0[K
Unpacking objects: 100% (35/35), done.


In [8]:
%cd saint

/content/saint


In [10]:
!pip install einops >> /dev/null

In [11]:
from models import *

In [19]:
config = {'dataset':'1995_income',
'cont_embeddings':'MLP',
'embedding_size':32,
'transformer_depth':6,
'attention_heads':8,
'attention_dropout':0.1,
'ff_dropout':0.1,
'attentiontype':'colrow',
'lr':0.0001,
'epochs':100,
'batchsize':256,
'savemodelroot':'./bestmodels',
'run_name':'testrun',
'set_seed': 1 ,
'active_log':False,
'pretrain':False,
'pretrain_epochs':50,
'pt_tasks':['contrastive','denoising'],
'pt_aug':[],
'pt_aug_lam':0.1,
'mixup_lam':0.3,

'train_mask_prob':0,
'mask_prob':0,

'ssl_avail_y': 0,
'pt_projhead_style':'diff',
'nce_temp':0.7,

'lam0':0.5,
'lam1':10,
'lam2':1,
'lam3':10,
'final_mlp_style':'sep'}

In [20]:
config['attentiontype']

'colrow'

In [21]:
def change_config(config, fld, val):
    print(f'Initial {fld} configuration: {val}')
    config[fld] = val
    print(f'New {fld} configuration: {val}')
    return config

In [22]:
config = change_config(config, 'ff_dropout', 0.8)
config = change_config(config, 'transformer_depth', 1)

Initial ff_dropout configuration: 0.8
New ff_dropout configuration: 0.8
Initial transformer_depth configuration: 1
New transformer_depth configuration: 1


In [23]:
config['train_mask_prob']

0

In [24]:
mask_params = {
    'mask_prob':config['train_mask_prob'],
    'avail_train_y': 0,
    'test_mask':config['train_mask_prob']
}

In [25]:
config['mask_prob']

0

In [26]:
pt_mask_params = {
    'mask_prob':0,
    'avail_train_y':0,
    'test_mask':0
}

In [34]:
mask = np.ones_like(dls.train.dataset.xs.values)

In [38]:
y_dim = 2

In [40]:
dls.train.normalize.means

{'age': 38.55587546546892,
 'education-num': 10.078429222106934,
 'fnlwgt': 189819.7968060194}

In [41]:
mean, std = dls.train.normalize.means, dls.train.normalize.stds

In [42]:
config['pretrain']

False

In [47]:
dls.train.categorify.classes

{'education': ['#na#', '10th', '11th', '12th', '1st-4th', '5th-6th', '7th-8th', '9th', 'Assoc-acdm', 'Assoc-voc', 'Bachelors', 'Doctorate', 'HS-grad', 'Masters', 'Preschool', 'Prof-school', 'Some-college'],
 'education-num_na': ['#na#', False, True],
 'marital-status': ['#na#', 'Divorced', 'Married-AF-spouse', 'Married-civ-spouse', 'Married-spouse-absent', 'Never-married', 'Separated', 'Widowed'],
 'occupation': ['#na#', '?', 'Adm-clerical', 'Armed-Forces', 'Craft-repair', 'Exec-managerial', 'Farming-fishing', 'Handlers-cleaners', 'Machine-op-inspct', 'Other-service', 'Priv-house-serv', 'Prof-specialty', 'Protective-serv', 'Sales', 'Tech-support', 'Transport-moving'],
 'race': ['#na#', 'Amer-Indian-Eskimo', 'Asian-Pac-Islander', 'Black', 'Other', 'White'],
 'relationship': ['#na#', 'Husband', 'Not-in-family', 'Other-relative', 'Own-child', 'Unmarried', 'Wife'],
 'workclass': ['#na#', '?', 'Federal-gov', 'Local-gov', 'Never-worked', 'Private', 'Self-emp-inc', 'Self-emp-not-inc', 'State-

In [51]:
categorical_dims = {o:len(i) for _, (o,i) in enumerate(dls.train.categorify.classes.items())}

In [52]:
categorical_dims

{'education': 17,
 'education-num_na': 3,
 'marital-status': 8,
 'occupation': 16,
 'race': 6,
 'relationship': 7,
 'workclass': 10}

In [53]:
y_dim = 2

In [54]:
config['ssl_avail_y']

0

In [55]:
train_bsize = config['batchsize']

In [56]:
train_bsize

256

In [58]:
categorical_dims.values()

dict_values([10, 17, 8, 16, 7, 6, 3])

In [65]:
categorical_dims.values()

dict_values([10, 17, 8, 16, 7, 6, 3])

In [66]:
cat_dims = np.append(np.array([10,17,8,16,7,6,3]), np.array([2])).astype(int)

In [67]:
cat_dims

array([10, 17,  8, 16,  7,  6,  3,  2])

In [68]:
dls.cont_names

(#3) ['age','fnlwgt','education-num']

In [69]:
config['embedding_size']

32

In [70]:
config['cont_embeddings']

'MLP'

In [71]:
config['attentiontype']

'colrow'

In [72]:
config['final_mlp_style']

'sep'

In [73]:
model = SAINT(
    categories = tuple(cat_dims),
    num_continuous = len(dls.cont_names),
    dim = config['embedding_size'],
    dim_out = 1,
    depth = config['transformer_depth'],
    heads = config['attention_heads'],
    attn_dropout = config['attention_dropout'],
    ff_dropout = config['ff_dropout'],
    mlp_hidden_mults = (4,2),
    continuous_mean_std = (mean,std),
    cont_embeddings = config['cont_embeddings'],
    attentiontype = config['attentiontype'],
    final_mlp_style = config['final_mlp_style'],
    y_dim = 1
)

In [81]:
mask.shape

(26049, 10)

In [87]:
from augmentations import embed_data_mask

In [88]:
cat.shape, cont.shape

(torch.Size([64, 7]), torch.Size([64, 3]))

In [89]:
mask

array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]])

In [128]:
np.ones_like(np.empty((64,8))).shape

(64, 8)

In [129]:
cat_mask = np.ones_like(np.empty((64,8)))
cont_mask = np.ones_like(cont)

In [139]:
cat_mask[:,-1] = 0

In [140]:
cat_mask

tensor([[1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1

In [131]:
len(dls.cat_names)

7

In [141]:
cat_mask = tensor(cat_mask)
cont_mask = tensor(cont_mask)

In [142]:
y_mask = tensor(np.ones_like(y))

In [143]:
torch.cat((cat_mask,y_mask), dim=1).shape

torch.Size([64, 9])

In [150]:
with torch.no_grad():
    model.eval()
    _ , x_categ_enc, x_cont_enc = embed_data_mask(torch.cat((cat, y), dim=-1), 
                                                  cont, 
                                                  cat_mask.int(), 
                                                  cont_mask.int(), 
                                                  model,
                                                  False)  
    reps = model.transformer(x_categ_enc, x_cont_enc)
    # Train embeds for our ys
    y_reps = reps[:,len(cat_dims)-1,:]
    y_outs = model.mlpfory(y_reps)

In [146]:
x_categ_enc.shape

torch.Size([64, 8, 32])

In [148]:
x_cont_enc.shape

torch.Size([64, 3, 32])

In [84]:
with torch.no_grad():
    model.eval()
    model.cuda()
    out = model()

tensor([[ 5, 12,  5, 13,  2,  5,  1],
        [ 5,  2,  5,  2,  4,  3,  1],
        [ 6,  8,  3,  5,  1,  5,  1],
        [ 5,  2,  3, 15,  1,  5,  1],
        [ 2, 10,  5, 11,  2,  5,  1],
        [ 5, 13,  1, 11,  5,  5,  1],
        [ 5, 16,  1, 11,  2,  5,  1],
        [ 5,  5,  3, 15,  1,  3,  1],
        [ 5,  2,  1, 12,  2,  5,  1],
        [ 3, 10,  3,  5,  1,  5,  1],
        [ 3, 12,  3,  5,  1,  5,  1],
        [ 5, 16,  5, 13,  2,  5,  1],
        [ 5, 16,  1, 11,  4,  5,  1],
        [ 1, 12,  1,  1,  4,  5,  1],
        [ 5, 12,  1,  2,  5,  5,  1],
        [ 5,  6,  3,  4,  1,  5,  1],
        [ 5,  1,  7,  8,  2,  5,  1],
        [ 5, 16,  1,  5,  2,  5,  1],
        [ 3, 16,  5,  9,  4,  3,  1],
        [ 5,  2,  3,  8,  1,  5,  1],
        [ 7, 12,  1, 13,  2,  5,  1],
        [ 7, 16,  3,  4,  1,  5,  1],
        [ 5,  1,  5,  9,  4,  5,  1],
        [ 5, 12,  1,  4,  2,  5,  1],
        [ 5, 10,  3, 11,  1,  5,  1],
        [ 5, 12,  7,  5,  2,  3,  1],
        [ 5,