In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

total_df = pd.read_pickle("../data/engage_rating.pkl")
rating_df = pd.read_pickle("../data/sessoin_rating.pkl")
rating_df

Unnamed: 0,user_id,date,item_id,rating
0,0,2022-07-22,"[71482, 53839, 32319, 64573]","[5.0, 5.0, 1.0, 5.0]"
1,0,2022-07-22,"[53839, 32319, 64573, 44799]","[5.0, 1.0, 5.0, 5.0]"
2,0,2022-07-22,"[32319, 64573, 44799, 56313]","[1.0, 5.0, 5.0, 5.0]"
3,0,2022-07-22,"[64573, 44799, 56313, 40697]","[5.0, 5.0, 5.0, 1.0]"
4,0,2022-07-22,"[44799, 56313, 40697, 41653]","[5.0, 5.0, 1.0, 1.0]"
...,...,...,...,...
21798,997,2023-03-20,"[68239, 68144, 67857, 62937]","[5.0, 5.0, 5.0, 5.0]"
21799,999,2014-12-03,"[16923, 19623, 19207, 3573]","[5.0, 4.0, 5.0, 5.0]"
21800,999,2018-12-26,"[40583, 38774, 38410, 61047]","[5.0, 5.0, 5.0, 5.0]"
21801,999,2019-02-10,"[40158, 39077, 38698, 42245]","[4.0, 5.0, 3.0, 5.0]"


In [2]:
total_df

Unnamed: 0,user_id,item_id,parent_asin,rating,timestamp
0,0,28833,B01G15TGCU,1.0,1491601677000
0,0,29361,B01HRPJQ2S,5.0,1475448112000
0,0,32319,B06XQYN77L,1.0,1658498138614
0,0,37990,B07F9SG3RX,5.0,1673250487226
0,0,40014,B07L4GYFK9,5.0,1677469171342
...,...,...,...,...,...
999,999,69440,B0BYSG291N,5.0,1559594592832
999,999,70383,B0C37XY2JZ,5.0,1540929869776
999,999,71229,B0C662M3GG,5.0,1547336147049
999,999,71551,B0C78KPQYH,5.0,1552672042595


In [3]:
len(total_df["item_id"].unique()), total_df["item_id"].max()

(72319, 72318)

In [4]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

item_sequences = rating_df['item_id'].tolist()  # item_id를 리스트로 변환
user_ids = rating_df['user_id'].tolist()


class SessionDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        return torch.tensor(sequence[:-1]), torch.tensor(sequence[-1])
    
train_data, test_data = train_test_split(item_sequences, test_size=0.2, random_state=42)

# Hyperparameters
embedding_dim = 128
hidden_dim = 128
num_epochs = 1000
learning_rate = 5e-5
batch_size = 512
# 512 1e-4 0.4369
# 512 53-5 0.4404

train_loader = DataLoader(SessionDataset(train_data), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(SessionDataset(test_data), batch_size=batch_size, shuffle=False)

In [5]:
import torch.optim as optim
class GRURecommender(nn.Module):
    def __init__(self, num_items, embedding_dim, hidden_dim):
        super(GRURecommender, self).__init__()
        self.embedding = nn.Embedding(num_items, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_items)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.gru(x)
        x = self.relu(x[:, -1, :])
        x = self.fc(x)
        return x

num_items = 368228

model = GRURecommender(num_items, embedding_dim, hidden_dim).cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [6]:
from sklearn.metrics import precision_score, recall_score
from tqdm import tqdm
from copy import deepcopy
import warnings
warnings.filterwarnings('always')

model_parameters = deepcopy(model.state_dict())

def precision_at_k(preds, target, k=20):
    top_k_preds = preds.topk(k, dim=1).indices.cpu().numpy()
    target = target.cpu().numpy()
    
    y_true = np.isin(top_k_preds, target[:, None]) 
    y_pred = np.ones_like(y_true)  
    
    precision_scores = []
    for true, pred in zip(y_true, y_pred):
        if np.sum(pred) == 0:  
            precision_scores.append(0)
        else:
            precision_scores.append(precision_score(true, pred, zero_division=0))
    return np.mean(precision_scores)

def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, num_epochs, early_stopping_patience=5):
    global model_parameters
    best_val_precision = -float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for inputs, target in tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs.cuda())
            loss = criterion(outputs, target.cuda())

            loss.backward()
            optimizer.step()
            total_loss += loss.detach().cpu().item()

        val_loss, val_precision = evaluate(model, val_loader, criterion)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}, '
              f'Val Loss: {val_loss:.4f}, Precision@20: {val_precision:.4f}')
        
        if val_precision > best_val_precision:
            best_val_precision = val_precision
            patience_counter = 0
            model_parameters = deepcopy(model.state_dict())
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered")
                break

