# 0. 기본 세팅

In [1]:
# !pip install -U sympy

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from collections import defaultdict
import datetime as dt

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
file_path = '/content/drive/MyDrive/transaction_categorized.csv'  # 파일 경로

purchase = pd.read_csv(file_path)
purchase_cycle = pd.read_csv('/content/drive/MyDrive/Purchase_Cycle.csv')

In [4]:
purchase

Unnamed: 0,Order Date,ASIN/ISBN (Product Code),Category,Survey ResponseID
0,2018-12-04,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_01vNIayewjIIKMF
1,2018-12-22,B01MA1MJ6H,"('Electronics', 'Headphones, Earbuds & Accesso...",R_01vNIayewjIIKMF
2,2018-12-25,B06XWF9HML,"('Beauty & Personal Care', 'Shave & Hair Remov...",R_01vNIayewjIIKMF
3,2018-12-25,B00837ZOI0,"('Beauty & Personal Care', 'Shave & Hair Remov...",R_01vNIayewjIIKMF
4,2019-02-18,B01GFB2E9M,"('Electronics', 'Computers & Accessories', 'Co...",R_01vNIayewjIIKMF
...,...,...,...,...
1593070,2021-04-01,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb
1593071,2021-04-14,B00QGCXPRG,"('Health & Household', 'Household Supplies', '...",R_zfqnsBzlOAKibzb
1593072,2021-05-22,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb
1593073,2021-12-01,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb


In [5]:
purchase_cycle

Unnamed: 0,Order Date,ASIN/ISBN (Product Code),Category,Survey ResponseID,Purchase_Cycle
0,2018-12-04,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_01vNIayewjIIKMF,127.32
1,2018-04-22,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_1LYvldXrvEt6Cel,127.32
2,2020-01-10,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_1QoPHtffEy2UzWJ,127.32
3,2018-02-27,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_1q4eVtHN10S5gPB,127.32
4,2018-10-06,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_1rpKvwP7VAtzJqH,127.32
...,...,...,...,...,...
1593070,2019-03-05,1475025289,"('Books', 'Religion & Spirituality', 'Worship ...",R_zfqnsBzlOAKibzb,11.38
1593071,2019-04-10,B001G0LC1E,"('Movies & TV', 'Featured Categories', 'DVD', ...",R_zfqnsBzlOAKibzb,11.38
1593072,2019-06-04,0674362810,"('Books', 'History', 'World', 'Religious', 'Ge...",R_zfqnsBzlOAKibzb,11.38
1593073,2019-06-04,B0000Y7L7G,"('Home & Kitchen', 'Kitchen & Dining', 'Kitche...",R_zfqnsBzlOAKibzb,11.38


# 1. 전처리

In [6]:
deduped_cycle = purchase_cycle.drop_duplicates(
    subset=['Order Date', 'ASIN/ISBN (Product Code)', 'Category', 'Survey ResponseID']
)

In [9]:
merge_keys = ['Order Date', 'ASIN/ISBN (Product Code)', 'Category', 'Survey ResponseID']
merged = purchase.merge(
    deduped_cycle[['Order Date', 'ASIN/ISBN (Product Code)', 'Category', 'Survey ResponseID', 'Purchase_Cycle']],
    on=merge_keys,
    how='left'
)
merged

Unnamed: 0,Order Date,ASIN/ISBN (Product Code),Category,Survey ResponseID,Purchase_Cycle
0,2018-12-04,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_01vNIayewjIIKMF,127.32
1,2018-12-22,B01MA1MJ6H,"('Electronics', 'Headphones, Earbuds & Accesso...",R_01vNIayewjIIKMF,127.32
2,2018-12-25,B06XWF9HML,"('Beauty & Personal Care', 'Shave & Hair Remov...",R_01vNIayewjIIKMF,43.66
3,2018-12-25,B00837ZOI0,"('Beauty & Personal Care', 'Shave & Hair Remov...",R_01vNIayewjIIKMF,233.71
4,2019-02-18,B01GFB2E9M,"('Electronics', 'Computers & Accessories', 'Co...",R_01vNIayewjIIKMF,127.32
...,...,...,...,...,...
1593070,2021-04-01,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb,127.32
1593071,2021-04-14,B00QGCXPRG,"('Health & Household', 'Household Supplies', '...",R_zfqnsBzlOAKibzb,127.32
1593072,2021-05-22,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb,127.32
1593073,2021-12-01,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb,127.32


