In [47]:
import os
import json
from collections import OrderedDict

import torch
from torch import nn
from torch import optim
from tqdm.notebook import tqdm

from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Tuple, Any
from random import choice

import numpy as np

Connect to google drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 1) Load the data

In [32]:
data_dir = '/content/drive/My Drive/datasets/2009-skoltech-hack/data'

# read the data
data_path = os.path.join(data_dir, 'train_data_lognorm_full.json')
with open(data_path, 'r') as json_file:
    data = json.load(json_file)

# read the labels
labels_path = os.path.join(data_dir, 'train_labels.json')
with open(labels_path, 'r') as json_file:
    labels = json.load(json_file)

print('Loaded')

Loaded


## 2) Define Dataset and DataLoader

In [33]:
class ProductDataset(Dataset):
    def __init__(self, 
                 data: Dict[int, List[Any]], 
                 labels: Dict[int, List[int]]):
        """
        data: Dict with structure: {client_id: List of features}
        labels: Dict with structure: {client_id: List with 7 zeros or ones}
        """
        super(ProductDataset, self).__init__()
        self.data = data
        self.labels = labels
        
        self.ix_to_key = dict()
        
        for ix, key in enumerate(self.data):
            self.ix_to_key[ix] = key
            

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        cust_id = self.ix_to_key[idx] # get customer id
        
        products = np.zeros(7, dtype=np.float32)
        for label_ix, product_status in enumerate(self.labels[cust_id]):
            if product_status:
                products[label_ix] = 1
                
        # randomly close one product
        chosen_ix = choice(np.where(products == 1)[0])
        
        # zero that position
        products[chosen_ix] = 0
        
        # form one-hot target
        targets = np.zeros_like(products, dtype=np.float32)
        targets[chosen_ix] = 1
        
        item = {
            "features": {},
            "targets": targets,
        }

        features = np.array(self.data[cust_id])
        features = np.hstack((features, products))
        
        item["features"] = torch.from_numpy(features).float()


        return item

In [86]:
BATCH_SIZE = 16

product_dataset = ProductDataset(data=data, 
                                 labels=labels)

product_dataloader = DataLoader(dataset=product_dataset, 
                                batch_size = BATCH_SIZE, 
                                shuffle=True, 
                                num_workers = 4)

In [87]:
product_dataset[2]['features'].shape

torch.Size([169])

## 3) Define the model

In [88]:
params = {
    'input_dim': 162+7,
    'emb_dim': 64,
    'transformer_nhead': 2,
    'transformer_dim_feedforward': 32,
    'transformer_dropout': 0.1,
    'dense_unit': 128,
    'num_layers': 2,
    'n_products': 7
}

In [89]:
class ProductModel(nn.Module):
    def __init__(self, params):
        super().__init__()

        # Let's say embedding
        self.embedding = nn.Linear(in_features=params['input_dim'], 
                                   out_features=params["emb_dim"])

        transformer_blocks = []
        for i in range(params["num_layers"]):
            transformer_block = nn.TransformerEncoderLayer(
                d_model=params["emb_dim"],
                nhead=params["transformer_nhead"],
                dim_feedforward=params["transformer_dim_feedforward"],
                dropout=params["transformer_dropout"],
            )
            transformer_blocks.append(
                (f"transformer_block_{i}", transformer_block)
            )

        self.transformer_encoder = nn.Sequential(
            OrderedDict(transformer_blocks)
        )

        self.linear = nn.Linear(
            in_features=params["emb_dim"], out_features=params["dense_unit"]
        )
        self.scorer = nn.Linear(
            in_features=params["dense_unit"],
            out_features=params["n_products"],
        )

    def forward(self, features):

        emb_features = self.embedding(features).unsqueeze(1)

        transformer_output = self.transformer_encoder(emb_features)
        pooling = torch.mean(transformer_output, dim=1)
        linear = torch.tanh(self.linear(pooling))
        merch_logits = self.scorer(linear)

        return merch_logits

### Sanity check

In [90]:
model = ProductModel(params=params)