def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss = 0
    total_precision = 0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, target in tqdm(data_loader):
            outputs = model(inputs.cuda())
            loss = criterion(outputs, target.cuda())
            total_loss += loss.detach().cpu().item()
            
            total_precision += precision_at_k(outputs, target, k=20) * inputs.size(0)
            total_samples += inputs.size(0)
            
    avg_loss = total_loss / len(data_loader)
    avg_precision = total_precision / total_samples
    return avg_loss, avg_precision

train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs)


  0%|          | 0/35 [00:00<?, ?it/s]

100%|██████████| 35/35 [00:02<00:00, 14.50it/s]
100%|██████████| 9/9 [00:06<00:00,  1.41it/s]


Epoch [1/1000], Loss: 12.8237, Val Loss: 12.8197, Precision@20: 0.0018


100%|██████████| 35/35 [00:02<00:00, 14.88it/s]
100%|██████████| 9/9 [00:06<00:00,  1.47it/s]


Epoch [2/1000], Loss: 12.7924, Val Loss: 12.8125, Precision@20: 0.0028


100%|██████████| 35/35 [00:02<00:00, 15.24it/s]
100%|██████████| 9/9 [00:06<00:00,  1.48it/s]


Epoch [3/1000], Loss: 12.7610, Val Loss: 12.8053, Precision@20: 0.0052


100%|██████████| 35/35 [00:02<00:00, 15.02it/s]
100%|██████████| 9/9 [00:06<00:00,  1.49it/s]


Epoch [4/1000], Loss: 12.7263, Val Loss: 12.7978, Precision@20: 0.0093


100%|██████████| 35/35 [00:02<00:00, 14.87it/s]
100%|██████████| 9/9 [00:06<00:00,  1.48it/s]


Epoch [5/1000], Loss: 12.6925, Val Loss: 12.7898, Precision@20: 0.0164


100%|██████████| 35/35 [00:02<00:00, 14.91it/s]
100%|██████████| 9/9 [00:06<00:00,  1.50it/s]


Epoch [6/1000], Loss: 12.6555, Val Loss: 12.7815, Precision@20: 0.0260


100%|██████████| 35/35 [00:02<00:00, 15.65it/s]
100%|██████████| 9/9 [00:06<00:00,  1.48it/s]


Epoch [7/1000], Loss: 12.6172, Val Loss: 12.7725, Precision@20: 0.0379


