In [2]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import gc

from tqdm.autonotebook import tqdm
import os
tqdm.pandas()



In [2]:
!nvidia-smi

Wed Jun 19 15:36:45 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.56       Driver Version: 418.56       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  On   | 00000000:08:00.0 Off |                  N/A |
| 20%   48C    P2    51W / 250W |      1MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
train_dir = './train-data'
train_x = pd.read_csv(os.path.join(train_dir, 'train_x.csv'))
# Class will start with 0
train_y = pd.read_csv(os.path.join(train_dir, 'train_y.csv')) - 1
test_x = pd.read_csv(os.path.join(train_dir, 'test_x.csv'))

In [5]:
categorical_features = ['gender', 'city', 'prodName', 'color', 'carrier']

In [None]:
train_x.fillna(train_x.mean(), inplace=True)
test_x.fillna(train_x.mean(), inplace=True)

In [None]:
categories_dict = []
for cate in categorical_features:
    # (dict_size, embedding_size)
    dict_size = max(train_x[cate].unique()) + 1
    categories_dict.append((dict_size, min(dict_size * 2, 100)))

In [None]:
categories_dict

[(2, 4), (364, 100), (227, 100), (136, 100), (4, 8)]

In [None]:
class TabularDataset(Dataset):
    def __init__(self, x: pd.DataFrame, categories=None, y=None):

      self.num_samples = x.shape[0]

      if y is not None:
          # Train
          self.y = y.values.reshape(-1, 1)
      else:
          # Test
          self.y = np.zeros((self.num_samples, 1))

      self.cat_cols = categories
      self.cont_cols = [col for col in x.columns if col not in categories]

      self.x_cont = x[self.cont_cols].astype(np.float32).values

      self.x_cate = x[categories].astype(np.int64).values

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return [self.x_cont[idx], self.x_cate[idx], self.y[idx]]

![](images/nn-model.png)


In [None]:
class TabularNN(nn.Module):
    def __init__(self, categories_dict, num_continuous, num_classes, layers, dropout):
        super().__init__()
        
        self.embeds = nn.ModuleList([nn.Embedding(x, y) for x, y in categories_dict])
        num_embed_out = sum([y for _, y in categories_dict])
        num_continuous = num_continuous
        self.num_classes = num_classes
        
        # FC
        fc1 = nn.Linear(num_embed_out + num_continuous, layers[0])
        self.fc_list = nn.ModuleList([fc1] + [nn.Linear(layers[i], layers[i+1]) 
                                              for i in range(len(layers) - 1)])
        
        for layer in self.fc_list:
            nn.init.kaiming_normal_(layer.weight.data, nonlinearity='relu')
            
        self.fc_out = nn.Linear(layers[-1], num_classes)
        nn.init.kaiming_normal_(self.fc_out.weight.data, nonlinearity='relu')
        
        # BN
        self.continuous_var_bn = nn.BatchNorm1d(num_continuous)
        self.bn_layers = nn.ModuleList([nn.BatchNorm1d(layer)
                                        for layer in layers])

        # Dropout
        self.embed_dropout = nn.Dropout(dropout)
        self.droput_layers = nn.ModuleList([nn.Dropout(dropout) for layer in layers])
        
    def forward(self, categorical, continuous):
        
        embeded = [layer(categorical[:, i]) for i, layer in enumerate(self.embeds)]
        embeded = torch.cat(embeded, 1)
        embeded = self.embed_dropout(embeded)
        
        norm_continuous = self.continuous_var_bn(continuous)
        
        # concat categorical and continuous data as input for fc
        x = torch.cat([embeded, norm_continuous], 1)
        
        for fc, bn, dropout in zip(self.fc_list, self.bn_layers, self.droput_layers):
            x = F.relu(fc(x))
            x = bn(x)
            x = dropout(x)
            
        return self.fc_out(x)
            

In [None]:
net = TabularNN(categories_dict, train_x.shape[1] - len(categorical_features), 
                len(train_y['age_group'].unique()), [250, 100], 0.5)
net = net.to(device)
net

