In [1]:
import pandas as pd
import torch
import pytorch_lightning as pl
from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm


from torchrecsys.datasets import InteractionsDataset, SequenceDataset
from torchrecsys.models import BaseModel
from torchrecsys.task import Ranking
from torchrecsys.layers import BruteForceLayer
import torchrecsys as trs

In [2]:
candidates = pd.read_csv("data/candidate_items.csv")
train_purchases = pd.read_csv("data/train_purchases.csv")
train_sessions = pd.read_csv("data/train_sessions.csv",  parse_dates=['date'])

test_sessions = pd.read_csv("data/test_leaderboard_sessions.csv")
final_test = pd.read_csv("data/test_final_sessions.csv")
all_interactions = pd.concat([train_sessions, train_purchases], ignore_index=True)

In [3]:
# Create list of browsed items in the session ordered by timestamp. WE DONT KNOW IF THE PURCHASED ITEM IS ALWAYS AFTER THE LASTEST BROWSED ITEM!! CHECK ON THIS
train_sessions = train_sessions.sort_values(['date']).groupby('session_id')["item_id"].apply(list).to_frame().reset_index()

# Rename target id 
train_purchases = train_purchases.rename(columns={"item_id": "purchased_item_id"})

In [4]:
#Merge sessions with their corresponding bough item. ##WARNING WE ARE DROPPING A LOT OF SESSIONS BECAUSE OF THIS. bert encoder idea.
train_sessions = train_sessions.merge(train_purchases[["session_id", "purchased_item_id"]], left_on='session_id', right_on='session_id').rename(columns={"item_id": "session_history"})

In [5]:
train_sessions.session_history.apply(len).mean()

4.74382

In [6]:
n_session, n_item = all_interactions.session_id.max()+1, all_interactions.item_id.max()+1

In [7]:
ds =  train_sessions[["session_history","purchased_item_id"]].values
train_ds, val_ds = train_test_split(ds, test_size=0.33, random_state=42)

window_length = 5
train_ds = SequenceDataset(train_ds, max_len=window_length)
val_ds = SequenceDataset(val_ds, max_len=window_length)

train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=6)
val_dataloader = torch.utils.data.DataLoader(val_ds, batch_size=128, shuffle=False, num_workers=2)

In [8]:
class retrievalModel(trs.BaseModel):
    def __init__(self):
        super().__init__()
        self.item_embeddings = torch.nn.Embedding(n_item, 128)
#         self.session_gru = torch.nn.GRU(input_size=128, hidden_size=128, num_layers=1, batch_first=True)
        self.linear = torch.nn.Linear(window_length*128, 128)
        
        self.session_model = torch.nn.Sequential(
            self.item_embeddings,
            torch.nn.Flatten(start_dim=1),
            self.linear,
        )
        
        self.item_model = torch.nn.Sequential(
            self.item_embeddings,
        )
            
        self.leakyrelu = torch.nn.LeakyReLU()
        self.task = Ranking()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        # Query, the session data
        session_embeddings = self.session_model(x[0])
        
        # Candidate, the target transaction
        item_embeddings = self.item_model(x[1])

        # The task computes the loss and the metrics.
        # return session_embeddings, item_embeddings
        return session_embeddings, item_embeddings
    
    
model = retrievalModel().cuda()

In [9]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# training
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[EarlyStopping(monitor="val_loss", patience=3)])
trainer.fit(model, train_dataloader, val_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type       | Params
-----------------------------------------------
0 | item_embeddings | Embedding  | 3.6 M 
1 | linear          | Linear     | 82.0 K
2 | session_model   | Sequential | 3.7 M 
3 | item_model      | Sequential | 3.6 M 
4 | leakyrelu       | LeakyReLU  | 0     
-----------------------------------------------
3.7 M     Trainable params
0         Non-trainable params
3.7 M     Total params
14.738    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Score

In [10]:
aux = test_sessions.sort_values(['date']).groupby('session_id')["item_id"].apply(list).to_frame().reset_index().rename(columns={"item_id": "session_history"})
ds =  aux[["session_history"]].values

window_length = 5
test_ds = SequenceDataset(ds, max_len=window_length, training=False)
test_dataloader = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2)


In [16]:
model = model.cuda()

In [17]:
brute_layer = BruteForceLayer(model.session_model, k=100)

unique_item_mapped = torch.stack(list(map(lambda x: model.item_model(torch.tensor(x).cuda()), all_interactions.item_id.unique())))

brute_layer.index(unique_item_mapped)

In [18]:
scores = []
indices = []
for batch in tqdm(test_dataloader):
    r = brute_layer(torch.tensor(batch).cuda())
    scores.append(r[0])
    indices.append(r[1])

  0%|          | 0/391 [00:00<?, ?it/s]

  r = brute_layer(torch.tensor(batch).cuda())


In [19]:
xdf = torch.cat(indices).cpu().numpy()

In [20]:
px = pd.DataFrame({
    "session_id": aux.session_id,
    "item_id": xdf.tolist(),
})

In [21]:
px = px.explode("item_id")
px["rank"] = px.groupby(["session_id"]).cumcount()+1

In [22]:
px.to_csv("results.csv", index=False)

In [23]:
px

Unnamed: 0,session_id,item_id,rank
0,26,18227,1
0,26,22922,2
0,26,14130,3
0,26,16836,4
0,26,21395,5
...,...,...,...
49999,4439757,22922,96
49999,4439757,21941,97
49999,4439757,16972,98
49999,4439757,13430,99
