In [1]:
# %config Completer.use_jedi = False
import pandas as pd
import numpy as np
from sklearn.datasets import fetch_covtype
import matplotlib.pyplot as plt
from collections import defaultdict, OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
import math
from khds import Optimizer
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import xgboost as xgb

In [2]:
class LazyQuantileNorm(nn.Module):

    def __init__(self, quantiles=50, momentum=.003, 
                 track_running_stats=True, noise=0.1,
                 use_stats_for_train=True, boost=True, predefined=None):
        super(LazyQuantileNorm, self).__init__()

        if type(quantiles) is int:
            quantiles = torch.arange(quantiles-1) / (quantiles - 2)
        else:
            assert type(quantiles) is torch.Tensor and len(quantiles.shape) == 1                        
        
        self.register_buffer("boundaries", None)
        self.register_buffer("quantiles", quantiles)
        self.bernoulli = torch.distributions.bernoulli.Bernoulli(probs=noise)

        self.lr = momentum
        self.boost = boost
#         self.noise = noise
        
        
        assert (not use_stats_for_train) or (use_stats_for_train and track_running_stats)

        self.track_running_stats = track_running_stats
        self.use_stats_for_train = use_stats_for_train

        
    def forward(self, x):
        
        shape = x.shape
        
        if not self.track_running_stats or self.boundaries is None:
            boundaries = torch.quantile(x, self.quantiles, dim=0).transpose(0, 1)
            if self.boundaries is None:
                self.boundaries = boundaries
        
        else:
            
            if self.training:
                
                q = self.quantiles.view(1, 1, -1)
                b = self.boundaries.unsqueeze(0)
                xv = x.unsqueeze(-1).detach()
                q_th = (q * (xv-b) > (1-q) * (b-xv)).float()
                q_grad = (- q * q_th + (1 - q) * (1 - q_th)) * (~torch.isinf(xv)).float()
                q_grad = q_grad.sum(dim=0)
    
                if self.boost:
                    q = self.quantiles.unsqueeze(0)
                    factor = (torch.max(1 / (q + 1e-3), 1 / (1 - q + 1e-3))) ** 0.5
                else:
                    factor = 1
                    
                self.boundaries = self.boundaries - self.lr * factor * q_grad
        
        
        if (self.training and self.use_stats_for_train) or (not self.training and self.track_running_stats):
            boundaries = self.boundaries
        
        xq = torch.searchsorted(boundaries, x.transpose(0, 1)).transpose(0, 1)
        
        if self.training:
            n = (2 * torch.randint(1, size=xq.shape, device=xq.device) - 1) * self.bernoulli.sample(xq.shape).to(xq.device)
            xq = torch.clamp(xq + n.long(), min=0, max=len(self.quantiles)-1)
        
        xq = xq + 1
        xq[torch.isinf(x)] = 0
        
        return xq

    
class GBN(torch.nn.Module):
    """
    Ghost Batch Normalization
    https://arxiv.org/abs/1705.08741
    """

    def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01):
        super(GBN, self).__init__()

        self.input_dim = input_dim
        self.virtual_batch_size = virtual_batch_size
        self.bn = nn.BatchNorm1d(self.input_dim, momentum=momentum)

    def forward(self, x):
        chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
        res = [self.bn(x_) for x_ in chunks]

        return torch.cat(res, dim=0)

    
class Sparsemax(nn.Module):
    """Sparsemax function."""

    def __init__(self, dim=None):
        """Initialize sparsemax activation
        
        Args:
            dim (int, optional): The dimension over which to apply the sparsemax function.
        """
        super(Sparsemax, self).__init__()

        self.dim = -1 if dim is None else dim

    def forward(self, input):
        """Forward function.
        Args:
            input (torch.Tensor): Input tensor. First dimension should be the batch size
        Returns:
            torch.Tensor: [batch_size x number_of_logits] Output tensor
        """
        # Sparsemax currently only handles 2-dim tensors,
        # so we reshape to a convenient shape and reshape back after sparsemax
        input = input.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)

        # Translate input by max for numerical stability
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)

        # Sort input in descending order.
        # (NOTE: Can be replaced with linear time selection method described here:
        # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=input.device, dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        # Determine sparsity of projection
        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]

        # Compute threshold function
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        # Sparsemax
        self.output = torch.max(torch.zeros_like(input), input - taus)

        # Reshape back to original shape
        output = self.output
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, self.dim)

        return output

    def backward(self, grad_output):
        """Backward function."""
        dim = 1

        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        return self.grad_input


