In [1]:
pip install polars

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import ast
import numpy as np
from transformers import BertModel, BertConfig
from sklearn.model_selection import train_test_split
import polars as pl
import pickle

In [3]:
df = pl.read_parquet('/kaggle/input/cxc-2025-rnn-data/rnn_data.parquet')
df.head()

user_id,session_id,event_type_mapped_list,target
str,str,list[i64],i64
"""afe99d2f-4fce-4584-a360-967b87…","""1715551789566""",[553],634
"""de762acc-c1cd-4308-8e5e-80ba1d…","""1721053697224""",[349],531
"""6ddede71-f391-48ba-9d87-32cf6b…","""1730825985104""",[349],238
"""6ddede71-f391-48ba-9d87-32cf6b…","""1730825985104""","[349, 238]",691
"""6ddede71-f391-48ba-9d87-32cf6b…","""1730825985104""","[349, 238, 691]",490


In [2]:
with open('/kaggle/input/cxc-2025-rnn-data/idx_to_event.pkl', 'rb') as f:
    idx_to_event = pickle.load(f, encoding='latin1')

# Convert Polars columns to Python lists for compatibility with the dataset (potientially use for different)
sequences = df["event_type_mapped_list"].to_list()  # Converts to list of lists
targets = df["target"].to_list()  # Converts to list of integers

# Define hyperparameters
num_items = len(idx_to_event) + 1
hidden_size = 128
num_layers = 2
num_heads = 2
max_seq_len = 10  # Adjusted for longer sequences if needed
learning_rate = 1e-4
batch_size = 32
num_epochs = 50

# Define the dataset class
class BERT4RecDataset(Dataset):
    def __init__(self, df, max_seq_len=3):
        self.sequences = df["event_type_mapped_list"].to_list()
        self.targets = df["target"].to_list()
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]      # e.g. [10, 22]
        target = self.targets[idx]     # e.g. 15

        # 1) Convert to list if not already
        if not isinstance(seq, list):
            seq = []

        # 2) Pad or truncate to length = max_seq_len
        seq_len = len(seq)
        if seq_len < self.max_seq_len:
            seq = [740] * (self.max_seq_len - seq_len) + seq
        else:
            seq = seq[-self.max_seq_len:]

        # 3) Convert to tensors
        input_seq = torch.tensor(seq, dtype=torch.long)       # shape: (max_seq_len,)
        target_action = torch.tensor(target, dtype=torch.long) # shape: ()

        # 4) Return
        return input_seq, target_action

# Define the BERT4Rec model
class BERT4Rec(nn.Module):
    def __init__(self, num_items, hidden_size, num_layers, num_heads, max_seq_len):
        super(BERT4Rec, self).__init__()
        
        self.config = BertConfig(
            vocab_size=num_items,
            hidden_size=hidden_size,
            num_hidden_layers=num_layers,
            num_attention_heads=num_heads,
            intermediate_size=hidden_size * 4,
            max_position_embeddings=max_seq_len,
            pad_token_id=740
        )
        
        self.bert = BertModel(self.config)
        self.fc = nn.Linear(hidden_size, num_items)

    def forward(self, input_ids):
        attention_mask = (input_ids != 0).long()
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        last_token_hidden = last_hidden_state[:, -1, :]  # (batch_size, hidden_size)
        logits = self.fc(last_token_hidden)              # (batch_size, num_items)
        return logits

df = df.filter(
    pl.col('user_id') != 'EMPTY'
).with_columns(
    pl.col('user_id').map_elements(lambda x: x[-1], return_dtype=pl.String).alias('train')
)

train_df = df
test_df = df.filter(
    (pl.col('train') == 'e') | (pl.col('train') == '6')
)

val_df = df.filter(
    (pl.col('train') == 'f') | (pl.col('train') == '9')
)

# Create datasets
train_dataset = BERT4RecDataset(train_df, max_seq_len)
val_dataset = BERT4RecDataset(val_df, max_seq_len)
test_dataset = BERT4RecDataset(test_df, max_seq_len)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Create model, optimizer, etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERT4Rec(num_items, hidden_size, num_layers, num_heads, max_seq_len).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0.0
    
    for input_seq, target_action in train_loader:
        input_seq = input_seq.to(device)
        target_action = target_action.to(device)

        optimizer.zero_grad()
        logits = model(input_seq)                 # (batch_size, num_items)
        loss = criterion(logits, target_action)    # raw logits + CrossEntropyLoss
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)

    # Validation Loop
    model.eval()
    total_val_loss = 0.0

    with torch.no_grad():
        for input_seq, target_action in val_loader:
            input_seq = input_seq.to(device)
            target_action = target_action.to(device)

            logits = model(input_seq)
            loss = criterion(logits, target_action)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"[Epoch {epoch+1}/{num_epochs}] "
          f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

# Test Loop
model.eval()
total_test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for input_seq, target_action in test_loader:
        input_seq = input_seq.to(device)
        target_action = target_action.to(device)

        logits = model(input_seq)
        loss = criterion(logits, target_action)
        total_test_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(logits, 1)
        correct += (predicted == target_action).sum().item()
        total += target_action.size(0)

avg_test_loss = total_test_loss / len(test_loader)
accuracy = correct / total
print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {accuracy:.4f}")

[Epoch 1/20] Train Loss: 1.1087, Val Loss: 1.0224
[Epoch 2/20] Train Loss: 0.9848, Val Loss: 0.9857
[Epoch 3/20] Train Loss: 0.9585, Val Loss: 0.9679
[Epoch 4/20] Train Loss: 0.9442, Val Loss: 0.9563
[Epoch 5/20] Train Loss: 0.9346, Val Loss: 0.9502
[Epoch 6/20] Train Loss: 0.9275, Val Loss: 0.9422
[Epoch 7/20] Train Loss: 0.9223, Val Loss: 0.9383
[Epoch 8/20] Train Loss: 0.9185, Val Loss: 0.9380
[Epoch 9/20] Train Loss: 0.9149, Val Loss: 0.9311
[Epoch 10/20] Train Loss: 0.9123, Val Loss: 0.9308
[Epoch 11/20] Train Loss: 0.9099, Val Loss: 0.9276
[Epoch 12/20] Train Loss: 0.9079, Val Loss: 0.9266
[Epoch 13/20] Train Loss: 0.9062, Val Loss: 0.9233
[Epoch 14/20] Train Loss: 0.9045, Val Loss: 0.9263
[Epoch 15/20] Train Loss: 0.9033, Val Loss: 0.9194
[Epoch 16/20] Train Loss: 0.9023, Val Loss: 0.9262
[Epoch 17/20] Train Loss: 0.9007, Val Loss: 0.9238
[Epoch 18/20] Train Loss: 0.9000, Val Loss: 0.9202
[Epoch 19/20] Train Loss: 0.8989, Val Loss: 0.9176
[Epoch 20/20] Train Loss: 0.8983, Val Lo

In [18]:
def recommend_next_action(user_sequence):
    """Predict next-action probabilities for a given user sequence,
       then return top-k predictions and their probabilities."""
    model.eval()
    with torch.no_grad():
        # Convert user_sequence to a batched tensor of shape (1, max_seq_len)
        seq_len = len(user_sequence)
        if seq_len < max_seq_len:
            user_sequence = [0]*(max_seq_len - seq_len) + user_sequence
        else:
            user_sequence = user_sequence[-max_seq_len:]

        input_seq = torch.tensor(user_sequence, dtype=torch.long).unsqueeze(0).to(device)
        logits = model(input_seq)  # shape (1, num_items)

        # Convert logits to probabilities
        probabilities = torch.softmax(logits, dim=-1)  # shape (1, num_items)

        # Top-k next items
        top_k = 1
        topk_values, topk_indices = torch.topk(probabilities, k=top_k, dim=-1)  # shape (1, k) each
        topk_values = topk_values.squeeze(0).cpu().numpy()
        topk_indices = topk_indices.squeeze(0).cpu().numpy()

    return topk_indices, topk_values

user_seq = [1, 2]  # user has done events 1, then 2
predicted_actions, probs = recommend_next_action(user_seq)

In [19]:
predicted_actions, probs

(array([460]), array([0.4038925], dtype=float32))

In [5]:
torch.save(model.state_dict(), "bert4rec_model.pth")

In [21]:
user_sequence = [1,2]
seq_len = len(user_sequence)
if seq_len < max_seq_len:
    user_sequence = [0]*(max_seq_len - seq_len) + user_sequence
else:
    user_sequence = user_sequence[-max_seq_len:]
input_seq = torch.tensor(user_sequence, dtype=torch.long).unsqueeze(0).to(device)
logits = model(input_seq)
probabilities = torch.softmax(logits, dim=-1)
probabilities_rounded = torch.round(probabilities * 100) / 100

In [23]:
topk_values, topk_indices = torch.topk(probabilities_rounded, k=5, dim=-1)
topk_values, topk_indices

(tensor([[0.4000, 0.2100, 0.0400, 0.0400, 0.0300]], device='cuda:0',
        grad_fn=<TopkBackward0>),
 tensor([[460, 553,  20,  16, 607]], device='cuda:0'))

In [26]:
model_load = torch.load('/kaggle/working/bert4rec_model.pth', weights_only=False)

In [25]:
model(input_seq)

tensor([[ -7.1945,  -5.3142,  -6.6874, -13.1927,  -6.6779,  -2.8280,  -8.4970,
         -13.1918,  -4.5702,  -5.2656, -13.1918,  -8.6398,  -0.7440,  -8.2066,
         -13.1928,  -7.1815,   1.2277,  -4.4531, -12.5453, -13.1909,   1.2523,
          -1.9722, -12.7996,  -4.5923,  -3.7978, -10.3480, -12.4543,  -4.2849,
         -13.1918, -13.1919, -13.1914, -13.1923,  -4.4153, -12.6341,  -8.4430,
          -6.8269,  -5.7460, -13.1912,  -4.6628, -13.1926, -10.9377, -13.1928,
          -7.7358,  -3.3614, -13.1925,  -7.6509,  -7.0793,  -6.0046, -13.1913,
         -13.1915,  -4.3490,  -3.4037,  -3.5399, -12.6127, -13.1927,  -0.0771,
          -5.0143,  -6.6612,  -8.1107, -13.1916,  -8.8396, -13.1913, -13.1919,
         -13.1925, -13.1921, -13.1922, -13.1918, -13.1909, -13.1917, -13.1922,
         -13.1909, -13.1914, -13.1921, -13.1917, -11.9839, -13.1931,  -4.0807,
          -4.7916,  -0.0411,  -2.6555,  -8.6739,  -5.7545, -13.1921,  -8.6077,
         -10.2615, -13.1925, -13.1913, -11.9031,  -3