In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder

# Function to load news data and encode categories
def load_news_data(filepath):
    news = pd.read_csv(filepath, sep='\t', names=['News_ID', 'Category', 'SubCategory', 'Title', 'Abstract', 'Entities', 'Relations'])
    
    # Convert categories and subcategories to integers
    category_encoder = LabelEncoder()
    subcategory_encoder = LabelEncoder()

    news['Category'] = category_encoder.fit_transform(news['Category'])
    news['SubCategory'] = subcategory_encoder.fit_transform(news['SubCategory'])

    return news, len(category_encoder.classes_), len(subcategory_encoder.classes_)

# Load entity and relation embeddings
entity_embeddings_df = pd.read_csv('./MINDsmall_train/entity_embedding.vec', delimiter='\t', header=None, index_col=0)
entity_embeddings = entity_embeddings_df.values

relation_embeddings_df = pd.read_csv('./MINDsmall_train/relation_embedding.vec', delimiter='\t', header=None, index_col=0)
relation_embeddings = relation_embeddings_df.values

# Function to pad sequences
def pad_sequences(sequences, max_length=None):
    if max_length is None:
        max_length = max(len(seq) for seq in sequences)
    padded_sequences = [seq[:max_length] + [0] * max(0, max_length - len(seq)) for seq in sequences]
    return torch.tensor(padded_sequences, dtype=torch.long)

# Function to get entity/relation embeddings for a sequence
def get_embedding_for_sequence(sequence, embedding_dict, embed_dim, device):
    embeddings = []
    max_length = 0  # 최대 임베딩 길이 초기화
    for items in sequence:
        item_embeddings = [embedding_dict[item] for item in items if item in embedding_dict]
        if item_embeddings:
            embeddings.append(torch.sum(torch.stack(item_embeddings), dim=0))
            max_length = max(max_length, len(item_embeddings))  # 가장 긴 임베딩 길이 갱신
        else:
            embeddings.append(torch.zeros(embed_dim, device=device))
    
    # 모든 임베딩의 길이를 가장 긴 임베딩 길이로 맞춤
    embeddings = [F.pad(embeddings[i], (0, max_length - len(item_embeddings)), 'constant', 0).unsqueeze(0) for i, item_embeddings in enumerate(embeddings)]
    return torch.cat(embeddings, dim=0)



# Read behaviors.tsv and news.tsv files
behaviors = pd.read_csv('./MINDsmall_train/behaviors.tsv', delimiter='\t', header=None)
news, num_categories, num_subcategories = load_news_data('./MINDsmall_train/news.tsv')
news_ids = news['News_ID'].unique()

# Assign column names
behaviors.columns = ['Impression ID', 'User ID', 'Time', 'History', 'Impressions']

# Assign integer indices to user and news IDs
user_ids = behaviors['User ID'].unique()
user2idx = {o: i for i, o in enumerate(user_ids)}
idx2user = {i: o for i, o in enumerate(user_ids)}

news2idx = {o: i for i, o in enumerate(news_ids)}
idx2news = {i: o for i, o in enumerate(news_ids)}

# Replace user and news IDs with integer indices
behaviors['User ID'] = behaviors['User ID'].apply(lambda x: user2idx[x])
behaviors['History'] = behaviors['History'].fillna('')
behaviors['History'] = behaviors['History'].apply(lambda x: [news2idx[i] for i in x.split() if i in news2idx])

# Replace impressions with integer indices
news2idx['unknown'] = len(news2idx)

def process_impressions(impressions):
    return [news2idx.get(i.split('-')[0], news2idx['unknown']) for i in impressions.split()]

behaviors['Impressions'] = behaviors['Impressions'].apply(process_impressions)
# Split train and test behaviors
train_behaviors, test_behaviors = train_test_split(behaviors, test_size=0.2, random_state=42)