def preprocess_feature(v, nq=20):
    '''
    get vector of features and calculate
    quantiles/categories
    
    returns vc, categories
    
    vc - categorical representation of v
    categories - names of categories (if quantile it is (a, b])
    
    currently does not handle nan.
    '''
    
    if type(v) is not pd.Series:
        v = pd.Series(v)
    
    # for now we use a simple rule to distinguish between categorical and numerical features
    n = v.nunique()
    
    if n > nq:
        
        c_type = 'numerical'
        
        q = (np.arange(nq + 1)) / nq
        
        vc = pd.qcut(v, q, labels=False, duplicates='drop')
        categories = v.quantile(q).values[:-1]
        vc = vc.fillna(-1).values
        
    else:
        
        c_type = 'categorical'
        
        vc, categories = pd.factorize(v)
       
    
    # allocate nan value
    categories = np.insert(categories, 0, np.nan)
    vc = vc + 1
        
    return vc, categories, c_type


def preprocess_table(df, nq=20):
    
    metadata = defaultdict(OrderedDict)
    n = 0
    dfc = {}
    
    for c in df.columns:
        vc, categories, c_type = preprocess_feature(df[c], nq=nq)
        
        m = len(categories)
        metadata['n_features'][c] = m
        metadata['categories'][c] = categories
        metadata['aggregated_n_features'][c] = n
        metadata['c_type'][c] = c_type
        
        vc = vc + n
        n = n + m
        dfc[c] = vc
       
    dfc = pd.DataFrame(dfc).astype(np.int64)
    
    metadata['total_features'] = n
        
    return dfc, metadata


class RuleLayer(nn.Module):

    def __init__(self, n_rules, e_dim_in, e_dim_out, bias=True, pos_enc=None, dropout=0.0):
        super(RuleLayer, self).__init__()

        self.query = nn.Parameter(torch.empty((n_rules, e_dim_out)))
        nn.init.kaiming_uniform_(self.query, a=math.sqrt(5))

        self.key = nn.Linear(e_dim_in, e_dim_out, bias=bias)
        self.value = nn.Linear(e_dim_in, e_dim_out, bias=bias)
        self.e_dim_out = e_dim_out
        self.sparsemax = Sparsemax(dim=1)
        self.tau = 1.
        
#         self.sparsemax = nn.Softmax(dim=1)

        if pos_enc is None:
            
#             _, _, V = torch.svd(torch.randn(e_dim_in, e_dim_in))
#             pos_enc = V.T.clone()
    
#         self.register_buffer('pos_enc', pos_enc)

            self.pos_enc = nn.Parameter(torch.empty((e_dim_in, e_dim_out)))
            nn.init.kaiming_uniform_(self.pos_enc, a=math.sqrt(5))
        
    def forward(self, x):
        
        b, nf, ne = x.shape
        
        ##############
        pos = self.pos_enc[:nf].unsqueeze(0).repeat(b, 1, 1)  
        x = x + pos
        ##############
                        
        k = self.key(x)
        v = self.value(x)
        q = self.query
                
        a = k @ q.T / math.sqrt(self.e_dim_out)
        a_prob = self.sparsemax(a / self.tau).transpose(1, 2)
        
        r = torch.bmm(a_prob, v)
        
        return r, a_prob


class mySequential(nn.Sequential):
    def forward(self, *input):
        for module in self._modules.values():
            input = module(*input)
        return input

    
