In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

total_df = pd.read_pickle("../../data/processed/rating_engage.pkl")
rating_df = pd.read_pickle("../../data/processed/rating_session.pkl")
rating_df

Unnamed: 0,user_id,date,item_id,rating
0,0,2022-07-22,"[71482, 53839, 32319, 64573]","[5.0, 5.0, 1.0, 5.0]"
1,0,2022-07-22,"[53839, 32319, 64573, 44799]","[5.0, 1.0, 5.0, 5.0]"
2,0,2022-07-22,"[32319, 64573, 44799, 56313]","[1.0, 5.0, 5.0, 5.0]"
3,0,2022-07-22,"[64573, 44799, 56313, 40697]","[5.0, 5.0, 5.0, 1.0]"
4,0,2022-07-22,"[44799, 56313, 40697, 41653]","[5.0, 5.0, 1.0, 1.0]"
...,...,...,...,...
21798,997,2023-03-20,"[68239, 68144, 67857, 62937]","[5.0, 5.0, 5.0, 5.0]"
21799,999,2014-12-03,"[16923, 19623, 19207, 3573]","[5.0, 4.0, 5.0, 5.0]"
21800,999,2018-12-26,"[40583, 38774, 38410, 61047]","[5.0, 5.0, 5.0, 5.0]"
21801,999,2019-02-10,"[40158, 39077, 38698, 42245]","[4.0, 5.0, 3.0, 5.0]"


In [2]:
total_df

Unnamed: 0,user_id,item_id,rating,item_len
0,0,28833,1.0,123
0,0,29361,5.0,123
0,0,32319,1.0,123
0,0,37990,5.0,123
0,0,40014,5.0,123
...,...,...,...,...
999,999,44930,0,193
999,999,14950,0,193
999,999,50080,0,193
999,999,22395,0,193


In [3]:
len(total_df["item_id"].unique()), total_df["item_id"].max()

(72319, 72318)

In [4]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

item_sequences = rating_df['item_id'].tolist()  # item_id를 리스트로 변환
user_ids = rating_df['user_id'].tolist()


class SessionDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        return torch.tensor(sequence[:-1]), torch.tensor(sequence[-1])
    
train_data, test_data = train_test_split(item_sequences, test_size=0.2, random_state=42)

# Hyperparameters
embedding_dim = 128
hidden_dim = 128
num_epochs = 1000
learning_rate = 5e-5
batch_size = 512
# 512 1e-4 0.4369
# 512 53-5 0.4404

train_loader = DataLoader(SessionDataset(train_data), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(SessionDataset(test_data), batch_size=batch_size, shuffle=False)

In [5]:
import torch.optim as optim
import sys
sys.path.append("../")
from model import GRURecommender

num_items = 368228

model = GRURecommender(num_items, embedding_dim, hidden_dim).cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [6]:
from sklearn.metrics import precision_score, recall_score
from tqdm import tqdm
from copy import deepcopy
import warnings
warnings.filterwarnings('always')

model_parameters = deepcopy(model.state_dict())

def precision_at_k(preds, target, k=20):
    top_k_preds = preds.topk(k, dim=1).indices.cpu().numpy()
    target = target.cpu().numpy()
    
    y_true = np.isin(top_k_preds, target[:, None]) 
    y_pred = np.ones_like(y_true)  
    
    precision_scores = []
    for true, pred in zip(y_true, y_pred):
        if np.sum(pred) == 0:  
            precision_scores.append(0)
        else:
            precision_scores.append(precision_score(true, pred, zero_division=0))
    return np.mean(precision_scores)

def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, num_epochs, early_stopping_patience=5):
    global model_parameters
    best_val_precision = -float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for inputs, target in tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs.cuda())
            loss = criterion(outputs, target.cuda())

            loss.backward()
            optimizer.step()
            total_loss += loss.detach().cpu().item()

        val_loss, val_precision = evaluate(model, val_loader, criterion)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}, '
              f'Val Loss: {val_loss:.4f}, Precision@20: {val_precision:.4f}')
        
        if val_precision > best_val_precision:
            best_val_precision = val_precision
            patience_counter = 0
            model_parameters = deepcopy(model.state_dict())
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered")
                break

