In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

total_df = pd.read_pickle("../../data/processed/rating_engage.pkl")
rating_df = pd.read_pickle("../../data/processed/rating_session.pkl")
rating_df

Unnamed: 0,user_id,date,item_id,rating
0,0,2014-05-15,"[15400, 8141, 26820, 39007]","[5.0, 5.0, 5.0, 5.0]"
1,0,2014-05-15,"[8141, 26820, 39007, 4646]","[5.0, 5.0, 5.0, 3.0]"
2,0,2014-12-06,"[18665, 21455, 23236, 21297]","[5.0, 5.0, 5.0, 5.0]"
3,0,2015-02-11,"[25341, 50734, 59076, 12715]","[4.0, 5.0, 2.0, 2.0]"
4,0,2015-02-11,"[50734, 59076, 12715, 13616]","[5.0, 2.0, 2.0, 5.0]"
...,...,...,...,...
37260,1996,2020-01-28,"[31772, 76693, 49373, 47628]","[5.0, 2.0, 5.0, 5.0]"
37261,1998,2014-12-03,"[27463, 5156, 28049, 24243]","[5.0, 5.0, 4.0, 5.0]"
37262,1998,2018-12-26,"[57897, 83702, 54822, 55331]","[5.0, 5.0, 5.0, 5.0]"
37263,1998,2019-02-10,"[55864, 57289, 55227, 55751]","[4.0, 4.0, 3.0, 5.0]"


In [2]:
total_df

Unnamed: 0,user_id,item_id,rating,item_len
0,0,556,2.0,93
0,0,843,2.0,93
0,0,1039,5.0,93
0,0,3865,5.0,93
0,0,4646,3.0,93
...,...,...,...,...
1999,1999,72419,0,106
1999,1999,82072,0,106
1999,1999,7212,0,106
1999,1999,19116,0,106


In [3]:
len(total_df["item_id"].unique()), total_df["item_id"].max()

(97718, 97717)

In [4]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

item_sequences = rating_df['item_id'].tolist()  # item_id를 리스트로 변환
user_ids = rating_df['user_id'].tolist()


class SessionDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        return torch.tensor(sequence[:-1]), torch.tensor(sequence[-1])
    
train_data, test_data = train_test_split(item_sequences, test_size=0.2, random_state=42)

# Hyperparameters
embedding_dim = 128
hidden_dim = 128
num_epochs = 1000
learning_rate = 5e-5
batch_size = 512
# 512 1e-4 0.4369
# 512 53-5 0.4404

train_loader = DataLoader(SessionDataset(train_data), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(SessionDataset(test_data), batch_size=batch_size, shuffle=False)

In [5]:
import torch.optim as optim
import sys
sys.path.append("../")
from model import GRURecommender

num_items = total_df["item_id"].max() + 1

model = GRURecommender(num_items, embedding_dim, hidden_dim).cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [6]:
from sklearn.metrics import precision_score, recall_score
from tqdm import tqdm
from copy import deepcopy
import warnings
warnings.filterwarnings('always')

model_parameters = deepcopy(model.state_dict())

def precision_at_k(preds, target, k=20):
    top_k_preds = preds.topk(k, dim=1).indices.cpu().numpy()
    target = target.cpu().numpy()
    
    y_true = np.isin(top_k_preds, target[:, None]) 
    y_pred = np.ones_like(y_true)  
    
    precision_scores = []
    for true, pred in zip(y_true, y_pred):
        if np.sum(pred) == 0:  
            precision_scores.append(0)
        else:
            precision_scores.append(precision_score(true, pred, zero_division=0))
    return np.mean(precision_scores)

def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, num_epochs, early_stopping_patience=5):
    global model_parameters
    best_val_precision = -float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for inputs, target in tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs.cuda())
            loss = criterion(outputs, target.cuda())

            loss.backward()
            optimizer.step()
            total_loss += loss.detach().cpu().item()

        val_loss, val_precision = evaluate(model, val_loader, criterion)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}, '
              f'Val Loss: {val_loss:.4f}, Precision@20: {val_precision:.4f}')
        
        if val_precision > best_val_precision:
            best_val_precision = val_precision
            patience_counter = 0
            model_parameters = deepcopy(model.state_dict())
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered")
                break