class ResRuleLayer(nn.Module):

    def __init__(self, n_rules, e_dim, bias=True, activation='gelu', dropout=0.0, n_out=1, n_features=None):
        super(ResRuleLayer, self).__init__()

        self.bn1 = GBN(e_dim, virtual_batch_size=256, momentum=0.1)
        self.rl1 = RuleLayer(n_rules, e_dim, e_dim, bias=bias, dropout=dropout)
        self.sl1 = RuleLayer(n_rules, e_dim, e_dim, bias=bias, dropout=dropout)
        self.bn2 = GBN(e_dim, virtual_batch_size=256, momentum=0.1)
        self.rl2 = RuleLayer(n_rules, e_dim, e_dim, bias=bias, dropout=dropout)
        self.sl2 = RuleLayer(n_rules, e_dim, e_dim, bias=bias, dropout=dropout)
        self.activation = getattr(F, activation)

        self.last_rule = RuleLayer(1, e_dim, e_dim, bias=bias, dropout=dropout)
        
        self.lin = nn.Sequential(nn.Flatten(start_dim=-2, end_dim=-1),
                               nn.Linear(e_dim, n_out, bias=bias))
        
    def forward(self, x, e, y):
                        
        r = x
        r = self.bn1(r.transpose(1, 2)).transpose(1, 2)
        r = self.activation(r)
                
        r1, ai = self.rl1(r)
        s1, ai = self.sl1(r)
        
        r = torch.sigmoid(s1) * r1

        r = self.bn2(r.transpose(1, 2)).transpose(1, 2)
        r = self.activation(r)
        
        r2, ai = self.rl2(r)        
        s2, ai = self.sl2(r)
        
        r = torch.sigmoid(s2) * r2
                
        r = r + x
                
        y.append(self.lin(self.last_rule(r)[0]))
                
        return r, e, y
    
class RuleNet(nn.Module):

    def __init__(self, n_features, features_offset, embedding_dim=256, n_rules=128, 
                 n_layers=5, dropout=0.2, n_out=1, bias=True, activation='gelu', noise=0.1,
                 quantiles=50, predefined_boundaries=None):
        super(RuleNet, self).__init__()
            
        self.q_norm = LazyQuantileNorm(quantiles=quantiles, predefined=predefined_boundaries, noise=noise)
        self.register_buffer('features_offset', features_offset) 
        
        self.emb = nn.Embedding(n_features, embedding_dim, sparse=True)
        
        self.first_rule = RuleLayer(n_rules, embedding_dim, embedding_dim, 
                                    bias=bias, dropout=dropout)
        self.rules = mySequential(*[ResRuleLayer(n_rules, embedding_dim, 
                                                 bias=bias, activation=activation, 
                                                 dropout=dropout, n_out=n_out) 
                                    for _ in range(n_layers)],
                                  )    
        
    def forward(self, x_num, x_cat):
        
        x_num = self.q_norm(x_num)
        
        x = torch.cat([x_num, x_cat], dim=1) + self.features_offset
                
        e = self.emb(x)        
        x, _ = self.first_rule(e)
                
        x, _, y = self.rules(x, e, [])

        y = torch.stack(y, dim=1).sum(dim=1)
        
        return y, None, None

    
class EmbNet(nn.Module):

    def __init__(self, n_features, embedding_dim=256, n_out=1, bias=True):
        super(EmbNet, self).__init__()

        self.emb = nn.Embedding(n_features, embedding_dim, sparse=True)
        self.lin = nn.Linear(embedding_dim, n_out, bias=bias)
    
        
    def forward(self, x):
                
        x = self.emb(x).mean(dim=1)
        y = self.lin(x)
        
        return y, None
    
class BaseNet(nn.Module):

    def __init__(self, n_features, n_hidden=256, n_out=1, bias=True):
        super(BaseNet, self).__init__()

        self.lin = nn.Sequential(nn.Linear(n_features, n_hidden, bias=bias),
                              nn.BatchNorm1d(n_hidden),
                             nn.GELU(),
                             nn.Linear(n_hidden, n_hidden, bias=bias),
                              nn.BatchNorm1d(n_hidden),
                             nn.GELU(),
                             nn.Linear(n_hidden, n_hidden, bias=bias),
                              nn.BatchNorm1d(n_hidden),
                             nn.GELU(),
                             nn.Linear(n_hidden, n_out, bias=bias))    
        
    def forward(self, x):
                        
        y = self.lin(x)
        
        return y, None