def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss = 0
    total_precision = 0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, target in tqdm(data_loader):
            outputs = model(inputs.cuda())
            loss = criterion(outputs, target.cuda())
            total_loss += loss.detach().cpu().item()
            
            total_precision += precision_at_k(outputs, target, k=20) * inputs.size(0)
            total_samples += inputs.size(0)
            
    avg_loss = total_loss / len(data_loader)
    avg_precision = total_precision / total_samples
    return avg_loss, avg_precision

train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs)


100%|██████████| 35/35 [00:02<00:00, 14.86it/s]
100%|██████████| 9/9 [00:06<00:00,  1.49it/s]


Epoch [1/1000], Loss: 12.8215, Val Loss: 12.8184, Precision@20: 0.0012


100%|██████████| 35/35 [00:02<00:00, 15.89it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [2/1000], Loss: 12.7888, Val Loss: 12.8110, Precision@20: 0.0021


100%|██████████| 35/35 [00:02<00:00, 15.71it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [3/1000], Loss: 12.7563, Val Loss: 12.8035, Precision@20: 0.0042


100%|██████████| 35/35 [00:02<00:00, 15.72it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [4/1000], Loss: 12.7217, Val Loss: 12.7957, Precision@20: 0.0081


100%|██████████| 35/35 [00:02<00:00, 15.72it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [5/1000], Loss: 12.6858, Val Loss: 12.7876, Precision@20: 0.0143


100%|██████████| 35/35 [00:02<00:00, 15.80it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [6/1000], Loss: 12.6499, Val Loss: 12.7789, Precision@20: 0.0231


100%|██████████| 35/35 [00:02<00:00, 15.12it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [7/1000], Loss: 12.6093, Val Loss: 12.7696, Precision@20: 0.0333


100%|██████████| 35/35 [00:02<00:00, 15.66it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [8/1000], Loss: 12.5674, Val Loss: 12.7596, Precision@20: 0.0434


100%|██████████| 35/35 [00:02<00:00, 15.73it/s]
100%|██████████| 9/9 [00:05<00:00,  1.50it/s]


Epoch [9/1000], Loss: 12.5197, Val Loss: 12.7487, Precision@20: 0.0526


100%|██████████| 35/35 [00:02<00:00, 15.67it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [10/1000], Loss: 12.4684, Val Loss: 12.7367, Precision@20: 0.0614


100%|██████████| 35/35 [00:02<00:00, 15.87it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [11/1000], Loss: 12.4126, Val Loss: 12.7233, Precision@20: 0.0706


100%|██████████| 35/35 [00:02<00:00, 15.77it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [12/1000], Loss: 12.3493, Val Loss: 12.7081, Precision@20: 0.0794


100%|██████████| 35/35 [00:02<00:00, 15.74it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [13/1000], Loss: 12.2778, Val Loss: 12.6906, Precision@20: 0.0883


100%|██████████| 35/35 [00:02<00:00, 15.81it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [14/1000], Loss: 12.1956, Val Loss: 12.6700, Precision@20: 0.0974


100%|██████████| 35/35 [00:02<00:00, 15.71it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [15/1000], Loss: 12.0986, Val Loss: 12.6451, Precision@20: 0.1072


100%|██████████| 35/35 [00:02<00:00, 15.11it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [16/1000], Loss: 11.9845, Val Loss: 12.6141, Precision@20: 0.1167


100%|██████████| 35/35 [00:02<00:00, 15.73it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [17/1000], Loss: 11.8404, Val Loss: 12.5745, Precision@20: 0.1270


100%|██████████| 35/35 [00:02<00:00, 15.77it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [18/1000], Loss: 11.6575, Val Loss: 12.5222, Precision@20: 0.1376


100%|██████████| 35/35 [00:02<00:00, 15.68it/s]
100%|██████████| 9/9 [00:05<00:00,  1.51it/s]


Epoch [19/1000], Loss: 11.4189, Val Loss: 12.4525, Precision@20: 0.1490


100%|██████████| 35/35 [00:02<00:00, 15.78it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [20/1000], Loss: 11.1027, Val Loss: 12.3661, Precision@20: 0.1623


100%|██████████| 35/35 [00:02<00:00, 15.52it/s]
100%|██████████| 9/9 [00:06<00:00,  1.48it/s]


Epoch [21/1000], Loss: 10.7239, Val Loss: 12.2789, Precision@20: 0.1756


100%|██████████| 35/35 [00:02<00:00, 15.73it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [22/1000], Loss: 10.3162, Val Loss: 12.2301, Precision@20: 0.1876


100%|██████████| 35/35 [00:02<00:00, 15.81it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [23/1000], Loss: 9.9697, Val Loss: 12.2443, Precision@20: 0.1984


100%|██████████| 35/35 [00:02<00:00, 15.74it/s]
100%|██████████| 9/9 [00:05<00:00,  1.55it/s]


Epoch [24/1000], Loss: 9.7169, Val Loss: 12.3017, Precision@20: 0.2101


100%|██████████| 35/35 [00:02<00:00, 15.00it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [25/1000], Loss: 9.5473, Val Loss: 12.3715, Precision@20: 0.2219


100%|██████████| 35/35 [00:02<00:00, 15.80it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [26/1000], Loss: 9.4272, Val Loss: 12.4335, Precision@20: 0.2337


100%|██████████| 35/35 [00:02<00:00, 15.74it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [27/1000], Loss: 9.3465, Val Loss: 12.4838, Precision@20: 0.2421


100%|██████████| 35/35 [00:02<00:00, 15.73it/s]
100%|██████████| 9/9 [00:06<00:00,  1.50it/s]


Epoch [28/1000], Loss: 9.2686, Val Loss: 12.5248, Precision@20: 0.2455


100%|██████████| 35/35 [00:02<00:00, 15.80it/s]
100%|██████████| 9/9 [00:05<00:00,  1.54it/s]


Epoch [29/1000], Loss: 9.2091, Val Loss: 12.5575, Precision@20: 0.2466


100%|██████████| 35/35 [00:02<00:00, 15.80it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [30/1000], Loss: 9.1591, Val Loss: 12.5885, Precision@20: 0.2448


100%|██████████| 35/35 [00:02<00:00, 15.70it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [31/1000], Loss: 9.1077, Val Loss: 12.6173, Precision@20: 0.2458


100%|██████████| 35/35 [00:02<00:00, 15.81it/s]
100%|██████████| 9/9 [00:05<00:00,  1.53it/s]


Epoch [32/1000], Loss: 9.0601, Val Loss: 12.6439, Precision@20: 0.2463


100%|██████████| 35/35 [00:02<00:00, 15.66it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]


Epoch [33/1000], Loss: 9.0183, Val Loss: 12.6713, Precision@20: 0.2454


100%|██████████| 35/35 [00:02<00:00, 15.06it/s]
100%|██████████| 9/9 [00:05<00:00,  1.52it/s]

Epoch [34/1000], Loss: 8.9732, Val Loss: 12.6966, Precision@20: 0.2453
Early stopping triggered





In [7]:
model.load_state_dict(model_parameters)

avg_loss, avg_precision = evaluate(model, test_loader, criterion)
print(avg_loss, avg_precision)

100%|██████████| 9/9 [00:05<00:00,  1.52it/s]

12.557455168830025 0.24662921348314606





In [8]:
torch.save({
    "state_dict": model.cpu().state_dict(),
    "num_items": num_items,
    "embedding_dim":embedding_dim,
    "hidden_dim": hidden_dim
}, "../parameters/session.pth")