In [10]:
# 고유값 추출
asin_list = merged['ASIN/ISBN (Product Code)'].unique()
cat_list = merged['Category'].astype(str).unique()
user_list = merged['Survey ResponseID'].unique()

# 딕셔너리로 인코딩
asin2id = {asin: i+1 for i, asin in enumerate(asin_list)}  # 0은 padding용
cat2id = {cat: i+1 for i, cat in enumerate(cat_list)}
user2id = {user: i for i, user in enumerate(user_list)}

merged['ASIN_ID'] = merged['ASIN/ISBN (Product Code)'].map(asin2id)
merged['Category_ID'] = merged['Category'].astype(str).map(cat2id)
merged['User_ID'] = merged['Survey ResponseID'].map(user2id)

In [11]:
merged

Unnamed: 0,Order Date,ASIN/ISBN (Product Code),Category,Survey ResponseID,Purchase_Cycle,ASIN_ID,Category_ID,User_ID
0,2018-12-04,B0143RTB1E,"('Electronics', 'Computers & Accessories', 'Co...",R_01vNIayewjIIKMF,127.32,1,1,0
1,2018-12-22,B01MA1MJ6H,"('Electronics', 'Headphones, Earbuds & Accesso...",R_01vNIayewjIIKMF,127.32,2,2,0
2,2018-12-25,B06XWF9HML,"('Beauty & Personal Care', 'Shave & Hair Remov...",R_01vNIayewjIIKMF,43.66,3,3,0
3,2018-12-25,B00837ZOI0,"('Beauty & Personal Care', 'Shave & Hair Remov...",R_01vNIayewjIIKMF,233.71,4,4,0
4,2019-02-18,B01GFB2E9M,"('Electronics', 'Computers & Accessories', 'Co...",R_01vNIayewjIIKMF,127.32,5,5,0
...,...,...,...,...,...,...,...,...
1593070,2021-04-01,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb,127.32,213637,2580,4690
1593071,2021-04-14,B00QGCXPRG,"('Health & Household', 'Household Supplies', '...",R_zfqnsBzlOAKibzb,127.32,791451,258,4690
1593072,2021-05-22,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb,127.32,213637,2580,4690
1593073,2021-12-01,B015ZRTHVA,"('Health & Household', 'Oral Care', 'Baby & Ch...",R_zfqnsBzlOAKibzb,127.32,213637,2580,4690


In [12]:
merged_clean = merged[['Order Date', 'ASIN_ID', 'Category_ID', 'User_ID', 'Purchase_Cycle']].copy()
merged_clean

Unnamed: 0,Order Date,ASIN_ID,Category_ID,User_ID,Purchase_Cycle
0,2018-12-04,1,1,0,127.32
1,2018-12-22,2,2,0,127.32
2,2018-12-25,3,3,0,43.66
3,2018-12-25,4,4,0,233.71
4,2019-02-18,5,5,0,127.32
...,...,...,...,...,...
1593070,2021-04-01,213637,2580,4690,127.32
1593071,2021-04-14,791451,258,4690,127.32
1593072,2021-05-22,213637,2580,4690,127.32
1593073,2021-12-01,213637,2580,4690,127.32


# 2. 시퀀스 생성

In [13]:
def generate_multimodal_sequences(df, min_len=6, max_len=50, step_size=3):
    """
    Generate random-length sequences with overlap for multimodal SASRec.
    Returns: list of (user_id, asin_seq, cat_seq, timestamp_seq)
    """
    import random

    grouped = df.groupby('User_ID')
    sequences = []

    for user_id, group in grouped:
        group_sorted = group.sort_values('Order Date')
        asin_seq = group_sorted['ASIN_ID'].tolist()
        cat_seq = group_sorted['Category_ID'].tolist()
        time_seq = pd.to_datetime(group_sorted['Order Date']).tolist()

        pos = 0
        L = len(asin_seq)

        while pos + min_len < L:
            max_possible_len = min(max_len, L - pos)
            if max_possible_len < min_len:
                break

            rand_len = random.randint(min_len, max_possible_len)

            asin_slice = asin_seq[pos:pos + rand_len]
            cat_slice = cat_seq[pos:pos + rand_len]
            time_slice = time_seq[pos:pos + rand_len]

            sequences.append((user_id, asin_slice, cat_slice, time_slice))
            pos += step_size

    return sequences