TabularNN(
  (embeds): ModuleList(
    (0): Embedding(2, 4)
    (1): Embedding(364, 100)
    (2): Embedding(227, 100)
    (3): Embedding(136, 100)
    (4): Embedding(4, 8)
  )
  (fc_list): ModuleList(
    (0): Linear(in_features=518, out_features=250, bias=True)
    (1): Linear(in_features=250, out_features=100, bias=True)
  )
  (fc_out): Linear(in_features=100, out_features=6, bias=True)
  (continuous_var_bn): BatchNorm1d(206, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_layers): ModuleList(
    (0): BatchNorm1d(250, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (embed_dropout): Dropout(p=0.5)
  (droput_layers): ModuleList(
    (0): Dropout(p=0.5)
    (1): Dropout(p=0.5)
  )
)

In [None]:
def train(dataset, loader, net, optimizer):
    net.train()
    total_loss = 0
    for cont, cat, y in tqdm(loader, leave=False):
        cont = cont.to(device)
        cat = cat.to(device)
        y = y.long().flatten().to(device)
        optimizer.zero_grad()

        pred = net(cat, cont)
#         print(y.shape)
#         print(pred.shape)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss/float(len(dataset))

def predict(dataset, loader, net):
    net.eval()
    corrects = eval_loss = 0
    result = []
    for cont, cat, y in tqdm(loader):
        cont = cont.to(device)
        cat = cat.to(device)
        pred = net(cat, cont)
        
        result.append(torch.max(pred, 1)[1].view(y.size()).data)
    return result

def evaluate(dataset, loader, net):
    net.eval()
    corrects = eval_loss = 0

    for cont, cat, y in tqdm(loader):
        cont = cont.to(device)
        cat = cat.to(device)
        y = y.long().flatten().to(device)
        pred = net(cat, cont)
        loss = criterion(pred, y)
        
        eval_loss += loss.item()
        corrects += (torch.max(pred, 1)[1].view(y.size()).data == y.data).sum()
    #loss, correct count, accuracy
    return eval_loss/float(len(dataset)), corrects, corrects*100/len(dataset)

In [None]:
import sklearn
import sklearn.model_selection
splits = 3
kfold = sklearn.model_selection.StratifiedKFold(splits, shuffle=True)

In [None]:
batch_size = 502500
epoch = 24
lr = 0.1
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

In [None]:
for train_index, test_index in kfold.split(train_x, train_y):
    net = TabularNN(categories_dict, train_x.shape[1] - len(categorical_features), 
                len(train_y['age_group'].unique()), [250, 120], 0.3)
    net = net.to(device)
    train_dataset = TabularDataset(train_x.iloc[train_index], categorical_features, train_y.iloc[train_index])
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    validation_dataset = TabularDataset(train_x.iloc[test_index], categorical_features, train_y.iloc[test_index])
    validation_loader = DataLoader(validation_dataset, batch_size, shuffle=False)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    # learning rate decay
    scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.3) 
    for i in tqdm(range(epoch)):
        print(train(train_dataset, train_loader, net, optimizer))
        print(evaluate(validation_dataset, validation_loader, net))
        scheduler.step()
    print("train acc:", evaluate(train_dataset, train_loader, net))
    del train_dataset
    del train_loader
    print("validation acc:", evaluate(validation_dataset, validation_loader, net))

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.32233216676696e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.658907570211684e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.051324946219962e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.241084214181417e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.961272847257897e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.369236964949579e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.256425423130736e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.9847108639504867e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.936762005121428e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.8895170153989322e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.882194356918214e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.85097315379899e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.7273484912076433e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.7108198439901472e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6937030675179123e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.685325050188463e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6772977219411767e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.668459591995556e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6551769983625244e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6512288549159518e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6467758508551544e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6440850994652257e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6401398028097822e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6372819803654896e-06



HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


train acc: (2.5633721656874274e-06, tensor(678297, device='cuda:0'), tensor(50, device='cuda:0'))


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))


validation acc: (3.429505581503847e-06, tensor(338335, device='cuda:0'), tensor(50, device='cuda:0'))


HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.413142933774351e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

5.047018581361913e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.291045398854498e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.501511178799529e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.220747058071307e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.010594069068111e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.261422844075445e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.033129937613188e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.928875097587927e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.8998144526979815e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.8339096859319886e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.827934364774334e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.7388957899008225e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.7221104102348213e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.7103239030980353e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.7025323305557025e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.692832519758993e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.687707765778499e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6733889508603225e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.669983301589738e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.666580321183845e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6646308934510645e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6605795568494656e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.6579663824679247e-06



HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


train acc: (2.588310704302432e-06, tensor(671368, device='cuda:0'), tensor(50, device='cuda:0'))


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))


validation acc: (3.4632133014166534e-06, tensor(333941, device='cuda:0'), tensor(49, device='cuda:0'))


HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.435211490761825e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

5.0828563065591965e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

4.069436562870144e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.885265020117365e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.7424096960089175e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.8981270701262555e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

3.1578951370943475e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

2.991820351649742e-06


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))