In [3]:
data = fetch_covtype()
device = 1
quantiles = 100

labels = data['target']
df = pd.DataFrame(data['data'])

dfc, metadata = preprocess_table(df, nq=quantiles)
labels, cat = pd.factorize(labels)

categorical_columns = [i for i, c in enumerate(metadata['c_type'].values()) if c == 'categorical']
numerical_columns = [i for i, c in enumerate(metadata['c_type'].values()) if c == 'numerical']

x_cat = torch.LongTensor(df.values[:, categorical_columns] + 1)
x_num = torch.FloatTensor(df.values[:, numerical_columns])

predefined_boundaries = torch.quantile(x_num, torch.arange(quantiles+1) / quantiles, dim=0).transpose(0, 1)

features_offset = torch.LongTensor([v for v in metadata['aggregated_n_features'].values()])
total_features = metadata['total_features']

y = torch.LongTensor(labels).to(device)
x_cat = x_cat.to(device)
x_num = x_num.to(device)

train_indices, test_indices = train_test_split(torch.arange(len(y)), test_size=116203, random_state=3463)
train_indices, validation_indices = train_test_split(train_indices, test_size=92962, random_state=3464)

In [4]:
batch = 256 * 4
batch_test = 256 * 4

epoch_length = len(train_indices) // batch
epoch_length_val = len(validation_indices) // batch
epoch_length_test = len(test_indices) // batch

n_epochs = 500
br_d = 0.005
n_layers = 4

min_lr = 3e-6
br = 1.
bernoulli = torch.distributions.bernoulli.Bernoulli(probs=br)

net = RuleNet(total_features, features_offset, n_out=int(y.max()+1),  embedding_dim=128, n_rules=64, 
              n_layers=n_layers, activation='gelu', dropout=0.0, quantiles=quantiles, noise=.2,
              predefined_boundaries=predefined_boundaries).to(device)

optimizer = Optimizer(net, dense_ars={'lr': 1e-3, 'eps': 1e-8, 'weight_decay': 0}, 
                      sparse_args={'lr': 1e-1, 'eps': 1e-8})

scheduler_dense = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer.dense, mode='min', factor=1 / math.sqrt(10), patience=16, threshold=0, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)
scheduler_sparse = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer.sparse, mode='min', factor=1 / math.sqrt(10), patience=16, threshold=0, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)

best_weight = net.state_dict()
best_acc = 0