In [14]:
seqs = generate_multimodal_sequences(
    merged_clean,
    min_len=6,      # 최소한 train(4)+val+test = 6 필요
    max_len=50,
    step_size=3,
)

print(f"총 생성된 시퀀스 수: {len(seqs)}")

총 생성된 시퀀스 수: 523182


In [15]:
print(seqs[0])  # (user_id, [cat_id, ...])

(0, [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 5, 6, 7, 2], [Timestamp('2018-12-04 00:00:00'), Timestamp('2018-12-22 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-04-23 00:00:00'), Timestamp('2019-05-02 00:00:00'), Timestamp('2019-05-02 00:00:00')])


In [16]:
def split_multimodal_sequence_triplet(sequences, min_train_len=4):
    """
    For (user_id, asin_seq, cat_seq, time_seq),
    split into train/val/test triplets.
    """
    train_data, val_data, test_data = [], [], []

    for user_id, asin_seq, cat_seq, time_seq in sequences:
        if len(asin_seq) < min_train_len + 3:
            continue

        # 시퀀스 부분
        train_asin = asin_seq[:-3]
        val_asin   = asin_seq[:-2]
        test_asin  = asin_seq[:-1]

        train_cat = cat_seq[:-3]
        val_cat   = cat_seq[:-2]
        test_cat  = cat_seq[:-1]

        train_time = time_seq[:-3]
        val_time   = time_seq[:-2]
        test_time  = time_seq[:-1]

        # 타겟
        train_target_asin = asin_seq[-3]
        val_target_asin   = asin_seq[-2]
        test_target_asin  = asin_seq[-1]

        train_target_cat = cat_seq[-3]
        val_target_cat   = cat_seq[-2]
        test_target_cat  = cat_seq[-1]

        train_target_time = time_seq[-3]
        val_target_time   = time_seq[-2]
        test_target_time  = time_seq[-1]

        train_data.append((user_id, train_asin, train_cat, train_time,
                           train_target_asin, train_target_cat, train_target_time))
        val_data.append((user_id, val_asin, val_cat, val_time,
                         val_target_asin, val_target_cat, val_target_time))
        test_data.append((user_id, test_asin, test_cat, test_time,
                          test_target_asin, test_target_cat, test_target_time))

    return train_data, val_data, test_data

In [17]:
train_data, val_data, test_data = split_multimodal_sequence_triplet(seqs, min_train_len=4)

print(f"Train seqs: {len(train_data)}")
print(f"Val seqs:   {len(val_data)}")
print(f"Test seqs:  {len(test_data)}")

Train seqs: 507670
Val seqs:   507670
Test seqs:  507670


In [18]:
print(seqs[0])

(0, [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 5, 6, 7, 2], [Timestamp('2018-12-04 00:00:00'), Timestamp('2018-12-22 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-04-23 00:00:00'), Timestamp('2019-05-02 00:00:00'), Timestamp('2019-05-02 00:00:00')])


In [19]:
print(train_data[0])

(0, [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 5], [Timestamp('2018-12-04 00:00:00'), Timestamp('2018-12-22 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-02-18 00:00:00')], 7, 6, Timestamp('2019-04-23 00:00:00'))


In [20]:
print(val_data[0])

(0, [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 5, 6], [Timestamp('2018-12-04 00:00:00'), Timestamp('2018-12-22 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-04-23 00:00:00')], 8, 7, Timestamp('2019-05-02 00:00:00'))


In [21]:
print(test_data[0])

(0, [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 5, 6, 7], [Timestamp('2018-12-04 00:00:00'), Timestamp('2018-12-22 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-04-23 00:00:00'), Timestamp('2019-05-02 00:00:00')], 9, 2, Timestamp('2019-05-02 00:00:00'))


# 3. Dataset 클래스로 변환