100%|██████████| 35/35 [00:02<00:00, 15.65it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [8/1000], Loss: 12.5732, Val Loss: 12.7627, Precision@20: 0.0483


100%|██████████| 35/35 [00:02<00:00, 15.63it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [9/1000], Loss: 12.5275, Val Loss: 12.7520, Precision@20: 0.0591


100%|██████████| 35/35 [00:02<00:00, 15.63it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [10/1000], Loss: 12.4780, Val Loss: 12.7403, Precision@20: 0.0690


100%|██████████| 35/35 [00:02<00:00, 15.58it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [11/1000], Loss: 12.4215, Val Loss: 12.7273, Precision@20: 0.0784


100%|██████████| 35/35 [00:02<00:00, 15.74it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [12/1000], Loss: 12.3610, Val Loss: 12.7126, Precision@20: 0.0872


100%|██████████| 35/35 [00:02<00:00, 14.86it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [13/1000], Loss: 12.2905, Val Loss: 12.6957, Precision@20: 0.0970


100%|██████████| 35/35 [00:02<00:00, 15.56it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [14/1000], Loss: 12.2125, Val Loss: 12.6760, Precision@20: 0.1060


100%|██████████| 35/35 [00:02<00:00, 15.58it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [15/1000], Loss: 12.1184, Val Loss: 12.6523, Precision@20: 0.1154


100%|██████████| 35/35 [00:02<00:00, 15.60it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [16/1000], Loss: 12.0069, Val Loss: 12.6231, Precision@20: 0.1244


100%|██████████| 35/35 [00:02<00:00, 15.59it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [17/1000], Loss: 11.8727, Val Loss: 12.5861, Precision@20: 0.1329


100%|██████████| 35/35 [00:02<00:00, 15.60it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [18/1000], Loss: 11.6956, Val Loss: 12.5377, Precision@20: 0.1436


100%|██████████| 35/35 [00:02<00:00, 15.52it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [19/1000], Loss: 11.4718, Val Loss: 12.4733, Precision@20: 0.1530


100%|██████████| 35/35 [00:02<00:00, 15.59it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [20/1000], Loss: 11.1769, Val Loss: 12.3914, Precision@20: 0.1636


100%|██████████| 35/35 [00:02<00:00, 15.10it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [21/1000], Loss: 10.8134, Val Loss: 12.3028, Precision@20: 0.1741


100%|██████████| 35/35 [00:02<00:00, 15.56it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [22/1000], Loss: 10.4063, Val Loss: 12.2410, Precision@20: 0.1878


100%|██████████| 35/35 [00:02<00:00, 15.44it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [23/1000], Loss: 10.0382, Val Loss: 12.2384, Precision@20: 0.1997


100%|██████████| 35/35 [00:02<00:00, 15.75it/s]
100%|██████████| 9/9 [00:05<00:00,  1.50it/s]


Epoch [24/1000], Loss: 9.7682, Val Loss: 12.2871, Precision@20: 0.2088


100%|██████████| 35/35 [00:02<00:00, 15.66it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [25/1000], Loss: 9.5830, Val Loss: 12.3544, Precision@20: 0.2135


100%|██████████| 35/35 [00:02<00:00, 15.65it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [26/1000], Loss: 9.4552, Val Loss: 12.4182, Precision@20: 0.2185


100%|██████████| 35/35 [00:02<00:00, 15.72it/s]
100%|██████████| 9/9 [00:06<00:00,  1.49it/s]


Epoch [27/1000], Loss: 9.3536, Val Loss: 12.4703, Precision@20: 0.2213


100%|██████████| 35/35 [00:02<00:00, 15.69it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [28/1000], Loss: 9.2793, Val Loss: 12.5133, Precision@20: 0.2256


100%|██████████| 35/35 [00:02<00:00, 15.04it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [29/1000], Loss: 9.2153, Val Loss: 12.5500, Precision@20: 0.2294


100%|██████████| 35/35 [00:02<00:00, 15.60it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [30/1000], Loss: 9.1507, Val Loss: 12.5799, Precision@20: 0.2322


100%|██████████| 35/35 [00:02<00:00, 15.75it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [31/1000], Loss: 9.1015, Val Loss: 12.6102, Precision@20: 0.2362


100%|██████████| 35/35 [00:02<00:00, 15.64it/s]
100%|██████████| 9/9 [00:06<00:00,  1.48it/s]


Epoch [32/1000], Loss: 9.0560, Val Loss: 12.6370, Precision@20: 0.2386


100%|██████████| 35/35 [00:02<00:00, 15.41it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [33/1000], Loss: 9.0144, Val Loss: 12.6627, Precision@20: 0.2407


100%|██████████| 35/35 [00:02<00:00, 15.65it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [34/1000], Loss: 8.9701, Val Loss: 12.6893, Precision@20: 0.2418


100%|██████████| 35/35 [00:02<00:00, 15.75it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [35/1000], Loss: 8.9326, Val Loss: 12.7153, Precision@20: 0.2424


100%|██████████| 35/35 [00:02<00:00, 15.66it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [36/1000], Loss: 8.8953, Val Loss: 12.7400, Precision@20: 0.2424


100%|██████████| 35/35 [00:02<00:00, 15.13it/s]
100%|██████████| 9/9 [00:06<00:00,  1.49it/s]


Epoch [37/1000], Loss: 8.8526, Val Loss: 12.7655, Precision@20: 0.2432


100%|██████████| 35/35 [00:02<00:00, 15.74it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [38/1000], Loss: 8.8229, Val Loss: 12.7916, Precision@20: 0.2440


100%|██████████| 35/35 [00:02<00:00, 15.64it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [39/1000], Loss: 8.7924, Val Loss: 12.8143, Precision@20: 0.2447


100%|██████████| 35/35 [00:02<00:00, 15.65it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [40/1000], Loss: 8.7604, Val Loss: 12.8400, Precision@20: 0.2443


100%|██████████| 35/35 [00:02<00:00, 15.57it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [41/1000], Loss: 8.7335, Val Loss: 12.8613, Precision@20: 0.2451


100%|██████████| 35/35 [00:02<00:00, 15.54it/s]
100%|██████████| 9/9 [00:05<00:00,  1.50it/s]


Epoch [42/1000], Loss: 8.6998, Val Loss: 12.8861, Precision@20: 0.2457


100%|██████████| 35/35 [00:02<00:00, 15.58it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [43/1000], Loss: 8.6722, Val Loss: 12.9078, Precision@20: 0.2457


100%|██████████| 35/35 [00:02<00:00, 15.60it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [44/1000], Loss: 8.6398, Val Loss: 12.9298, Precision@20: 0.2469


100%|██████████| 35/35 [00:02<00:00, 15.16it/s]
100%|██████████| 9/9 [00:06<00:00,  1.49it/s]


Epoch [45/1000], Loss: 8.6151, Val Loss: 12.9548, Precision@20: 0.2471


100%|██████████| 35/35 [00:02<00:00, 15.58it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [46/1000], Loss: 8.5890, Val Loss: 12.9762, Precision@20: 0.2467


100%|██████████| 35/35 [00:02<00:00, 15.40it/s]
100%|██████████| 9/9 [00:06<00:00,  1.50it/s]


Epoch [47/1000], Loss: 8.5610, Val Loss: 12.9982, Precision@20: 0.2484


100%|██████████| 35/35 [00:02<00:00, 15.77it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [48/1000], Loss: 8.5380, Val Loss: 13.0203, Precision@20: 0.2482


100%|██████████| 35/35 [00:02<00:00, 15.69it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [49/1000], Loss: 8.5107, Val Loss: 13.0429, Precision@20: 0.2484


100%|██████████| 35/35 [00:02<00:00, 15.59it/s]
100%|██████████| 9/9 [00:06<00:00,  1.49it/s]


Epoch [50/1000], Loss: 8.4830, Val Loss: 13.0649, Precision@20: 0.2486


100%|██████████| 35/35 [00:02<00:00, 15.47it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [51/1000], Loss: 8.4638, Val Loss: 13.0870, Precision@20: 0.2486


100%|██████████| 35/35 [00:02<00:00, 15.68it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [52/1000], Loss: 8.4423, Val Loss: 13.1070, Precision@20: 0.2504


100%|██████████| 35/35 [00:02<00:00, 14.96it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [53/1000], Loss: 8.4198, Val Loss: 13.1286, Precision@20: 0.2500


100%|██████████| 35/35 [00:02<00:00, 15.74it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [54/1000], Loss: 8.3972, Val Loss: 13.1486, Precision@20: 0.2501


100%|██████████| 35/35 [00:02<00:00, 15.57it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [55/1000], Loss: 8.3695, Val Loss: 13.1696, Precision@20: 0.2505


100%|██████████| 35/35 [00:02<00:00, 15.59it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [56/1000], Loss: 8.3465, Val Loss: 13.1908, Precision@20: 0.2516


100%|██████████| 35/35 [00:02<00:00, 15.70it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [57/1000], Loss: 8.3261, Val Loss: 13.2126, Precision@20: 0.2515


100%|██████████| 35/35 [00:02<00:00, 15.33it/s]
100%|██████████| 9/9 [00:06<00:00,  1.49it/s]


Epoch [58/1000], Loss: 8.3043, Val Loss: 13.2334, Precision@20: 0.2503


100%|██████████| 35/35 [00:02<00:00, 15.58it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [59/1000], Loss: 8.2808, Val Loss: 13.2557, Precision@20: 0.2502


100%|██████████| 35/35 [00:02<00:00, 15.78it/s]
100%|██████████| 9/9 [00:05<00:00,  1.50it/s]


Epoch [60/1000], Loss: 8.2605, Val Loss: 13.2766, Precision@20: 0.2501


100%|██████████| 35/35 [00:02<00:00, 14.94it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]

Epoch [61/1000], Loss: 8.2388, Val Loss: 13.2963, Precision@20: 0.2505
Early stopping triggered





In [7]:
model.load_state_dict(model_parameters)

avg_loss, avg_precision = evaluate(model, test_loader, criterion)
print(avg_loss, avg_precision)

100%|██████████| 9/9 [00:06<00:00,  1.50it/s]

13.190789116753471 0.2515822059160743