In [None]:
for epoch in range(n_epochs):
    
    train_indices = train_indices[torch.randperm(len(train_indices))]
    
    net.train()
    stats = defaultdict(list)
    for i in tqdm(range(epoch_length)):
        
        indices = train_indices[i * batch: (i + 1) * batch]
        xi_cat = x_cat[indices]
        xi_num = x_num[indices]
        yi = y[indices]
                
        mask_num = bernoulli.sample(sample_shape=xi_num.shape).bool().to(device)
        xi_num_masked = xi_num.clone()
        xi_num_masked[~mask_num] = torch.tensor(-np.inf, device=xi_num.device)
        
        mask_cat = bernoulli.sample(sample_shape=xi_cat.shape).long().to(device)
        xi_cat_masked = xi_cat * mask_cat
        
        yhat, _, _ = net(xi_num_masked, xi_cat_masked)
        loss = F.cross_entropy(yhat, yi, reduction='sum')
    
        loss_t = loss
        
        optimizer.zero_grad()
        loss_t.backward()
        optimizer.step()
        
        stats['train_loss'].append(float(loss))

        stats['train_accuracy'].append(float((yhat.argmax(dim=1) == yi).float().mean()))
        stats['lr'].append(float(optimizer.dense.param_groups[0]['lr']))
        
    net.eval()
    
    with torch.no_grad():
        for i in tqdm(range(epoch_length_test)):

            indices = test_indices[i * batch: (i + 1) * batch]
            xi_cat = x_cat[indices]
            xi_num = x_num[indices]
            yi = y[indices]
            
            yhat, _, _ = net(xi_num, xi_cat)
            loss = F.cross_entropy(yhat, yi, reduction='sum')    
    
            stats['test_loss'].append(float(loss))
            stats['test_accuracy'].append(float((yhat.argmax(dim=1) == yi).float().mean()))
            
        for i in tqdm(range(epoch_length_val)):

            indices = validation_indices[i * batch: (i + 1) * batch]
            xi_cat = x_cat[indices]
            xi_num = x_num[indices]
            yi = y[indices]
            
            yhat, _, _ = net(xi_num, xi_cat)
            loss = F.cross_entropy(yhat, yi, reduction='sum')    
    
            stats['val_loss'].append(float(loss))
            stats['val_accuracy'].append(float((yhat.argmax(dim=1) == yi).float().mean()))
    
    acc = float(np.mean(stats['val_accuracy']))
    val_loss = float(np.mean(stats['val_loss']))
    train_loss = float(np.mean(stats['train_loss']))
        
    scheduler_dense.step(val_loss)
    scheduler_sparse.step(val_loss)
    
    if acc > best_acc:
        best_weight = net.state_dict()
        best_acc = acc
        print(f'epoch {epoch}: Update best weights')

    print(f"epoch: {epoch}")
    print(f'bernoulli: {br}')
    for k, v in stats.items():
        print(f"{k}: {np.mean(v)}")
        
    if train_loss < val_loss:
        br = max(0, br - br_d)
    else:  
        br = min(br + br_d, 1)
        
    bernoulli = torch.distributions.bernoulli.Bernoulli(probs=br)
    if float(optimizer.dense.param_groups[0]['lr']) < min_lr:
        break
            

  0%|          | 0/363 [00:00<?, ?it/s]

  xq = torch.searchsorted(boundaries, x.transpose(0, 1)).transpose(0, 1)
  xq = torch.searchsorted(boundaries, x.transpose(0, 1)).transpose(0, 1)


  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 0: Update best weights