# Define dataset class
class NewsDataset(Dataset):
    def __init__(self, user_tensor, sequence_tensor, impression_tensor, category_tensor, subcategory_tensor, entity_tensor, relation_tensor):
        self.user_tensor = user_tensor
        self.sequence_tensor = sequence_tensor
        self.impression_tensor = impression_tensor
        self.category_tensor = category_tensor
        self.subcategory_tensor = subcategory_tensor
        self.entity_tensor = entity_tensor
        self.relation_tensor = relation_tensor

    def __getitem__(self, index):
        return (
            self.user_tensor[index],
            self.sequence_tensor[index],
            self.impression_tensor[index % self.impression_tensor.shape[0]],
            self.category_tensor[index % self.category_tensor.shape[0]],
            self.subcategory_tensor[index % self.subcategory_tensor.shape[0]],
            self.entity_tensor[index % self.entity_tensor.shape[0]],
            self.relation_tensor[index % self.relation_tensor.shape[0]]
        )

    def __len__(self):
        return self.sequence_tensor.shape[0]

# Define the model
class CausalRec(nn.Module):
    def __init__(self, num_users, num_items, num_categories, num_subcategories, embed_dim, num_gru_layers, device):
        super(CausalRec, self).__init__()
        
        self.num_items = num_items

        self.user_embedding = nn.Embedding(num_users, embed_dim)
        self.item_embedding = nn.Embedding(num_items, embed_dim)
        self.category_embedding = nn.Embedding(num_categories, embed_dim)
        self.subcategory_embedding = nn.Embedding(num_subcategories, embed_dim)

        self.gru = nn.GRU(embed_dim, embed_dim, num_layers=num_gru_layers, batch_first=True)
        self.fc = nn.Linear(embed_dim, 1)

        self.lambda_ = torch.zeros(1, requires_grad=False).to(device)
        self.mu_ = torch.zeros(1, requires_grad=False).to(device)

        self.temperature = 1.0
        self.eps = 1e-10

    def forward(self, user_indices, sequence, item_indices, category_indices, subcategory_indices, entities, relations):
        user_embed = self.user_embedding(user_indices)
        item_embed = self.item_embedding(item_indices % self.num_items)

        category_embed = self.category_embedding(category_indices).unsqueeze(1).expand(-1, item_embed.size(1), -1)
        subcategory_embed = self.subcategory_embedding(subcategory_indices).unsqueeze(1).expand(-1, item_embed.size(1), -1)
        entity_embed = get_embedding_for_sequence(entities, entity_embeddings, embed_dim, device).unsqueeze(1).expand(-1, item_embed.size(1), -1)
        relation_embed = get_embedding_for_sequence(relations, relation_embeddings, embed_dim, device).unsqueeze(1).expand(-1, item_embed.size(1), -1)

        item_embed = item_embed.unsqueeze(2).expand(-1, -1, category_embed.size(1), -1)
        item_embed = item_embed[:, :1, :4, :]  # 크기를 (4, 1, 4, 100)로 변경


        out, _ = self.gru(sequence)
        out = self.fc(out)

        user_embed_expanded = user_embed.unsqueeze(1).expand(-1, item_embed.size(1), -1)
        out_expanded = out.unsqueeze(1).expand(-1, item_embed.size(1), -1)

        prob = torch.sigmoid((user_embed_expanded * item_embed).sum(dim=-1, keepdim=True) + out_expanded)
        
        # constraint라는 제약조건을 설정, prob의 평균(클릭확률 vs Non클릭확률)이 0.5에 가깝게 유지
        constraint = (prob.mean() - 0.5) ** 2
        self.lambda_ = self.lambda_ + self.mu_ * constraint
        
        # Gumbel softmax - Reparametrization
        u = torch.rand(prob.shape).to(device)
        gumbel_noise = -torch.log(-torch.log(u + self.eps) + self.eps)
        logit = (prob + gumbel_noise) / self.temperature
        prob = torch.sigmoid(logit)

        return prob



    
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define hyperparameters
num_users = len(user_ids)
num_items = len(news_ids)
embed_dim = 100
num_gru_layers = 2
max_length = 100