In [22]:
class MultiModalSASRecDataset(Dataset):
    def __init__(self, data, num_asins, max_len=50, pad_id=0):
        self.data = data
        self.num_asins = num_asins
        self.max_len = max_len
        self.pad_id = pad_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        user_id, asin_seq, cat_seq, time_seq, target_asin, target_cat, target_time = self.data[idx]

        # 최근 max_len 개만 자르기
        asin_seq = asin_seq[-self.max_len:]
        cat_seq = cat_seq[-self.max_len:]
        time_seq = time_seq[-self.max_len:]

        # 패딩
        pad_len = self.max_len - len(asin_seq)
        padded_asin = [self.pad_id] * pad_len + asin_seq
        padded_cat = [self.pad_id if c is None else c for c in ([self.pad_id] * pad_len + cat_seq)]
        padded_time = [None] * pad_len + time_seq

        # ✅ negative asin 샘플링
        neg_asin = self.sample_negative(set(asin_seq + [target_asin]))

        return {
            "user_id": user_id,
            "asin_seq": torch.tensor(padded_asin, dtype=torch.long),
            "cat_seq": torch.tensor(padded_cat, dtype=torch.long),
            "time_seq": padded_time,
            "target_asin": torch.tensor(target_asin, dtype=torch.long),
            "neg_asin": torch.tensor(neg_asin, dtype=torch.long),
            "target_cat": torch.tensor(self.pad_id if target_cat is None else target_cat, dtype=torch.long),
            "target_time": target_time
        }

    def sample_negative(self, exclude_asins: set):
        import random
        while True:
            neg = random.randint(1, self.num_asins - 1)  # 0은 padding용이라 제외
            if neg not in exclude_asins:
                return neg

In [23]:
def multimodal_collate_fn(batch):
    """
    Custom collate function to handle datetime + tensor fields and negative sampling.
    """
    batch_user = [b["user_id"] for b in batch]
    batch_asin = torch.stack([b["asin_seq"] for b in batch])
    batch_cat = torch.stack([b["cat_seq"] for b in batch])
    batch_time = [b["time_seq"] for b in batch]  # list of list of datetime

    batch_target_asin = torch.stack([b["target_asin"] for b in batch])
    batch_neg_asin = torch.stack([b["neg_asin"] for b in batch])  # ✅ 추가
    batch_target_cat = torch.stack([b["target_cat"] for b in batch])
    batch_target_time = [b["target_time"] for b in batch]

    return {
        "user_id": batch_user,
        "asin_seq": batch_asin,
        "cat_seq": batch_cat,
        "time_seq": batch_time,
        "target_asin": batch_target_asin,
        "neg_asin": batch_neg_asin,  # ✅ 추가
        "target_cat": batch_target_cat,
        "target_time": batch_target_time
    }

In [24]:
asin2cat = merged_clean.drop_duplicates('ASIN_ID').set_index('ASIN_ID')['Category_ID'].to_dict()

asin2cycle = merged_clean.drop_duplicates('ASIN_ID').set_index('ASIN_ID')['Purchase_Cycle'].to_dict()

In [25]:
MAX_LEN = 50
PAD_ID = 0  # now safe because cat_id starts from 1
NUM_ASINS = max(asin2cat.keys()) + 1  # padding 포함

train_dataset = MultiModalSASRecDataset(train_data, num_asins=NUM_ASINS, max_len=MAX_LEN, pad_id=PAD_ID)
val_dataset   = MultiModalSASRecDataset(val_data,   num_asins=NUM_ASINS, max_len=MAX_LEN, pad_id=PAD_ID)
test_dataset  = MultiModalSASRecDataset(test_data,  num_asins=NUM_ASINS, max_len=MAX_LEN, pad_id=PAD_ID)

In [26]:
sample = train_dataset[0]

print("user_id:", sample['user_id'])
print("ASIN sequence:", sample['asin_seq'])       # Tensor of length MAX_LEN
print("Category sequence:", sample['cat_seq'])    # Tensor of length MAX_LEN
print("Timestamp sequence:", sample['time_seq'])  # list of datetime
print("Target ASIN:", sample['target_asin'])      # scalar tensor
print("Target Category:", sample['target_cat'])   # scalar tensor
print("Target Timestamp:", sample['target_time']) # datetime object
print("Negative ASIN:", sample['neg_asin'])       # scalar tensor

user_id: 0
ASIN sequence: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4,
        5, 6])
Category sequence: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4,
        5, 5])