criterion = nn.BCEWithLogitsLoss()
batch = next(iter(product_dataloader))
output = model(batch['features'])

loss = criterion(output, batch['targets'])
print(loss)

tensor(0.6641, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)


## 4) Train loop without Catalyst

In [73]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [74]:
def calculate_metrics(logits: torch.tensor, 
                      targets: torch.tensor, 
                      num_classes: int = 7):
    """
    Calculates position-wise count matched
    """

    logits = logits.detach().cpu().numpy()
    prediction = np.argmax(logits, axis=1)
    prediction_one_hot = np.zeros((prediction.size, num_classes))
    prediction_one_hot[np.arange(prediction.size), prediction] = 1
    prediction_one_hot = prediction_one_hot.astype(bool)

    targets = targets.detach().cpu().numpy().astype(bool)

    matches = np.logical_and(prediction_one_hot, targets).astype(int)
    count_matches = np.sum(matches, axis=0)
    
    return count_matches

In [96]:
def train(model, 
          train_dataloader, 
          criterion,
          optimizer,
          device: str = 'cuda', 
          num_epochs: int = 10, 
          num_classes: int = 7):
    """
    model: torch nn.Module
    train_dataloader: torch DataLoader
    """

    all_losses = []
    all_counts = []

    model = model.to(device)
    
    for epoch in range(num_epochs):
        running_loss = 0
        num_samples_per_epoch = 0
        class_wise_counts = np.zeros(num_classes, dtype = int)
        for batch in tqdm(train_dataloader):
            # train body

            features = batch['features'].to(device)
            targets = batch['targets'].to(device)

            logits = model(features)

            loss = criterion(logits, batch['targets'].to(device))

            optimizer.zero_grad()
            optimizer.step()

            # save metrics
            with torch.no_grad():
                running_loss += loss.item()
                class_wise_counts += calculate_metrics(logits=logits, targets=targets)

                all_losses.append(running_loss)
                all_counts.append(class_wise_counts / BATCH_SIZE)
                num_samples_per_epoch += len(targets)
        
        # epoch metrics

        running_loss /= len(train_dataloader)

        class_wise_counts = class_wise_counts.astype(np.float64)
        class_wise_counts /= num_samples_per_epoch

        print(f'EPOCH: {epoch + 1}')
        print(f'### BSE loss: {running_loss:.5f}')
        print(f'###### Accuracy:', '\t'.join([str(elem)[:5] for elem in class_wise_counts]))
        print()

    return model, all_losses, all_counts


In [97]:
model = ProductModel(params=params)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())

trained_model, losses, counts = train(model=model, 
                                      train_dataloader=product_dataloader, 
                                      criterion=criterion, 
                                      optimizer=optimizer, 
                                      device=device)

HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 1
### BSE loss: 0.72125
###### Accuracy: 0.0	0.045	0.005	0.002	0.007	0.010	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 2
### BSE loss: 0.72152
###### Accuracy: 0.0	0.044	0.006	0.002	0.007	0.010	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 3
### BSE loss: 0.72125
###### Accuracy: 0.0	0.043	0.006	0.001	0.007	0.011	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 4
### BSE loss: 0.72142
###### Accuracy: 0.0	0.045	0.005	0.002	0.007	0.011	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 5
### BSE loss: 0.72128
###### Accuracy: 2.023	0.045	0.006	0.002	0.007	0.010	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 6
### BSE loss: 0.72119
###### Accuracy: 4.046	0.045	0.006	0.001	0.008	0.011	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 7
### BSE loss: 0.72129
###### Accuracy: 6.069	0.044	0.006	0.002	0.007	0.011	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 8
### BSE loss: 0.72127
###### Accuracy: 0.0	0.044	0.006	0.002	0.007	0.011	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 9
### BSE loss: 0.72129
###### Accuracy: 2.023	0.043	0.006	0.001	0.007	0.010	0.000



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 10
### BSE loss: 0.72137
###### Accuracy: 2.023	0.044	0.006	0.002	0.007	0.011	0.000