# Instantiate the model
model = CausalRec(num_users, num_items, num_categories, num_subcategories, embed_dim, num_gru_layers, device)
model.to(device)

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Create datasets and data loaders
train_dataset = NewsDataset(
    torch.tensor(train_behaviors['User ID'].values, dtype=torch.long),
    pad_sequences(train_behaviors['History'].values, max_length).float(),
    pad_sequences(train_behaviors['Impressions'].values).float(),
    torch.tensor(news['Category'].values, dtype=torch.long),
    torch.tensor(news['SubCategory'].values, dtype=torch.long),
    entity_embeddings,
    relation_embeddings
)
test_dataset = NewsDataset(
    torch.tensor(test_behaviors['User ID'].values, dtype=torch.long),
    pad_sequences(test_behaviors['History'].values, max_length).float(),
    pad_sequences(test_behaviors['Impressions'].values).float(),
    torch.tensor(news['Category'].values, dtype=torch.long),
    torch.tensor(news['SubCategory'].values, dtype=torch.long),
    entity_embeddings,
    relation_embeddings
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Training loop
num_epochs = 10

eta = 1.1
delta = 0.1

for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0

    model.train()

    for user_indices, sequence, item_indices, category_indices, subcategory_indices, entities, relations in train_loader:
        user_indices = user_indices.to(device)
        sequence = sequence.to(device)
        item_indices = item_indices.to(device, dtype=torch.long)
        category_indices = category_indices.to(device)
        subcategory_indices = subcategory_indices.to(device)

        optimizer.zero_grad()

        outputs = model(
            user_indices, sequence, item_indices, category_indices, subcategory_indices, entities, relations
        )

        labels = torch.ones_like(outputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        # 제약조건이 만족되지 않을 경우(constraint가 delta * prev_constraint보다 큰 경우), Lagrange multiplier인 mu_를 eta배 증가시킴으로써 제약조건을 강화 
        constraint = (outputs.mean() - 0.5) ** 2
        model.lambda_ = model.lambda_ + model.mu_ * constraint

        if constraint > delta * prev_constraint:
            model.mu_ = eta * model.mu_

        prev_constraint = constraint

        total_loss += loss.item()
        num_batches += 1

    average_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss:.4f}")

Epoch 1/10, Loss: 0.2841
Epoch 2/10, Loss: 0.2839
Epoch 3/10, Loss: 0.2837
Epoch 4/10, Loss: 0.2838
Epoch 5/10, Loss: 0.2840
Epoch 6/10, Loss: 0.2838
Epoch 7/10, Loss: 0.2841
Epoch 8/10, Loss: 0.2840
Epoch 9/10, Loss: 0.2837
Epoch 10/10, Loss: 0.2840


In [6]:
def hit_at_k(predictions, targets, k):
    '''
    Calculate hit@k score
    predictions: sorted list of predictions
    targets: list of targets
    '''
    top_k_preds = predictions[:k]
    if len(set(top_k_preds).intersection(set(targets))) > 0:
        return 1
    return 0


def ndcg_at_k(predictions, targets, k):
    '''
    Calculate ndcg@k score
    predictions: sorted list of predictions
    targets: list of targets
    '''
    dcg = 0
    for i, pred in enumerate(predictions[:k]):
        if pred in targets:
            dcg += 1 / np.log2(i + 2)  # i + 2 because i starts at 0 and the formula starts at 1

    idcg = 0  # ideal dcg; best possible dcg given the targets
    for i in range(len(targets)):
        idcg += 1 / np.log2(i + 2)
        
    return dcg / idcg

# Evaluation loop
model.eval()

hit1_total = 0
hit5_total = 0
ndcg5_total = 0
num_batches = 0

for user_indices, sequence, item_indices, category_indices, subcategory_indices, entities, relations in test_loader:
    user_indices = user_indices.to(device)
    sequence = sequence.to(device)
    item_indices = item_indices.to(device, dtype=torch.long)
    category_indices = category_indices.to(device)
    subcategory_indices = subcategory_indices.to(device)

    with torch.no_grad():
        outputs = model(
            user_indices, sequence, item_indices, category_indices, subcategory_indices, entities, relations
        )

    predictions = outputs.flatten().argsort().tolist()
    targets = item_indices.flatten().tolist()

    hit1_total += hit_at_k(predictions, targets, 1)
    hit5_total += hit_at_k(predictions, targets, 5)
    ndcg5_total += ndcg_at_k(predictions, targets, 5)

    num_batches += 1

hit1 = hit1_total / num_batches
hit5 = hit5_total / num_batches
ndcg5 = ndcg5_total / num_batches

print("Hit@1: ", hit1)
print("Hit@5: ", hit5)
print("NDCG@5: ", ndcg5)

Hit@1:  0.030449738820231876
Hit@5:  0.14995540833227164
NDCG@5:  0.0006621525547102801