epoch: 0
bernoulli: 1.0
train_loss: 567.2111474239465
train_accuracy: 0.7637337508608816
lr: 0.0010000000000000002
test_loss: 421.02376968248757
test_accuracy: 0.825566924778761
val_loss: 419.6343729654948
val_accuracy: 0.82333984375


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 1: Update best weights
epoch: 1
bernoulli: 1
train_loss: 360.73697050740896
train_accuracy: 0.853558669077135
lr: 0.0010000000000000002
test_loss: 288.14089722759957
test_accuracy: 0.8847915514380531
val_loss: 288.5475346883138
val_accuracy: 0.8846571180555556


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 2: Update best weights
epoch: 2
bernoulli: 1
train_loss: 266.8932630050281
train_accuracy: 0.8945608428030303
lr: 0.0010000000000000002
test_loss: 236.74848654417866
test_accuracy: 0.9079611449115044
val_loss: 241.28247290717232
val_accuracy: 0.9064453125


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 3: Update best weights
epoch: 3
bernoulli: 1
train_loss: 220.73535307576833
train_accuracy: 0.9141781809573003
lr: 0.0010000000000000002
test_loss: 211.49872204265762
test_accuracy: 0.9176403484513275
val_loss: 208.08175133599175
val_accuracy: 0.9197916666666667


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 4: Update best weights
epoch: 4
bernoulli: 1
train_loss: 192.75361507195086
train_accuracy: 0.9253077651515151
lr: 0.0010000000000000002
test_loss: 187.36777665129804
test_accuracy: 0.9281664823008849
val_loss: 185.59468756781683
val_accuracy: 0.9285590277777778


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 5: Update best weights
epoch: 5
bernoulli: 1
train_loss: 173.6311140244329
train_accuracy: 0.9333220342630854
lr: 0.0010000000000000002
test_loss: 175.7662018632467
test_accuracy: 0.9330406526548672
val_loss: 175.92877638075086
val_accuracy: 0.9331163194444444


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 6: Update best weights
epoch: 6
bernoulli: 0.995
train_loss: 171.77966485141724
train_accuracy: 0.9341156594352618
lr: 0.0010000000000000002
test_loss: 168.03470240230055
test_accuracy: 0.9359616980088495
val_loss: 166.21565907796224
val_accuracy: 0.9360134548611111


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 7: Update best weights
epoch: 7
bernoulli: 1.0
train_loss: 151.60563329691402
train_accuracy: 0.941618780130854
lr: 0.0010000000000000002
test_loss: 159.6523938474402
test_accuracy: 0.9393753456858407
val_loss: 161.5384644402398
val_accuracy: 0.9386067708333333


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 8: Update best weights
epoch: 8
bernoulli: 0.995
train_loss: 151.9797775857048
train_accuracy: 0.9411910296143251
lr: 0.0010000000000000002
test_loss: 151.53015501309284
test_accuracy: 0.9415704507743363
val_loss: 151.47732043796117
val_accuracy: 0.941796875


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 9: Update best weights
epoch: 9
bernoulli: 1.0
train_loss: 135.10922594306882
train_accuracy: 0.9482744705578512
lr: 0.0010000000000000002
test_loss: 150.47939705637705
test_accuracy: 0.9433161642699115
val_loss: 149.80362396240236
val_accuracy: 0.9433810763888889


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 10: Update best weights
epoch: 10
bernoulli: 0.995
train_loss: 137.1325972375791
train_accuracy: 0.9472602444903582
lr: 0.0010000000000000002
test_loss: 141.13543903722172
test_accuracy: 0.9462890625
val_loss: 141.0145282321506
val_accuracy: 0.9466254340277778


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 11
bernoulli: 0.99
train_loss: 140.52455519579001
train_accuracy: 0.9456245695592287
lr: 0.0010000000000000002
test_loss: 146.252558142738
test_accuracy: 0.9448371819690266
val_loss: 144.75708880954318
val_accuracy: 0.9457899305555556


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 12
bernoulli: 0.985
train_loss: 143.80249897770318
train_accuracy: 0.9444865917699724
lr: 0.0010000000000000002
test_loss: 141.803007345284
test_accuracy: 0.9467989491150443
val_loss: 141.2498352050781
val_accuracy: 0.9464735243055555


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 13: Update best weights
epoch: 13
bernoulli: 0.99
train_loss: 131.23081816786248
train_accuracy: 0.9490035296143251
lr: 0.0010000000000000002
test_loss: 139.43240808807644
test_accuracy: 0.947956996681416
val_loss: 137.62321302625867
val_accuracy: 0.9490559895833334


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 14
bernoulli: 0.985
train_loss: 135.61232913098715
train_accuracy: 0.9479758522727273
lr: 0.0010000000000000002
test_loss: 135.73971429335333
test_accuracy: 0.948933559181416
val_loss: 138.07052451239693
val_accuracy: 0.94873046875


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 15
bernoulli: 0.98
train_loss: 138.24750205636352
train_accuracy: 0.9464020532024794
lr: 0.0010000000000000002
test_loss: 137.71739493851112
test_accuracy: 0.9486570105088495
val_loss: 136.89214011298284
val_accuracy: 0.9487630208333333


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 16: Update best weights
epoch: 16
bernoulli: 0.985
train_loss: 128.17307516891438
train_accuracy: 0.9508006198347108
lr: 0.0010000000000000002
test_loss: 132.7524311437016
test_accuracy: 0.9506792726769911
val_loss: 130.62402301364475
val_accuracy: 0.9515190972222223


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 17
bernoulli: 0.98
train_loss: 132.3268650359687
train_accuracy: 0.9485165934917356
lr: 0.0010000000000000002
test_loss: 139.3191803788717
test_accuracy: 0.9489854120575221
val_loss: 138.14589326646592
val_accuracy: 0.9488606770833333


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 18: Update best weights
epoch: 18
bernoulli: 0.975
train_loss: 137.64546010251215
train_accuracy: 0.9464235752410468
lr: 0.0010000000000000002
test_loss: 130.00270026552994
test_accuracy: 0.9511200221238938
val_loss: 127.74485651652019
val_accuracy: 0.9520073784722223


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 19
bernoulli: 0.98
train_loss: 127.12974798449471
train_accuracy: 0.9508813274793388
lr: 0.0010000000000000002
test_loss: 131.88414676632502
test_accuracy: 0.9511113799778761
val_loss: 131.56076532999674
val_accuracy: 0.9507378472222222


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 20: Update best weights
epoch: 20
bernoulli: 0.975
train_loss: 130.84398636410717
train_accuracy: 0.9489443440082644
lr: 0.0010000000000000002
test_loss: 132.55710527538199
test_accuracy: 0.9515348451327433
val_loss: 129.4068354288737
val_accuracy: 0.95322265625


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 21: Update best weights
epoch: 21
bernoulli: 0.98
train_loss: 121.22684510286189
train_accuracy: 0.953337530130854
lr: 0.0010000000000000002
test_loss: 130.77729615068014
test_accuracy: 0.951776825221239
val_loss: 126.96390719943577
val_accuracy: 0.9534939236111111


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 22: Update best weights
epoch: 22
bernoulli: 0.975
train_loss: 125.10246150749774
train_accuracy: 0.9517018551997245
lr: 0.0010000000000000002
test_loss: 125.84782956553771
test_accuracy: 0.9530645049778761
val_loss: 123.45015318128797
val_accuracy: 0.9547309027777777


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 23: Update best weights
epoch: 23
bernoulli: 0.98
train_loss: 116.34640547066681
train_accuracy: 0.9545185519972452
lr: 0.0010000000000000002
test_loss: 127.76114965118138
test_accuracy: 0.9534361172566371
val_loss: 122.29689017401802
val_accuracy: 0.9551106770833333


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 24: Update best weights
epoch: 24
bernoulli: 0.975
train_loss: 121.47919365328534
train_accuracy: 0.9530604338842975
lr: 0.0010000000000000002
test_loss: 124.89465541333225
test_accuracy: 0.9551818307522124
val_loss: 122.54635281032986
val_accuracy: 0.9555447048611111


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 25
bernoulli: 0.97
train_loss: 125.12954758152817
train_accuracy: 0.951467803030303
lr: 0.0010000000000000002
test_loss: 132.6680505803201
test_accuracy: 0.9524249861725663
val_loss: 130.17428334554037
val_accuracy: 0.9533420138888888


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 26
bernoulli: 0.965
train_loss: 130.9014988208277
train_accuracy: 0.9493505724862259
lr: 0.0010000000000000002
test_loss: 128.4939767618095
test_accuracy: 0.9534447594026548
val_loss: 124.55857984754775
val_accuracy: 0.9540473090277778


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 27: Update best weights
epoch: 27
bernoulli: 0.97
train_loss: 121.52332738871088
train_accuracy: 0.9528828770661157
lr: 0.0010000000000000002
test_loss: 126.07415676960903
test_accuracy: 0.9550521985619469
val_loss: 124.75767508612739
val_accuracy: 0.9558485243055556


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 28
bernoulli: 0.965
train_loss: 127.30186227005046
train_accuracy: 0.9505907799586777
lr: 0.0010000000000000002
test_loss: 123.44938976996768
test_accuracy: 0.9550435564159292
val_loss: 122.11313052707249
val_accuracy: 0.9556749131944444


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 29
bernoulli: 0.97
train_loss: 118.39264653733939
train_accuracy: 0.9539939523071626
lr: 0.0010000000000000002
test_loss: 123.28369262155178
test_accuracy: 0.9553460315265486
val_loss: 123.78283971150717
val_accuracy: 0.9550238715277778


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 30
bernoulli: 0.965
train_loss: 123.85156355088078
train_accuracy: 0.951809465392562
lr: 0.0010000000000000002
test_loss: 122.77132820872079
test_accuracy: 0.9558559181415929
val_loss: 120.0531618754069
val_accuracy: 0.9557834201388888


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 31: Update best weights
epoch: 31
bernoulli: 0.97
train_loss: 114.61730730040999
train_accuracy: 0.9555731318870524
lr: 0.0010000000000000002
test_loss: 122.45195236881223
test_accuracy: 0.9559855503318584
val_loss: 119.78102696736654
val_accuracy: 0.9568684895833334


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 32
bernoulli: 0.965
train_loss: 120.04224120421186
train_accuracy: 0.9531788050964187
lr: 0.0010000000000000002
test_loss: 123.81312371988213
test_accuracy: 0.9556571487831859
val_loss: 120.697838083903
val_accuracy: 0.9562608506944444


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 33
bernoulli: 0.96
train_loss: 124.37723135619781
train_accuracy: 0.9514731835399449
lr: 0.0010000000000000002
test_loss: 119.51894351655403
test_accuracy: 0.9571435978982301
val_loss: 117.91636827256944
val_accuracy: 0.9568359375


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch 34: Update best weights
epoch: 34
bernoulli: 0.965
train_loss: 117.28965601645226
train_accuracy: 0.9540262353650137
lr: 0.0010000000000000002
test_loss: 120.60047095644791
test_accuracy: 0.9560892560840708
val_loss: 117.72687225341797
val_accuracy: 0.9571506076388889


  0%|          | 0/363 [00:00<?, ?it/s]

  0%|          | 0/113 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

