In [None]:
! pip install -q kaggle

In [None]:
from google.colab import files

files.upload()

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle competitions download -c test-recsys

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir -p "/content/drive/My Drive/sbermarket"

In [None]:
from zipfile import ZipFile
import os

for filename in os.listdir('.'):
    if filename.endswith(".zip"):
        with ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('drive/My Drive/sbermarket')
    else:
        continue

In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

import tqdm
from itertools import chain
import seaborn as sns

from collections import defaultdict
from typing import List

from zipfile import ZipFile
import os

from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math

import json
import pickle
from pathlib import Path

WORKING_DIR = 'drive/My Drive/sbermarket'

columns = ['price', 'quantity', 'discount', 'product_name', 'prod_idx', 'par_idx', 'mast_idx']

MAX_NPROD = 50

### Ищем репрезентативных пользователей

In [None]:
df = pd.DataFrame()

for filename in sorted(os.listdir('drive/My Drive/sbermarket/')):
    if filename.startswith('tab_2_products') and filename.endswith('csv'):
        print(filename)
        new_df = pd.read_csv('drive/My Drive/sbermarket/'+filename)
        new_df = pd.DataFrame(new_df.groupby('user_id')['order_id'].nunique())

        df = df.add(new_df, fill_value=0)

In [None]:
repr_users = df[df['order_id'] > 5][df['order_id'] < 15].reset_index().drop('order_id', axis=1)
repr_users.to_csv('drive/My Drive/sbermarket/repr_users.csv')
print(repr_users.shape)

  """Entry point for launching an IPython kernel.


(77368, 1)


the dataset will consist of 77368 users

In [None]:
repr_users = pd.read_csv('drive/My Drive/sbermarket/repr_users.csv')

### Filter tables

In [None]:
dfs = []
for filename in tqdm.tqdm(sorted(os.listdir('drive/My Drive/sbermarket'))):
    if filename.startswith('tab_2_products') and filename.endswith('csv'):
        print(filename)
        df = pd.read_csv('drive/My Drive/sbermarket/'+filename)
        df = df.set_index('user_id')
        df = df.loc[list(set(df.index).intersection(set(repr_users['user_id'])))]
        df = df.reset_index()
        dfs.append(df)


In [None]:
products_df = pd.concat(dfs)
products_df.dropna(subset=['product_name', 'master_category_id'], inplace=True)

for col, new_col in zip(['product_id', 'parent_category_id', 'master_category_id'],
                        ['prod_idx', 'par_idx', 'mast_idx']):
    products_df[col] = products_df[col].astype('category')
    products_df[new_col] = products_df[col].cat.codes

print(products_df.shape)

(19357908, 14)


In [None]:
products_df.to_pickle('drive/My Drive/sbermarket/filtered_df.pckl')

In [None]:
nproducts = len(products_df.product_id.unique())
ncategories = len(products_df.parent_category_id.unique())
ncategories2 = len(products_df.master_category_id.unique())
print(nproducts, ncategories, ncategories2)

prod2idx = {prod_id: idx for idx, prod_id in enumerate(products_df.product_id.unique())}
mast_cat2idx = {prod_id: idx for idx, prod_id in enumerate(products_df.master_category_id.unique())}
par_cat2idx = {prod_id: idx for idx, prod_id in enumerate(products_df.parent_category_id.unique())}

89931 119 608


In [None]:
meta = {'nproducts' : nproducts, 'ncategories': ncategories, 'ncategories2': ncategories2, 
        'prod2idx': prod2idx,
        'mast_cat2idx': mast_cat2idx,
        'par_cat2idx': par_cat2idx}

with Path('drive/My Drive/sbermarket/meta.json').open('w') as f:
    json.dump(meta, f)

In [None]:
products_df = pd.read_pickle('drive/My Drive/sbermarket/filtered_df.pckl')

In [None]:
products_df.shape

(19357908, 14)

In [None]:
with Path('drive/My Drive/sbermarket/meta.json').open('r') as f:
    meta = json.load(f)


nproducts = meta['nproducts']
ncategories = meta['ncategories']
ncategories2 = meta['ncategories2']
prod2idx = meta['prod2idx']
par_cat2idx = meta['par_cat2idx']
mast_cat2idx = meta['mast_cat2idx']

In [None]:
import gc
gc.collect()

193

### PyTorch Dataset

In [None]:
class MarketDataset(Dataset):
    def __init__(self, dataset, max_nprod=None, max_nord=None):
        self.df_dataset = dataset.set_index('user_id')
        self.user2idx = {u_id: i for i, u_id in enumerate(self.df_dataset.index.unique())}
        self.idx2user = {i: u_id for u_id, i in self.user2idx.items()}
        self.columns = ['price', 'quantity', 'discount', 'product_name', 'prod_idx', 'par_idx', 'mast_idx']
        self.max_nprod = float('inf') if max_nprod is None else max_nprod
        self.max_nord = max_nord
    
    def __len__(self):
        return len(self.user2idx)

    def __getitem__(self, user_idx):
        if user_idx < 2600:
            return []
        item = defaultdict(list)
        try:
            user_df = self.df_dataset.loc[self.idx2user[user_idx]]
            nprods = []
            orders_ids = list(user_df['order_id'])

            if self.max_nord is not None:
                max_nord = len(orders_ids)
            else:
                max_nord = max_nord

            line_ids = []
            for order in orders_ids[-max_nord:]:
                line_ids_ = user_df.query(f'order_id == {order}')['line_item_id']
                line_ids_ = line_ids_.sample(min(self.max_nprod, len(line_ids_)))
                products_from_order = user_df[user_df['line_item_id'].isin(line_ids_)]
                for c in columns:
                    item[c].append(products_from_order[c].values.tolist())
        except:
             return self.__getitem__(user_idx+1)

        return item

In [None]:
from torch.nn.utils.rnn import pad_sequence

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def collate_fn(batch):
    # encode sentences
    sentences = []
    items_without_name = defaultdict(list)
    try:
        for b in batch:
            for c in columns:
                if c != 'product_name':
                    tensor_per_batch = pad_sequence(
                        [torch.Tensor(x) for x in b[c]], batch_first=True # for each order in sample
                    ).to(device) #
                    items_without_name[c].append(tensor_per_batch)

        # items_without_name['product_name'] = sentences
        for k, v in items_without_name.items():
            items_without_name[k] = torch.stack(v, 0) # batch dim
        return items_without_name
    except:
        return []

In [None]:
import gc
gc.collect()

523

### Model

In [None]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
        

In [None]:
class TransformerRec(nn.Module):
    """Container module with an encoder, a recurrent or transformer module, and a decoder."""

    def __init__(self, nproducts, ncategories1, ncategories2, nhead, nhid, nlayers, dropout=0.5):
        super().__init__()

        self.price_emb = nn.Linear(1, 1)
        self.quantity_emb = nn.Linear(1, 1)
        self.discount_emb = nn.Linear(1, 1)

        self.product_id_emb = nn.Embedding(nproducts, 201)
        self.parent_id_emb = nn.Embedding(ncategories1, 100)
        self.master_id_emb = nn.Embedding(ncategories2, 100)

        self.ninp = 3 + 201 + 1 * 100 

        self.src_mask = None
        self.pos_encoder = PositionalEncoding(self.ninp, dropout)
        encoder_layers = TransformerEncoderLayer(self.ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        
        self.decoder = nn.Linear(self.ninp, nproducts)
        self.device = 'cuda' if torch.cuda.is_available else 'cpu'

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, has_mask=True):
        tr_len = src['prod_idx'].shape[1]
        if has_mask:
            device = src['prod_idx'].device
            if self.src_mask is None or self.src_mask.size(0) != tr_len:
                mask = self._generate_square_subsequent_mask(tr_len).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None
        src_ = {}
        for k, v in src.items():
            if not k.endswith('idx') and k != 'product_name': 
                src_[k] = v.reshape(list(v.shape) + [1])
            else:
                src_[k] = v

        src_ = torch.cat([
                          # src_['product_name'],
                          self.price_emb(src_['price'].float()), 
                          self.quantity_emb(src_['quantity'].float()),
                          self.discount_emb(src_['discount'].float()), 
                          self.product_id_emb(src_['prod_idx'].long()),
                          self.parent_id_emb(src_['par_idx'].long()),
                          #self.master_id_emb(src_['mast_idx'].long()),
                          ], -1)
        mask = (src['prod_idx'] == 0)
        src_[mask] = 0.0
        src_ = (src_.sum(2) / ((~mask).sum(2)+1e-5).unsqueeze(-1).repeat(1,1,src_.shape[-1])).permute(1, 0, 2)

        #src_ = src_ * math.sqrt(self.ninp)
        #src = self.pos_encoder(src_)
        output = self.transformer_encoder(src_, self.src_mask)
        output = self.decoder(output)
        return output.permute(1, 0, 2)


In [None]:
def get_input_target(batch, device):
    return {k:v[:, :-1, :].to(device) for k, v in batch.items()}, \
    batch['prod_idx'][:, 1:, :].to(device)

In [None]:
def loss_fn(pred, target):
    pred = pred[0].repeat(target.shape[-1], 1)
    mask = (target == 0)
    target = target[0].reshape(-1)
    mask = mask[0].reshape(-1) 
    loss = F.cross_entropy(pred, target, reduction='none')
    mask = target == 0
    loss[mask] = 0.0
    return loss.mean()

### Euristics: consider only those products that user has already bought earlier

Very useful to get higher metric

In [None]:
def filter_pred(pred, inp_prods):
    hist = torch.zeros(pred.shape[-1]).to(pred.device)
    for p in inp_prods.squeeze():
        hist[p] = 1
    return pred * hist

In [None]:
def apk(actual, predicted, k=10):
    score = 0.0
    num_hits = 0.0

    for i,p in enumerate(predicted):
        if p in actual and p:
            num_hits += 1.0
            score += num_hits / (i+1.0)
    if not actual:
        return 0.0

    return score / min(len(actual), k)

def mapk(true, predicted, k=10):
    '''
    true, predicted: tensors
    '''
    predicted = torch.topk(predicted, k)[1]
    return np.mean([apk(a,p,k) for a,p in zip(true.tolist(), predicted.tolist())])

In [None]:
n_epoch = 10
n_heads = 4
n_layers = 3
n_hid = 128
dropout = 0.2
lr = 1e-3
model_file_name = 'model3.pt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = TransformerRec(nproducts, ncategories, ncategories2, n_heads,  n_hid, n_layers, dropout)
_ = model.to(device)


In [None]:
optimizer = Adam(model.parameters(), lr=lr)

In [None]:
market_dataset = MarketDataset(products_df)
trainloader = DataLoader(market_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

#### Train loop

In [None]:
if hasattr(tqdm.tqdm, '_instaces'):
    tqdm.tqdm._instances.clear()

if hasattr(tqdm.trange, '_instaces'):
    tqdm.trange._instances.clear()

model.train()
for ep in range(n_epoch):
    ep_loss = 0
    map_ep = 0
    batch_id = 0
    for batch in tqdm.tqdm(trainloader):
        batch_id += 1
        input, target = get_input_target(batch, device)
        pred = model(input)
        pred = filter_pred(pred, input['prod_idx'].long())
        loss = loss_fn(pred, target.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        map_ep += mapk(target[0], pred[0], MAX_NPROD)
        ep_loss += loss.item()

        if batch_id % 100 == 0:
            print(f'\tIter {batch_id}, epoch loss: {ep_loss/ batch_id}, map_ep: {map_ep/ batch_id}')
        if batch_id % 1000 == 0:
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss': ep_loss/ batch_id,
                    'map': map_ep/ batch_id
                }, Path(WORKING_DIR, model_file_name)
            )

    print(f'MAP: {map_ep / len(trainloader)}')
    print(f'Epoch {ep+1}, loss: {ep_loss / len(trainloader)}')

  0%|          | 101/77361 [00:11<2:19:44,  9.21it/s]

	Iter 100, epoch loss: 6.231178970336914, map_ep: 0.06675325912523791


  0%|          | 202/77361 [00:22<2:15:08,  9.52it/s]

	Iter 200, epoch loss: 5.790545735359192, map_ep: 0.07126948063865615


  0%|          | 301/77361 [00:33<2:23:06,  8.97it/s]

	Iter 300, epoch loss: 5.598095037937164, map_ep: 0.071543121596606


  1%|          | 401/77361 [00:44<2:38:37,  8.09it/s]

	Iter 400, epoch loss: 5.404175570607185, map_ep: 0.06921290394426599


  1%|          | 501/77361 [00:56<2:37:32,  8.13it/s]

	Iter 500, epoch loss: 5.231935997247696, map_ep: 0.06689767253140777


  1%|          | 601/77361 [01:07<2:32:01,  8.42it/s]

	Iter 600, epoch loss: 5.114535999298096, map_ep: 0.06724977129533156


  1%|          | 701/77361 [01:18<2:18:57,  9.19it/s]

	Iter 700, epoch loss: 5.022688857657569, map_ep: 0.06927316486489751


  1%|          | 801/77361 [01:29<2:17:51,  9.26it/s]

	Iter 800, epoch loss: 4.9590394744277, map_ep: 0.06917178329759788


  1%|          | 901/77361 [01:40<2:22:04,  8.97it/s]

	Iter 900, epoch loss: 4.902243313392003, map_ep: 0.06883118177460903


  1%|▏         | 999/77361 [01:51<2:23:42,  8.86it/s]

	Iter 1000, epoch loss: 4.834290717840195, map_ep: 0.06881594771660425


  1%|▏         | 1101/77361 [02:06<2:41:41,  7.86it/s]

	Iter 1100, epoch loss: 4.7769488273967395, map_ep: 0.06929719778825977


  2%|▏         | 1201/77361 [02:18<2:19:32,  9.10it/s]

	Iter 1200, epoch loss: 4.7402947451670965, map_ep: 0.06902387860404934


  2%|▏         | 1301/77361 [02:29<2:27:52,  8.57it/s]

	Iter 1300, epoch loss: 4.706100970598367, map_ep: 0.07019363026885725


  2%|▏         | 1401/77361 [02:40<2:22:08,  8.91it/s]

	Iter 1400, epoch loss: 4.672594692877361, map_ep: 0.0698467938350927


  2%|▏         | 1502/77361 [02:51<2:16:28,  9.26it/s]

	Iter 1500, epoch loss: 4.631651905536652, map_ep: 0.06964984629466264


  2%|▏         | 1601/77361 [03:02<2:22:17,  8.87it/s]

	Iter 1600, epoch loss: 4.582132759168744, map_ep: 0.06936992612957611


  2%|▏         | 1702/77361 [03:13<2:15:36,  9.30it/s]

	Iter 1700, epoch loss: 4.550312948016559, map_ep: 0.06939555051008076


  2%|▏         | 1801/77361 [03:24<2:16:13,  9.24it/s]

	Iter 1800, epoch loss: 4.5226596923338045, map_ep: 0.06918716514317919


  2%|▏         | 1901/77361 [03:36<2:19:23,  9.02it/s]

	Iter 1900, epoch loss: 4.494540963329767, map_ep: 0.06914980128046815


  3%|▎         | 1999/77361 [03:47<2:13:21,  9.42it/s]

	Iter 2000, epoch loss: 4.460304840296507, map_ep: 0.06908976499774122


  3%|▎         | 2101/77361 [04:02<2:31:20,  8.29it/s]

	Iter 2100, epoch loss: 4.441436772034281, map_ep: 0.06921887553477342


  3%|▎         | 2201/77361 [04:14<2:17:48,  9.09it/s]

	Iter 2200, epoch loss: 4.425838504894213, map_ep: 0.06924944823245467


  3%|▎         | 2301/77361 [04:25<2:41:15,  7.76it/s]

	Iter 2300, epoch loss: 4.406884370342545, map_ep: 0.06891581007544471


  3%|▎         | 2400/77361 [04:36<2:15:45,  9.20it/s]

	Iter 2400, epoch loss: 4.393369459037979, map_ep: 0.06904325533013533


  3%|▎         | 2500/77361 [04:47<2:24:27,  8.64it/s]

	Iter 2500, epoch loss: 4.375604136013985, map_ep: 0.06923115764731787


  3%|▎         | 2601/77361 [04:58<2:20:27,  8.87it/s]

	Iter 2600, epoch loss: 4.358804405446236, map_ep: 0.06925092014550396


  3%|▎         | 2701/77361 [05:09<2:19:40,  8.91it/s]

	Iter 2700, epoch loss: 4.335560906485275, map_ep: 0.06988044852067177


  4%|▎         | 2800/77361 [05:20<2:16:34,  9.10it/s]

	Iter 2800, epoch loss: 4.3242759546424665, map_ep: 0.06995952012624775


  4%|▎         | 2900/77361 [05:31<2:17:56,  9.00it/s]

	Iter 2900, epoch loss: 4.311643990093264, map_ep: 0.07013191136033775


  4%|▍         | 2999/77361 [05:42<2:21:16,  8.77it/s]

	Iter 3000, epoch loss: 4.29739121200641, map_ep: 0.07010845940897042


  4%|▍         | 3101/77361 [05:56<2:26:45,  8.43it/s]

	Iter 3100, epoch loss: 4.284602497104675, map_ep: 0.0701735602503585


  4%|▍         | 3201/77361 [06:08<2:33:10,  8.07it/s]

	Iter 3200, epoch loss: 4.277457450833172, map_ep: 0.07016581081406342


  4%|▍         | 3301/77361 [06:19<2:09:57,  9.50it/s]

	Iter 3300, epoch loss: 4.268545068231496, map_ep: 0.07030554945231253


  4%|▍         | 3401/77361 [06:30<2:13:06,  9.26it/s]

	Iter 3400, epoch loss: 4.259240198152907, map_ep: 0.07020360032515212


  5%|▍         | 3502/77361 [06:42<2:09:53,  9.48it/s]

	Iter 3500, epoch loss: 4.244847608923912, map_ep: 0.06997844059608364


  5%|▍         | 3601/77361 [06:52<2:21:41,  8.68it/s]

	Iter 3600, epoch loss: 4.231717937373452, map_ep: 0.07011057228051022


  5%|▍         | 3702/77361 [07:04<2:08:35,  9.55it/s]

	Iter 3700, epoch loss: 4.225263693477657, map_ep: 0.06992110480338157


  5%|▍         | 3801/77361 [07:15<2:03:17,  9.94it/s]

	Iter 3800, epoch loss: 4.214515442769779, map_ep: 0.06986464003345796


  5%|▌         | 3901/77361 [07:26<2:17:18,  8.92it/s]

	Iter 3900, epoch loss: 4.205052813062301, map_ep: 0.06993319005667484


  5%|▌         | 3998/77361 [07:37<2:26:51,  8.33it/s]

	Iter 4000, epoch loss: 4.189398554265499, map_ep: 0.06977503344192022


  5%|▌         | 4101/77361 [07:53<2:41:20,  7.57it/s]

	Iter 4100, epoch loss: 4.180827972394664, map_ep: 0.06978364845145925


  5%|▌         | 4201/77361 [08:05<2:13:29,  9.13it/s]

	Iter 4200, epoch loss: 4.172691691461064, map_ep: 0.06967899899366019


  6%|▌         | 4300/77361 [08:16<2:20:38,  8.66it/s]

	Iter 4300, epoch loss: 4.168691414040189, map_ep: 0.06958973971190559


  6%|▌         | 4401/77361 [08:27<2:20:11,  8.67it/s]

	Iter 4400, epoch loss: 4.162341660748829, map_ep: 0.0698408956686384


  6%|▌         | 4501/77361 [08:38<2:07:02,  9.56it/s]

	Iter 4500, epoch loss: 4.159897914065255, map_ep: 0.06971622089313798


  6%|▌         | 4600/77361 [08:49<2:28:01,  8.19it/s]

	Iter 4600, epoch loss: 4.148454149598661, map_ep: 0.069755403166225


  6%|▌         | 4702/77361 [09:00<2:10:24,  9.29it/s]

	Iter 4700, epoch loss: 4.139748374857801, map_ep: 0.06975855135132245


  6%|▌         | 4802/77361 [09:11<2:10:44,  9.25it/s]

	Iter 4800, epoch loss: 4.134256210997701, map_ep: 0.069684012143645


  6%|▋         | 4901/77361 [09:23<2:19:48,  8.64it/s]

	Iter 4900, epoch loss: 4.125025893839038, map_ep: 0.06954198681902339


  6%|▋         | 4999/77361 [09:33<2:03:56,  9.73it/s]

	Iter 5000, epoch loss: 4.119609796071052, map_ep: 0.06944273421074287


  7%|▋         | 5102/77361 [09:49<2:18:03,  8.72it/s]

	Iter 5100, epoch loss: 4.115147534164728, map_ep: 0.06943693215787285


  7%|▋         | 5201/77361 [10:01<2:23:45,  8.37it/s]

	Iter 5200, epoch loss: 4.108381817180377, map_ep: 0.06949123728364191


  7%|▋         | 5301/77361 [10:12<2:10:41,  9.19it/s]

	Iter 5300, epoch loss: 4.103317554087009, map_ep: 0.06943664940157625


  7%|▋         | 5401/77361 [10:23<2:18:48,  8.64it/s]

	Iter 5400, epoch loss: 4.097233120335473, map_ep: 0.0694425666794068


  7%|▋         | 5501/77361 [10:35<2:11:18,  9.12it/s]

	Iter 5500, epoch loss: 4.090009396076202, map_ep: 0.06937422644721904


  7%|▋         | 5601/77361 [10:46<2:03:46,  9.66it/s]

	Iter 5600, epoch loss: 4.084501228758267, map_ep: 0.0694202603430246


  7%|▋         | 5702/77361 [10:57<2:07:03,  9.40it/s]

	Iter 5700, epoch loss: 4.080254271636929, map_ep: 0.06946918223951794


  7%|▋         | 5800/77361 [11:07<2:07:10,  9.38it/s]

	Iter 5800, epoch loss: 4.073483791433532, map_ep: 0.069528469347301


  8%|▊         | 5901/77361 [11:19<2:05:31,  9.49it/s]

	Iter 5900, epoch loss: 4.067159979969769, map_ep: 0.06947514574646546


  8%|▊         | 5999/77361 [11:29<2:06:21,  9.41it/s]

	Iter 6000, epoch loss: 4.064147082149982, map_ep: 0.0695257365870025


  8%|▊         | 6101/77361 [11:44<2:33:18,  7.75it/s]

	Iter 6100, epoch loss: 4.061694294272876, map_ep: 0.06949222414924848


  8%|▊         | 6201/77361 [11:56<2:11:44,  9.00it/s]

	Iter 6200, epoch loss: 4.056542699798461, map_ep: 0.06941196750848848


  8%|▊         | 6301/77361 [12:07<2:18:29,  8.55it/s]

	Iter 6300, epoch loss: 4.049205051963291, map_ep: 0.06947136261874921


  8%|▊         | 6400/77361 [12:18<2:16:25,  8.67it/s]

	Iter 6400, epoch loss: 4.04490090539679, map_ep: 0.06951266968288679


  8%|▊         | 6501/77361 [12:29<2:22:13,  8.30it/s]

	Iter 6500, epoch loss: 4.039758312748028, map_ep: 0.06971361801345717


  9%|▊         | 6600/77361 [12:40<2:20:04,  8.42it/s]

	Iter 6600, epoch loss: 4.035393200751507, map_ep: 0.06971841677572063


  9%|▊         | 6701/77361 [12:51<2:10:09,  9.05it/s]

	Iter 6700, epoch loss: 4.029876702792609, map_ep: 0.06980482072286373


  9%|▉         | 6801/77361 [13:02<2:23:29,  8.20it/s]

	Iter 6800, epoch loss: 4.023916173209162, map_ep: 0.06970081049150621


  9%|▉         | 6901/77361 [13:13<2:09:51,  9.04it/s]

	Iter 6900, epoch loss: 4.019691076987032, map_ep: 0.06962362388343439


  9%|▉         | 6999/77361 [13:24<2:16:15,  8.61it/s]

	Iter 7000, epoch loss: 4.016219884429659, map_ep: 0.06961721756576383


  9%|▉         | 7101/77361 [13:39<2:34:19,  7.59it/s]

	Iter 7100, epoch loss: 4.014595062950967, map_ep: 0.06966472735703821


  9%|▉         | 7201/77361 [13:51<2:13:10,  8.78it/s]

	Iter 7200, epoch loss: 4.0109587234920925, map_ep: 0.06975575624540382


  9%|▉         | 7300/77361 [14:02<1:56:01, 10.06it/s]

	Iter 7300, epoch loss: 4.008149371457426, map_ep: 0.06974280581481622


 10%|▉         | 7402/77361 [14:13<1:56:38, 10.00it/s]

	Iter 7400, epoch loss: 4.003177877672621, map_ep: 0.0697243599805821


 10%|▉         | 7501/77361 [14:24<2:11:15,  8.87it/s]

	Iter 7500, epoch loss: 3.9996716872294744, map_ep: 0.06966011229545982


 10%|▉         | 7601/77361 [14:34<2:10:22,  8.92it/s]

	Iter 7600, epoch loss: 3.9950075845263506, map_ep: 0.06962916133953237


 10%|▉         | 7701/77361 [14:45<2:03:53,  9.37it/s]

	Iter 7700, epoch loss: 3.989031089892635, map_ep: 0.06963379006581293


 10%|█         | 7801/77361 [14:57<2:05:09,  9.26it/s]

	Iter 7800, epoch loss: 3.985587878143176, map_ep: 0.06948306154772138


 10%|█         | 7901/77361 [15:08<3:05:19,  6.25it/s]

	Iter 7900, epoch loss: 3.9817798334963714, map_ep: 0.06953339444234849


 10%|█         | 7998/77361 [15:18<2:05:30,  9.21it/s]

	Iter 8000, epoch loss: 3.9796998354122044, map_ep: 0.06958160677638645


 10%|█         | 8101/77361 [15:33<2:26:57,  7.85it/s]

	Iter 8100, epoch loss: 3.9760749045934207, map_ep: 0.06954309627023805


 11%|█         | 8201/77361 [15:45<2:06:45,  9.09it/s]

	Iter 8200, epoch loss: 3.9724029135340597, map_ep: 0.06949640465539943


 11%|█         | 8301/77361 [15:57<2:27:49,  7.79it/s]

	Iter 8300, epoch loss: 3.9692332521763194, map_ep: 0.06941844041291852


 11%|█         | 8401/77361 [16:08<2:14:31,  8.54it/s]

	Iter 8400, epoch loss: 3.9662894921572436, map_ep: 0.0693863153734732


 11%|█         | 8501/77361 [16:19<2:20:39,  8.16it/s]

	Iter 8500, epoch loss: 3.9649966040989932, map_ep: 0.06939070685251597


 11%|█         | 8600/77361 [16:30<1:59:46,  9.57it/s]

	Iter 8600, epoch loss: 3.9614890445595563, map_ep: 0.06928256399524059


 11%|█         | 8701/77361 [16:41<1:59:51,  9.55it/s]

	Iter 8700, epoch loss: 3.9580903719622516, map_ep: 0.06928858581719988


 11%|█▏        | 8801/77361 [16:52<2:17:08,  8.33it/s]

	Iter 8800, epoch loss: 3.95603223440322, map_ep: 0.06932919624752493


 12%|█▏        | 8902/77361 [17:03<2:00:16,  9.49it/s]

	Iter 8900, epoch loss: 3.9510127390435574, map_ep: 0.06928607432592647


 12%|█▏        | 8999/77361 [17:14<2:10:14,  8.75it/s]

	Iter 9000, epoch loss: 3.9481651131974327, map_ep: 0.06932757526738481


 12%|█▏        | 9101/77361 [17:29<2:22:24,  7.99it/s]

	Iter 9100, epoch loss: 3.9445980128875147, map_ep: 0.06938306091297326


 12%|█▏        | 9201/77361 [17:41<2:16:53,  8.30it/s]

	Iter 9200, epoch loss: 3.9435485060836957, map_ep: 0.06951728217402427


 12%|█▏        | 9300/77361 [17:52<1:59:03,  9.53it/s]

	Iter 9300, epoch loss: 3.9419622288980793, map_ep: 0.06949023829522857


 12%|█▏        | 9401/77361 [18:03<2:14:55,  8.39it/s]

	Iter 9400, epoch loss: 3.9375231585350443, map_ep: 0.06940684561606164


 12%|█▏        | 9501/77361 [18:14<2:15:58,  8.32it/s]

	Iter 9500, epoch loss: 3.933832101006257, map_ep: 0.06938271169700132


 12%|█▏        | 9601/77361 [18:25<2:09:13,  8.74it/s]

	Iter 9600, epoch loss: 3.9303033717970055, map_ep: 0.06938161727434189


 13%|█▎        | 9701/77361 [18:36<2:07:17,  8.86it/s]

	Iter 9700, epoch loss: 3.928847713544197, map_ep: 0.06936569413359864


 13%|█▎        | 9800/77361 [18:47<2:08:42,  8.75it/s]

	Iter 9800, epoch loss: 3.924777972418435, map_ep: 0.0693432469838124


 13%|█▎        | 9901/77361 [18:58<2:10:28,  8.62it/s]

	Iter 9900, epoch loss: 3.9238367747538017, map_ep: 0.06931144880632328


 13%|█▎        | 9998/77361 [19:09<2:05:09,  8.97it/s]

	Iter 10000, epoch loss: 3.9205549569904803, map_ep: 0.0692687413278583


 13%|█▎        | 10062/77361 [19:19<2:02:03,  9.19it/s]

KeyboardInterrupt: ignored

In [None]:
model.load_state_dict(torch.load(Path(WORKING_DIR, model_file_name))['model'])
_ = model.eval()

In [None]:
submission_df = pd.read_pickle(Path(WORKING_DIR, 'submission1_df.pckl'))

In [None]:
gc.collect()

0

In [None]:
test_market_dataset = MarketDataset(submission_df, max_nord=5, max_nprod=20)

In [None]:
test_loader = DataLoader(test_market_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)

In [None]:
idx2prod = {idx:prod_id for prod_id, idx in prod2idx.items()}

In [None]:
with Path('subm.pkl').open('rb') as f:
    preds = pickle.load(f)

In [None]:
#preds = {'Id': [], 'Predicted': []}
idx2user = test_market_dataset.idx2user
with torch.no_grad():
    for batch_id, batch in enumerate(test_loader):
        if batch_id < 2600:
            continue
        if batch_id % 100 == 0:
            print(batch_id)
        if batch_id % 10000 == 0:
            with Path('subm.pkl').open('wb') as f:
                pickle.dump(preds, f)
        input = {k:v[:, :-1, :].to(device) for k, v in batch.items()}
        if input['prod_idx'].shape[1] == 0:
            preds['Id'].append(idx2user[batch_id])
            preds['Predicted'].append([int(idx2prod[p]) for p in pred])
        else:
            pred = model(input)
            pred = filter_pred(pred, input['prod_idx'].long())
            pred = pred[0, -1, :].topk(50)[1].tolist()
            preds['Id'].append(idx2user[batch_id])
            preds['Predicted'].append(' '.join([idx2prod[p] for p in pred]))


2600


In [None]:
pd.to_csv('submission.csv')