def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss = 0
    total_precision = 0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, target in tqdm(data_loader):
            outputs = model(inputs.cuda())
            loss = criterion(outputs, target.cuda())
            total_loss += loss.detach().cpu().item()
            
            total_precision += precision_at_k(outputs, target, k=20) * inputs.size(0)
            total_samples += inputs.size(0)
            
    avg_loss = total_loss / len(data_loader)
    avg_precision = total_precision / total_samples
    return avg_loss, avg_precision

train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs)


100%|██████████| 59/59 [00:01<00:00, 40.83it/s]
100%|██████████| 15/15 [00:09<00:00,  1.51it/s]


Epoch [1/1000], Loss: 11.4936, Val Loss: 11.4865, Precision@20: 0.0072


100%|██████████| 59/59 [00:01<00:00, 45.03it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [2/1000], Loss: 11.4459, Val Loss: 11.4755, Precision@20: 0.0150


100%|██████████| 59/59 [00:01<00:00, 44.68it/s]
100%|██████████| 15/15 [00:09<00:00,  1.56it/s]


Epoch [3/1000], Loss: 11.3967, Val Loss: 11.4640, Precision@20: 0.0287


100%|██████████| 59/59 [00:01<00:00, 44.60it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [4/1000], Loss: 11.3436, Val Loss: 11.4515, Precision@20: 0.0461


100%|██████████| 59/59 [00:01<00:00, 41.76it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [5/1000], Loss: 11.2842, Val Loss: 11.4377, Precision@20: 0.0632


100%|██████████| 59/59 [00:01<00:00, 44.32it/s]
100%|██████████| 15/15 [00:09<00:00,  1.56it/s]


Epoch [6/1000], Loss: 11.2178, Val Loss: 11.4221, Precision@20: 0.0809


100%|██████████| 59/59 [00:01<00:00, 44.88it/s]
100%|██████████| 15/15 [00:09<00:00,  1.56it/s]


Epoch [7/1000], Loss: 11.1402, Val Loss: 11.4042, Precision@20: 0.0971


100%|██████████| 59/59 [00:01<00:00, 45.05it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [8/1000], Loss: 11.0493, Val Loss: 11.3831, Precision@20: 0.1119


100%|██████████| 59/59 [00:01<00:00, 44.81it/s]
100%|██████████| 15/15 [00:09<00:00,  1.55it/s]


Epoch [9/1000], Loss: 10.9380, Val Loss: 11.3572, Precision@20: 0.1276


100%|██████████| 59/59 [00:01<00:00, 44.50it/s]
100%|██████████| 15/15 [00:09<00:00,  1.56it/s]


Epoch [10/1000], Loss: 10.7962, Val Loss: 11.3248, Precision@20: 0.1430


100%|██████████| 59/59 [00:01<00:00, 40.91it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [11/1000], Loss: 10.6067, Val Loss: 11.2843, Precision@20: 0.1619


100%|██████████| 59/59 [00:01<00:00, 44.27it/s]
100%|██████████| 15/15 [00:09<00:00,  1.55it/s]


Epoch [12/1000], Loss: 10.3484, Val Loss: 11.2448, Precision@20: 0.1839


100%|██████████| 59/59 [00:01<00:00, 44.00it/s]
100%|██████████| 15/15 [00:09<00:00,  1.55it/s]


Epoch [13/1000], Loss: 10.0335, Val Loss: 11.2478, Precision@20: 0.2082


100%|██████████| 59/59 [00:01<00:00, 44.16it/s]
100%|██████████| 15/15 [00:09<00:00,  1.57it/s]


Epoch [14/1000], Loss: 9.7522, Val Loss: 11.3254, Precision@20: 0.2341


100%|██████████| 59/59 [00:01<00:00, 44.34it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [15/1000], Loss: 9.5744, Val Loss: 11.4197, Precision@20: 0.2544


100%|██████████| 59/59 [00:01<00:00, 44.47it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [16/1000], Loss: 9.4634, Val Loss: 11.4886, Precision@20: 0.2656


100%|██████████| 59/59 [00:01<00:00, 44.07it/s]
100%|██████████| 15/15 [00:09<00:00,  1.55it/s]


Epoch [17/1000], Loss: 9.3814, Val Loss: 11.5402, Precision@20: 0.2738


100%|██████████| 59/59 [00:01<00:00, 43.89it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [18/1000], Loss: 9.3126, Val Loss: 11.5850, Precision@20: 0.2788


100%|██████████| 59/59 [00:01<00:00, 43.99it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [19/1000], Loss: 9.2500, Val Loss: 11.6290, Precision@20: 0.2822


100%|██████████| 59/59 [00:01<00:00, 44.52it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [20/1000], Loss: 9.1943, Val Loss: 11.6690, Precision@20: 0.2860


100%|██████████| 59/59 [00:01<00:00, 44.01it/s]
100%|██████████| 15/15 [00:09<00:00,  1.53it/s]


Epoch [21/1000], Loss: 9.1429, Val Loss: 11.7096, Precision@20: 0.2873


100%|██████████| 59/59 [00:01<00:00, 41.11it/s]
100%|██████████| 15/15 [00:09<00:00,  1.53it/s]


Epoch [22/1000], Loss: 9.0958, Val Loss: 11.7491, Precision@20: 0.2879


100%|██████████| 59/59 [00:01<00:00, 44.41it/s]
100%|██████████| 15/15 [00:09<00:00,  1.53it/s]


Epoch [23/1000], Loss: 9.0513, Val Loss: 11.7857, Precision@20: 0.2881


100%|██████████| 59/59 [00:01<00:00, 43.48it/s]
100%|██████████| 15/15 [00:09<00:00,  1.53it/s]


Epoch [24/1000], Loss: 9.0066, Val Loss: 11.8247, Precision@20: 0.2865


100%|██████████| 59/59 [00:01<00:00, 43.92it/s]
100%|██████████| 15/15 [00:09<00:00,  1.54it/s]


Epoch [25/1000], Loss: 8.9679, Val Loss: 11.8611, Precision@20: 0.2864


100%|██████████| 59/59 [00:01<00:00, 43.92it/s]
100%|██████████| 15/15 [00:09<00:00,  1.53it/s]


Epoch [26/1000], Loss: 8.9269, Val Loss: 11.8959, Precision@20: 0.2854


100%|██████████| 59/59 [00:01<00:00, 43.79it/s]
100%|██████████| 15/15 [00:09<00:00,  1.53it/s]


Epoch [27/1000], Loss: 8.8897, Val Loss: 11.9322, Precision@20: 0.2853


100%|██████████| 59/59 [00:01<00:00, 40.82it/s]
100%|██████████| 15/15 [00:09<00:00,  1.52it/s]

Epoch [28/1000], Loss: 8.8516, Val Loss: 11.9683, Precision@20: 0.2844
Early stopping triggered





In [7]:
model.load_state_dict(model_parameters)

avg_loss, avg_precision = evaluate(model, test_loader, criterion)
print(avg_loss, avg_precision)

100%|██████████| 15/15 [00:09<00:00,  1.53it/s]

11.78567492167155 0.28807191734871856





In [8]:
torch.save({
    "state_dict": model.cpu().state_dict(),
    "num_items": num_items,
    "embedding_dim":embedding_dim,
    "hidden_dim": hidden_dim
}, "../parameters/session.pth")