epoch: 35
bernoulli: 0.96
train_loss: 122.28914392224357
train_accuracy: 0.9521511277548209
lr: 0.0010000000000000002
test_loss: 119.2733955045717
test_accuracy: 0.9566250691371682
val_loss: 117.59567498101129
val_accuracy: 0.9569444444444445


  0%|          | 0/363 [00:00<?, ?it/s]

In [None]:
ensambles = 16
bernoulli = torch.distributions.bernoulli.Bernoulli(probs=.98)

dataset_indices = test_indices

y_agg = []
y_hat_agg = []
net.eval()

with torch.no_grad():
    for i in tqdm(range(min(1000, int(len(dataset_indices) / (batch_test // ensambles))))):

        indices = dataset_indices[i * batch_test // ensambles: (i + 1) * batch_test // ensambles]
        xi_cat = x_cat[indices].unsqueeze(1).repeat(1, ensambles, 1).view(len(indices) * ensambles, -1)
        yi = y[indices]
        xi_num = x_num[indices].unsqueeze(1).repeat(1, ensambles, 1).view(len(indices) * ensambles, -1)
        
        mask_num = bernoulli.sample(sample_shape=xi_num.shape).bool().to(device)
        xi_num_masked = xi_num.clone()
        xi_num_masked[~mask_num] = torch.tensor(-np.inf, device=xi_num.device)
        
        mask_cat = bernoulli.sample(sample_shape=xi_cat.shape).long().to(device)
        xi_cat_masked = xi_cat * mask_cat
            
        yhat, _, _ = net(xi_num_masked, xi_cat_masked)

        yhat = torch.softmax(yhat, dim=1)
        yhat = yhat.view(len(indices), ensambles, -1).mean(dim=1)
        
        y_hat_agg.append(yhat)
        y_agg.append(yi)
        

In [None]:
y_agg = torch.cat(y_agg)
y_hat_agg = torch.cat(y_hat_agg, dim=0)

In [None]:
acc = torch.argmax(y_hat_agg, dim=-1) == y_agg

In [None]:
y_hat_agg.max(dim=-1)[0][~acc].mean()

In [None]:
torch.quantile(y_hat_agg.max(dim=-1)[0], 0.5)

In [None]:
acc.float().mean()

## Loading and storing weights

In [None]:
curr_weight = net.state_dict()

In [None]:
net.load_state_dict(best_weight)

In [None]:
net.load_state_dict(curr_weight)

In [None]:
torch.save(net.state_dict(), 'net')