Timestamp sequence: [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, Timestamp('2018-12-04 00:00:00'), Timestamp('2018-12-22 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2018-12-25 00:00:00'), Timestamp('2019-02-18 00:00:00'), Timestamp('2019-02-18 00:00:00')]
Target ASIN: tensor(7)
Target Category: tensor(6)
Target Timestamp: 2019-04-23 00:00:00
Negative ASIN: tensor(48361)

In [27]:
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=multimodal_collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=64, shuffle=False, collate_fn=multimodal_collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False, collate_fn=multimodal_collate_fn)

# 4. 모델 정의 및 학습 루프

In [28]:
class MultiModalSASRec(nn.Module):
    def __init__(self, num_asins, num_categories, hidden_dim=64, max_len=50, num_layers=2, num_heads=2, dropout=0.2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.max_len = max_len
        self.num_asins = num_asins
        self.num_categories = num_categories

        self.asin_embedding = nn.Embedding(num_asins, hidden_dim, padding_idx=0)
        self.cat_embedding = nn.Embedding(num_categories, hidden_dim, padding_idx=0)
        self.pos_embedding = nn.Embedding(max_len, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, asin_seq, cat_seq):
        B, T = asin_seq.shape
        positions = torch.arange(T, device=asin_seq.device).unsqueeze(0).expand(B, T)

        asin_emb = self.asin_embedding(asin_seq)
        cat_emb = self.cat_embedding(cat_seq)
        pos_emb = self.pos_embedding(positions)

        x = (asin_emb + cat_emb + pos_emb) * (self.hidden_dim ** 0.5)
        x = self.layer_norm(x)

        attn_mask = self._generate_square_subsequent_mask(T, device=x.device)
        padding_mask = asin_seq == 0  # (B, T)

        x = self.transformer(x, mask=attn_mask, src_key_padding_mask=padding_mask)

        return x  # (B, T, hidden_dim)

    def _generate_square_subsequent_mask(self, sz, device):
        return torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 정확한 ASIN 개수
NUM_ASINS = max(asin2cat.keys()) + 1
NUM_CATEGORIES = merged_clean['Category_ID'].max() + 1

# ASIN → Category 매핑 텐서
asin_id_to_cat_id = torch.tensor([
    asin2cat.get(i, 0) for i in range(NUM_ASINS)
], device=device)

# 🔥 최적화: ASIN → Purchase_Cycle 딕셔너리 → 텐서
asin_cycle_map = merged_clean.drop_duplicates('ASIN_ID').set_index('ASIN_ID')['Purchase_Cycle'].to_dict()
asin_id_to_cycle = torch.tensor([
    asin_cycle_map.get(i, 0.0) for i in range(NUM_ASINS)
], dtype=torch.float, device=device)

# 모델 생성
model = MultiModalSASRec(
    num_asins=NUM_ASINS,
    num_categories=NUM_CATEGORIES,
    hidden_dim=128,
    max_len=50,
    num_layers=3,
    num_heads=4,
    dropout=0.2
).to(device)

In [30]:
print(f"asin_id_to_cat_id: {asin_id_to_cat_id.shape}")
print(f"cat_id_to_median: {asin_id_to_cycle.shape}")
print(f"model.asin_embedding: {model.asin_embedding.weight.shape}")

asin_id_to_cat_id: torch.Size([791452])
cat_id_to_median: torch.Size([791452])
model.asin_embedding: torch.Size([791452, 128])


In [None]:
import os

# === BPR Loss 함수 정의 ===
def bpr_loss(pos_scores, neg_scores):
    return -torch.mean(torch.log(torch.sigmoid(pos_scores - neg_scores) + 1e-8))

def sample_negative(batch_size, num_items, true_items):
    neg_items = torch.randint(1, num_items, size=(batch_size,), device=true_items.device)
    mask = neg_items == true_items
    while mask.any():
        neg_items[mask] = torch.randint(1, num_items, size=(mask.sum(),), device=true_items.device)
        mask = neg_items == true_items
    return neg_items

# === Dual Evaluation 함수 ===
def evaluate_dual_metrics(model, dataloader, device, asin_id_to_cat_id, asin_id_to_cycle, k=20):
    model.eval()
    asin_hits, asin_mrr, cat_hits, cat_mrr, total = 0, 0, 0, 0, 0
    penalty_strength = 5.0
    first_batch = True  # 첫 배치에서만 logits 분포 출력

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            asin_seq = batch["asin_seq"].to(device)
            cat_seq = batch["cat_seq"].to(device)
            target_asin = batch["target_asin"].to(device)
            target_cat = batch["target_cat"].to(device)
            time_seqs = batch["time_seq"]
            target_times = batch["target_time"]

            output = model(asin_seq, cat_seq)
            logits = output[:, -1, :] @ model.asin_embedding.weight.T

            # Soft Time Penalty 적용
            gap_days = torch.tensor([
                max((target_times[i] - time_seqs[i][-1]).days, 0)
                for i in range(len(target_times))
            ], device=device).unsqueeze(1)  # (B, 1)

            purchase_cycle = asin_id_to_cycle.unsqueeze(0).expand(logits.size(0), -1)  # (B, num_asins)
            safe_cycle = torch.clamp(purchase_cycle, min=1.0)
            penalty_ratio = torch.clamp(1 - (gap_days / safe_cycle), 0.0, 1.0)

            logits -= penalty_ratio * penalty_strength  # ← ✅ 패널티 적용

            top_k_asin = logits.topk(k, dim=-1).indices  # (B, K)
            top_k_cat = asin_id_to_cat_id[top_k_asin]    # (B, K)

            # ASIN 평가
            asin_hits += (top_k_asin == target_asin.unsqueeze(1)).any(dim=1).sum().item()
            for i in range(target_asin.size(0)):
                pred_asins = top_k_asin[i].tolist()
                if (true_asin := target_asin[i].item()) in pred_asins:
                    asin_mrr += 1.0 / (pred_asins.index(true_asin) + 1)

            # Category 평가
            cat_hits += (top_k_cat == target_cat.unsqueeze(1)).any(dim=1).sum().item()
            for i in range(target_cat.size(0)):
                pred_cats = top_k_cat[i].tolist()
                if (true_cat := target_cat[i].item()) in pred_cats:
                    cat_mrr += 1.0 / (pred_cats.index(true_cat) + 1)

            total += target_asin.size(0)

            if first_batch:
                print(f"[logits] min: {logits.min().item():.4f}, max: {logits.max().item():.4f}, "
                      f"mean: {logits.mean().item():.4f}, std: {logits.std().item():.4f}")
                first_batch = False

    print(f"🌟 ASIN Hit@{k}: {asin_hits / total:.4f} | MRR@{k}: {asin_mrr / total:.4f}")
    print(f"🧠 CAT  Hit@{k}: {cat_hits / total:.4f} | MRR@{k}: {cat_mrr / total:.4f}")
    return asin_hits / total, asin_mrr / total, cat_hits / total, cat_mrr / total

# === 학습 루프 ===
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
save_path = "/content/drive/MyDrive/0627_timeaware_final3.pt"
num_epochs = 3
best_hit = 0.0
patience = 3
min_delta = 0.001
wait = 0

if os.path.exists(save_path):
    model.load_state_dict(torch.load(save_path, map_location=device))
    print(f"📦 Loaded existing model from {save_path}")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        asin_seq = batch["asin_seq"].to(device)
        cat_seq = batch["cat_seq"].to(device)
        target_asin = batch["target_asin"].to(device)

        output = model(asin_seq, cat_seq)
        user_vec = output[:, -1, :]

        neg_asin = sample_negative(target_asin.size(0), model.num_asins, target_asin)
        pos_emb = model.asin_embedding(target_asin)
        neg_emb = model.asin_embedding(neg_asin)

        pos_scores = (user_vec * pos_emb).sum(dim=-1)
        neg_scores = (user_vec * neg_emb).sum(dim=-1)
        loss = bpr_loss(pos_scores, neg_scores)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    asin_hit, asin_mrr, val_hit, val_mrr = evaluate_dual_metrics(
        model, val_loader, device, asin_id_to_cat_id, asin_id_to_cycle, k=20
    )

    if val_hit > best_hit + min_delta:
        best_hit = val_hit
        wait = 0
        torch.save(model.state_dict(), save_path)
        print(f"🎉 Best model saved! CAT Hit@20: {val_hit:.4f} → {save_path}")
    else:
        wait += 1
        print(f"⏳ No improvement. Patience: {wait}/{patience}")
        if wait >= patience:
            print("🚑 Early stopping triggered.")
            break

    print(f"📊 Epoch {epoch+1} | Loss: {avg_loss:.4f} | ASIN Hit/MRR@20: {asin_hit:.4f}/{asin_mrr:.4f} | CAT Hit/MRR@20: {val_hit:.4f}/{val_mrr:.4f}")


Epoch 1: 100%|██████████| 7933/7933 [09:40<00:00, 13.67it/s]
Evaluating:   0%|          | 6/7933 [00:00<05:03, 26.15it/s]

[logits] min: nan, max: nan, mean: nan, std: nan


Evaluating: 100%|██████████| 7933/7933 [04:55<00:00, 26.82it/s]


🌟 ASIN Hit@20: 0.0002 | MRR@20: 0.0000
🧠 CAT  Hit@20: 0.0197 | MRR@20: 0.0022
🎉 Best model saved! CAT Hit@20: 0.0197 → /content/drive/MyDrive/0627_timeaware_final3.pt
📊 Epoch 1 | Loss: 0.3805 | ASIN Hit/MRR@20: 0.0002/0.0000 | CAT Hit/MRR@20: 0.0197/0.0022


Epoch 2:  29%|██▉       | 2316/7933 [02:49<06:51, 13.66it/s]

In [47]:
def evaluate_model_on_test(model, test_loader, device, asin_id_to_cat_id, asin_id_to_cycle, k=20, penalty_strength=5.0):
    model.eval()
    asin_hits, asin_mrr, cat_hits, cat_mrr, total = 0, 0, 0, 0, 0

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="🧪 Final Evaluation"):
            asin_seq = batch["asin_seq"].to(device)
            cat_seq = batch["cat_seq"].to(device)
            target_asin = batch["target_asin"].to(device)
            target_cat = batch["target_cat"].to(device)
            time_seqs = batch["time_seq"]
            target_times = batch["target_time"]

            output = model(asin_seq, cat_seq)
            logits = output[:, -1, :] @ model.asin_embedding.weight.T  # (B, num_asins)

            # === Soft Masking ===
            gap_days = torch.tensor([
                max((target_times[i] - time_seqs[i][-1]).days, 0)
                for i in range(len(target_times))
            ], device=device).unsqueeze(1)  # (B, 1)

            purchase_cycle = asin_id_to_cycle.unsqueeze(0).expand(logits.size(0), -1)  # (B, num_asins)
            penalty_ratio = torch.clamp(1 - (gap_days / (purchase_cycle + 1e-6)), 0.0, 1.0)
            logits -= penalty_ratio * penalty_strength

            # === Top-k Ranking ===
            top_k_asin = logits.topk(k, dim=-1).indices
            top_k_cat = asin_id_to_cat_id[top_k_asin]

            # === ASIN 평가 ===
            asin_hits += (top_k_asin == target_asin.unsqueeze(1)).any(dim=1).sum().item()
            for i in range(target_asin.size(0)):
                true = target_asin[i].item()
                pred = top_k_asin[i].tolist()
                if true in pred:
                    asin_mrr += 1.0 / (pred.index(true) + 1)

            # === Category 평가 ===
            cat_hits += (top_k_cat == target_cat.unsqueeze(1)).any(dim=1).sum().item()
            for i in range(target_cat.size(0)):
                true = target_cat[i].item()
                pred = top_k_cat[i].tolist()
                if true in pred:
                    cat_mrr += 1.0 / (pred.index(true) + 1)

            total += target_asin.size(0)

    # === 결과 출력 ===
    print(f"\n✅ Final Evaluation:")
    print(f"🎯 ASIN  Hit@{k}: {asin_hits / total:.4f} | MRR@{k}: {asin_mrr / total:.4f}")
    print(f"🧠 CAT   Hit@{k}: {cat_hits / total:.4f} | MRR@{k}: {cat_mrr / total:.4f}")
    return asin_hits / total, asin_mrr / total, cat_hits / total, cat_mrr / total

In [None]:
model.load_state_dict(torch.load(save_path, map_location=device))
evaluate_model_on_test(
    model, test_loader, device,
    asin_id_to_cat_id, asin_id_to_cycle, k=20
)