# Text Codification Pipeline

This notebook implements the **Inductive Text Pipeline** for the recommender system.

**Architecture:**
1. **Item Representation:**
   - **Review Aggregation:** An Attention mechanism summarizes up to 10 reviews into a single vector.
   - **Gated Fusion:** A learnable gate fuses the `Overview Embedding` with the `Aggregated Review Embedding`. If reviews are missing, the model learns to rely on the overview.

2. **Interaction Representation:**
   - We treat **Ratings** as embeddings (1-5) and add them to the Item Text Embedding.
   - $h_{interaction} = h_{item\_text} + e_{rating}$

3. **User Representation:**
   - A sequence-based Attention mechanism (User Query) attends to the user's history of interactions to generate a dynamic user profile.

In [21]:
%%capture
!pip install torch_geometric

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import os
import json
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from torch_geometric.data import HeteroData
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from recomender_metrics import evaluate_recommendations, print_evaluation_results

EMBED_DIM = 1024  # Size of Alibaba-NLP/gte-large-en-v1.5 embeddings
RATING_DIM = 1024
HEADS = 4
DROPOUT = 0.1
BATCH_SIZE = 512
LR = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")

Using device: cuda


## 1. Data Loading (Pre-computed Embeddings)

We load the `.pt` file containing the SentenceTransformer embeddings for Overviews and Reviews.

We align these tensors with our internal `movie_map` so that `movie_idx` 0 corresponds to the 0th row in these tensors.

In [23]:
def load_text_embeddings(pt_file_path, movie_map):
    """
    Loads pre-computed text embeddings and aligns them with the movie_map.
    """
    print(f"Loading text embeddings from {pt_file_path}...")

    # Load the dictionary saved by the preprocessing step
    data = torch.load(pt_file_path, map_location='cpu', weights_only=False)

    raw_ids = data['movie_ids']       # The original CSV IDs
    ov_embs = data['overview_embs']   # (N_samples, 384)
    rev_embs = data['review_embs']    # (N_samples, 10, 384)
    masks = data['review_mask']       # (N_samples, 10) - 1 if review exists, 0 if padding

    num_movies = len(movie_map)
    dim = ov_embs.shape[1]
    max_rev = rev_embs.shape[1]

    # Initialize aligned tensors
    # ao: Aligned Overviews, ar: Aligned Reviews, am: Aligned Masks
    ao = torch.zeros((num_movies, dim), dtype=torch.float32)
    ar = torch.zeros((num_movies, max_rev, dim), dtype=torch.float32)
    am = torch.zeros((num_movies, max_rev), dtype=torch.float32)

    # We also need a mask to know if a movie has ANY reviews at all
    # This helps the Gate know when to ignore the review branch entirely
    has_reviews_mask = torch.zeros((num_movies, 1), dtype=torch.float32)

    hits = 0
    for i, mid in enumerate(raw_ids):
        # raw_ids might be integers or strings depending on previous steps
        # Ensure type consistency with movie_map keys
        if mid in movie_map:
            idx = movie_map[mid]
            ao[idx] = ov_embs[i]
            ar[idx] = rev_embs[i]
            am[idx] = masks[i]

            # If the sum of the mask > 0, we have at least one review
            if masks[i].sum() > 0:
                has_reviews_mask[idx] = 1.0
            hits += 1

    print(f"Aligned {hits} movies out of {len(raw_ids)} raw embeddings.")

    return ao, ar, am, has_reviews_mask

## 2. Review Aggregator & Gated Fusion

This module handles the "Static Item" representation.

1. **`ReviewAttention`**: Compresses $N$ reviews into 1 vector.
2. **`ItemTextEncoder`**: Contains the gating logic.
   - Formula: $h_{final} = \lambda \cdot h_{overview} + (1 - \lambda) \cdot h_{reviews}$
   - We compute both, but if `has_reviews_mask` is 0, we force the gate to strictly use Overview.

In [24]:
class ReviewAttention(nn.Module):
    """
    Aggregates multiple reviews into a single embedding using Attention.
    """
    def __init__(self, dim):
        super().__init__()
        # A learnable 'Query' vector that asks: "What is the consensus of these reviews?"
        self.query = nn.Sequential(
            nn.Linear(dim, 128),
            nn.ReLu(),
            nn.Linear(128, 1, bias=False)
        )

    def forward(self, review_embs, review_mask):
        """
        review_embs: (Batch, Num_Reviews, Dim)
        review_mask: (Batch, Num_Reviews) - 1 for valid, 0 for pad
        """
        # (B, 10, 1)
        attn_scores = self.query(review_embs)

        # Masking: Set score to -infinity where mask is 0 to ignore padding
        # We perform masked_fill on the squeezed dimension or keep dim
        attn_scores = attn_scores.masked_fill(review_mask.unsqueeze(-1) == 0, -1e9)

        # (B, 10, 1)
        attn_weights = F.softmax(attn_scores, dim=1)

        # Weighted sum: (B, 10, 1) * (B, 10, Dim) -> (B, 10, Dim) -> Sum -> (B, Dim)
        aggregated = torch.sum(attn_weights * review_embs, dim=1)
        return aggregated

class ItemTextEncoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.review_agg = ReviewAttention(dim)

        # Gating mechanism
        # Projects concatenated [Overview, Reviews] to a scalar weight [0, 1]
        self.gate_net = nn.Sequential(
            nn.Linear(dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

        self.norm = nn.LayerNorm(dim)

    def forward(self, ov_embs, rev_embs, rev_mask, has_rev_mask):
        """
        ov_embs: (Batch, Dim)
        rev_embs: (Batch, 10, Dim)
        rev_mask: (Batch, 10)
        has_rev_mask: (Batch, 1) - 1 if movie has reviews, 0 if not
        """
        # Aggregate Reviews
        # If a movie has NO reviews, this vector will be garbage (or zero),
        # but the gate will handle it.
        h_rev = self.review_agg(rev_embs, rev_mask)

        # Compute Gate
        # Concatenate overview and reviews to decide how much to trust each
        combined = torch.cat([ov_embs, h_rev], dim=1)
        alpha = self.gate_net(combined) # (Batch, 1)

        # Apply Logic for Missing Reviews
        # If has_rev_mask is 0 (no reviews), we force alpha to 1.0 (Trust Overview 100%)
        # This overrides the learned gate for cold-start review items
        alpha = alpha * has_rev_mask + (1.0 - has_rev_mask)

        # 4. Fusion
        h_final = alpha * ov_embs + (1 - alpha) * h_rev

        return self.norm(h_final)

## 3. User History Encoder

This module handles the "Dynamic User" representation.

1. **Input**: A sequence of Item Embeddings (from step 2) + Rating IDs (1-5).
2. **Interaction Embedding**: $h_{item} + Embedding(Rating)$.
3. **Aggregation**: A static "User Query" attention mechanism. This allows the model to look at the history and extract a fixed-size user vector.
   - *Note:* We include a rating '0' for padding, so embedding size is 6.

In [25]:
class UserHistoryAttention(nn.Module):
    def __init__(self, dim, num_rating_levels=6): # 0=Pad, 1-5=Ratings
        super().__init__()

        # Rating Embedding (Additive)
        self.rating_emb = nn.Embedding(num_rating_levels, dim, padding_idx=0)

        # The "User Persona" query
        # "Given my history, who am I?"
        self.user_query = nn.Parameter(torch.randn(1, dim))

        # Attention Keys/Values projection
        self.key_layer = nn.Linear(dim, dim)
        self.value_layer = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(DROPOUT)
        self.norm = nn.LayerNorm(dim)

    def forward(self, item_vectors, ratings, padding_mask):
        """
        item_vectors: (Batch, Seq_Len, Dim) - The text embeddings of movies watched
        ratings: (Batch, Seq_Len) - Integers 1-5 (0 for pad)
        padding_mask: (Batch, Seq_Len) - 1 for valid, 0 for pad
        """
        # Inject Ratings
        # Additive composition as discussed
        r_emb = self.rating_emb(ratings)
        interaction_vecs = item_vectors + r_emb

        # Attention Mechanism
        # Query: (1, Dim) -> Broadcast to (Batch, 1, Dim)
        B = item_vectors.size(0)
        Q = self.user_query.expand(B, 1, -1)

        K = self.key_layer(interaction_vecs) # (B, Seq, Dim)
        V = self.value_layer(interaction_vecs)

        # Scores: Q * K^T
        # (B, 1, D) @ (B, D, Seq) -> (B, 1, Seq)
        scores = torch.bmm(Q, K.transpose(1, 2)) / (interaction_vecs.size(-1) ** 0.5)

        # Mask out padding in history
        scores = scores.masked_fill(padding_mask.unsqueeze(1) == 0, -1e9)

        weights = F.softmax(scores, dim=-1) # (B, 1, Seq)
        weights = self.dropout(weights)

        # Weighted Sum
        # (B, 1, Seq) @ (B, Seq, D) -> (B, 1, D)
        user_vector = torch.bmm(weights, V).squeeze(1)

        return self.norm(user_vector)


## 4. The Full Pipeline Wrapper

This class orchestrates the whole flow. It holds the static embeddings (on GPU) and processes a batch of user histories.

In [26]:
class TextCodificationPipeline(nn.Module):
    def __init__(self, ov_embs, rev_embs, rev_masks, has_rev_masks, dim=1024):
        super().__init__()

        # Register the static data as buffers (not learnable parameters, but part of state)
        # This keeps them on the correct device
        self.register_buffer('static_ov', ov_embs)
        self.register_buffer('static_rev', rev_embs)
        self.register_buffer('static_rev_mask', rev_masks)
        self.register_buffer('static_has_rev', has_rev_masks)

        self.item_encoder = ItemTextEncoder(dim)
        self.user_encoder = UserHistoryAttention(dim)

    def forward(self, history_movie_ids, history_ratings, history_mask):
        """
        Generates User Embeddings from their history.

        history_movie_ids: (Batch, Seq_Len) - Indices in the movie_map
        history_ratings: (Batch, Seq_Len) - Rating values (1-5)
        history_mask: (Batch, Seq_Len) - Mask for user history length
        """

        # Lookup Item Features for the batch
        # We flatten the batch to process all movies at once
        B, S = history_movie_ids.shape
        flat_ids = history_movie_ids.view(-1) # (B*S)

        batch_ov = self.static_ov[flat_ids]         # (B*S, Dim)
        batch_rev = self.static_rev[flat_ids]       # (B*S, 10, Dim)
        batch_rm = self.static_rev_mask[flat_ids]   # (B*S, 10)
        batch_has_r = self.static_has_rev[flat_ids] # (B*S, 1)

        # Compute Item Text Embeddings
        # Result: (B*S, Dim)
        flat_item_embs = self.item_encoder(batch_ov, batch_rev, batch_rm, batch_has_r)

        # Reshape back to sequence
        # (B, S, Dim)
        seq_item_embs = flat_item_embs.view(B, S, -1)

        # Compute User Embedding
        user_emb = self.user_encoder(seq_item_embs, history_ratings, history_mask)

        return user_emb

In [27]:
import json
import torch
import pandas as pd
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

def load_data_and_build_graph_with_context(context_dir='context_data'):
    print("Loading Graph Data with Context...")

    # 1. Load Standard Data
    movies_df = pd.read_csv('movies_graph_ready.csv')
    train_df = pd.read_csv('u1.base', sep='\t', names=['user_id', 'movie_id', 'rating', 'timestamp'])
    test_df = pd.read_csv('u1.test', sep='\t', names=['user_id', 'movie_id', 'rating', 'timestamp'])

    # 2. Load Context/Metadata
    movie_ctx_df = pd.read_csv(f'{context_dir}/movie_context.csv')
    with open(f'{context_dir}/context_metadata.json', 'r') as f:
        ctx_meta = json.load(f)

    # 3. Load Auxiliary Files (Genres, etc.) for Evaluation Mappings
    genres_df = pd.read_csv('nodes/genres.csv')

    # --- MAPPINGS ---
    # Ensure strict alignment between movie_id and tensor index
    movie_map = {mid: i for i, mid in enumerate(movies_df['ml_movie_id'])}
    # Reverse map for evaluation
    id_to_movie = {v: k for k, v in movie_map.items()}

    user_map = {uid: i for i, uid in enumerate(train_df['user_id'].unique())}
    id_to_user = {v: k for k, v in user_map.items()}

    # Evaluation Mappings (Genres)
    movie_genres_dict = {}
    id_to_genre_name = dict(zip(genres_df['id'], genres_df['name']))

    for _, row in movies_df.iterrows():
        mid = str(row['ml_movie_id'])
        try:
            g_ids = json.loads(row['genres'])
            names = {id_to_genre_name.get(gid, 'unknown') for gid in g_ids}
            movie_genres_dict[mid] = names
        except:
            movie_genres_dict[mid] = set()

    # Bundle all mappings
    mappings = {
        'user_map': user_map,
        'movie_map': movie_map,
        'id_to_user': id_to_user,
        'id_to_movie': id_to_movie,
        'movie_genres': movie_genres_dict
    }

    # --- GRAPH CONSTRUCTION ---
    data = HeteroData()

    # A. Initialize Movie Features (The Context)
    feature_cols = ['year_bucket', 'budget_bucket', 'revenue_bucket', 'popularity_bucket',
                   'vote_avg_bucket', 'vote_count_bucket', 'runtime_bucket']

    num_movies = len(movie_map)
    movie_feat_tensor = torch.zeros((num_movies, len(feature_cols)), dtype=torch.long)

    # Fill tensor based on movie_map indices
    for _, row in movie_ctx_df.iterrows():
        if row['ml_movie_id'] in movie_map:
            idx = movie_map[row['ml_movie_id']]
            vals = row[feature_cols].values.astype(int)
            movie_feat_tensor[idx] = torch.tensor(vals)

    data['movie'].x = movie_feat_tensor
    data['movie'].num_nodes = num_movies
    data['user'].num_nodes = len(user_map) # Important for PyG to know user count

    # B. Build Attribute Nodes & Edges
    def load_aux_edges(filename, node_type, col_name_in_movies):
        df = pd.read_csv(f'nodes/{filename}')
        node_map = {id_: i for i, id_ in enumerate(df['id'])}
        data[node_type].num_nodes = len(node_map)
        # CRITICAL FIX: Initialize .x for attribute nodes with indices
        data[node_type].x = torch.arange(len(node_map), dtype=torch.long)

        src, dst = [], []
        for _, row in movies_df.iterrows():
            if row['ml_movie_id'] not in movie_map: continue
            mid = movie_map[row['ml_movie_id']]
            try:
                item_ids = json.loads(row[col_name_in_movies])
                for iid in item_ids:
                    if iid in node_map:
                        src.append(mid)
                        dst.append(node_map[iid])
            except: continue

        if len(src) > 0:
            data['movie', f'has_{node_type}', node_type].edge_index = torch.tensor([src, dst], dtype=torch.long)
            data[node_type, f'{node_type}_of', 'movie'].edge_index = torch.tensor([dst, src], dtype=torch.long)

    load_aux_edges('genres.csv', 'genre', 'genres')
    load_aux_edges('directors.csv', 'director', 'directors')
    load_aux_edges('keywords.csv', 'keyword', 'keywords')
    load_aux_edges('writers.csv', 'writer', 'writers')

    return data, mappings, ctx_meta, train_df, test_df

In [28]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GATConv, Linear

class ItemGraphEncoder(nn.Module):
    def __init__(self, data_metadata, ctx_metadata, hidden_dim=64, heads=2):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Initial Embeddings (Metadata Injection)
        # Movie Context Embeddings
        self.movie_feat_embs = nn.ModuleList()
        feat_keys = ['year_bucket', 'budget_bucket', 'revenue_bucket', 'popularity_bucket',
                     'vote_avg_bucket', 'vote_count_bucket', 'runtime_bucket']

        for key in feat_keys:
            vocab_size = ctx_metadata['movie_features'][key]
            self.movie_feat_embs.append(nn.Embedding(vocab_size, hidden_dim))

        # Attribute Embeddings (Genre, Director are just indices initially)
        # We learn a vector for every Genre, Director, etc.
        self.attr_embs = nn.ModuleDict()
        # The loop previously here was redundant and has been removed.

        # 2. GNN Layers (The Third Relation Logic)

        # Layer 1: Movies -> Attributes
        # Attributes learn from the movies they contain
        # "Action Genre" becomes a mix of "Terminator" + "Matrix" features
        self.conv1 = HeteroConv({
            ('movie', f'has_{nt}', nt): GATConv(hidden_dim, hidden_dim // heads, heads=heads, add_self_loops=False)
            for nt in ['genre', 'director', 'keyword', 'writer']
        }, aggr='sum')

        # Layer 2: Attributes -> Movies
        # Movies learn from their Enriched Attributes
        # "Terminator" updates based on the "Enriched Action Genre"
        self.conv2 = HeteroConv({
            (nt, f'{nt}_of', 'movie'): GATConv(hidden_dim, hidden_dim // heads, heads=heads, add_self_loops=False)
            for nt in ['genre', 'director', 'keyword', 'writer']
        }, aggr='sum')

        self.movie_lin = nn.Linear(hidden_dim * len(feat_keys), hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.relu = nn.LeakyReLU(0.2)

    def init_node_embs(self, num_nodes_dict):
        """ Initialize learnable embeddings for attribute nodes (Genre, etc.) """
        for ntype, count in num_nodes_dict.items():
            if ntype != 'movie' and ntype != 'user':
                self.attr_embs[ntype] = nn.Embedding(count, self.hidden_dim)

    def forward(self, x_dict, edge_index_dict):
        """
        x_dict: {'movie': [N, 7] (Metadata Indices), 'genre': [M] (Indices)...}
        """

        # Prepare Initial Features
        # Movies: Sum/Concat metadata embeddings
        m_indices = x_dict['movie'] # [N, 7]
        m_emb_list = [emb(m_indices[:, i]) for i, emb in enumerate(self.movie_feat_embs)]
        # Stack and project: [N, 7, Dim] -> [N, Dim]
        h_movie = self.movie_lin(torch.cat(m_emb_list, dim=1))

        # Attributes: Look up learnable embeddings
        h_dict = {'movie': h_movie}
        for ntype, emb in self.attr_embs.items():
            if ntype in x_dict:
                h_dict[ntype] = emb(x_dict[ntype])

        # Message Passing

        # Layer 1: Propagate Movie Info TO Attributes
        # We need to act on edges like ('movie', 'has_genre', 'genre')
        out_l1 = self.conv1(h_dict, edge_index_dict)

        # Apply activation/norm to Attribute updates
        for ntype in out_l1:
            # Ensure that 'movie' is not processed here if it's not a target node type in conv1
            if ntype in h_dict and ntype != 'movie': # Added check for ntype in h_dict
                h_dict[ntype] = self.norm(self.relu(out_l1[ntype])) # Update the h_dict for layer 2

        # Layer 2: Propagate Attribute Info BACK TO Movies
        # Input dictionary is the updated attributes + original movies
        # (We keep original movies to drive the aggregation)
        # h_inputs_l2 = {k: v for k, v in out_l1.items()} # Original code - needs to use updated h_dict

        # Update h_dict with the output from layer 1 for next layer's input
        # Note: out_l1 only contains destination nodes. We need to retain original sources or updated sources.
        # The correct way is to update h_dict for the next layer's input.
        # h_dict already contains initial movie embeddings. Just update the attribute nodes.

        out_l2 = self.conv2(h_dict, edge_index_dict)

        # Final Movie Embedding = Original Context + Graph Context
        # out_l2 only contains 'movie' as its target type
        final_movie = h_movie + out_l2['movie']

        return self.norm(final_movie)


In [29]:

class UserGraphHistoryEncoder(nn.Module):
    def __init__(self, hidden_dim, num_time_buckets=64):
        super().__init__()

        # Interaction Embeddings
        # Rating Embedding (1-5)
        self.rating_emb = nn.Embedding(6, hidden_dim, padding_idx=0)

        # Time Embedding (Relative Days bucketed)
        self.time_emb = nn.Embedding(num_time_buckets, hidden_dim)

        # Attention Mechanism
        # Multi-Head Attention to capture different "views" of history
        self.mha = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=HEADS, batch_first=True)

        # "User Persona" Queries
        # Instead of 1 query, we can have learnable queries or just self-attention
        # Here we use a standard Self-Attention encoder block (History attends to History)
        # followed by a Pooling/Target attention.

        # To strictly follow the "Target-Aware" or "User Query" approach:
        self.user_query = nn.Parameter(torch.randn(1, 1, hidden_dim))

        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, movie_graph_embs, ratings, time_buckets, mask):
        """
        movie_graph_embs: [Batch, Seq, Dim] (From GNN)
        ratings: [Batch, Seq]
        time_buckets: [Batch, Seq]
        """

        # Compose Interaction Vector
        # v = h_movie + e_rating + e_time
        r_vec = self.rating_emb(ratings)
        t_vec = self.time_emb(time_buckets)

        x = movie_graph_embs + r_vec + t_vec

        # Attention
        # User Query Attention ("What defines me?")
        B, S, D = x.shape
        query = self.user_query.expand(B, 1, D)

        # Key Padding Mask: True where padding exists (PyTorch convention)
        # Our 'mask' is usually 1 for valid, 0 for pad. PyTorch MHA expects True for Ignored positions.
        key_padding_mask = (mask == 0)

        # Attention(Q=User, K=History, V=History)
        attn_out, _ = self.mha(query, x, x, key_padding_mask=key_padding_mask)

        user_rep = attn_out.squeeze(1)

        return self.norm(user_rep)

In [30]:

class GraphCodificationPipeline(nn.Module):
    def __init__(self, data_metadata, ctx_metadata, num_nodes_dict, hidden_dim=384):
        super().__init__()

        self.item_gnn = ItemGraphEncoder(data_metadata, ctx_metadata, hidden_dim)
        self.item_gnn.init_node_embs(num_nodes_dict)

        self.user_encoder = UserGraphHistoryEncoder(hidden_dim)

    def forward(self, x_dict, edge_index_dict,
                history_movie_indices, history_ratings, history_times, history_mask):
        """
        1. Run GNN on the FULL Item Graph to get all Movie Embeddings.
        2. Look up the specific movies in the user's history.
        3. Encode the User History.
        """

        # Get ALL Movie Graph Embeddings
        # h_movies_all: [Num_Total_Movies, Dim]
        h_movies_all = self.item_gnn(x_dict, edge_index_dict)

        # Lookup History
        # Flatten batch to index efficiently
        B, S = history_movie_indices.shape
        flat_ids = history_movie_indices.view(-1)

        # Gather embeddings for the sequence
        # Note: history_movie_indices must be aligned with GNN output indices
        seq_embs = h_movies_all[flat_ids].view(B, S, -1)

        # Encode User
        user_graph_emb = self.user_encoder(seq_embs, history_ratings, history_times, history_mask)

        return user_graph_emb

In [31]:

class DeepRecommenderSystem(nn.Module):
    def __init__(self, text_pipeline, graph_pipeline, hidden_dim=1024, head_dim=64):
        super().__init__()

        self.text_pipeline = text_pipeline
        self.graph_pipeline = graph_pipeline

        # --- 1. SELF-ATTENTION FUSION (View merging) ---
        # "Which view matters more? Text or Graph?"
        # We learn a weight alpha to combine them: u = alpha * u_text + (1-alpha) * u_graph
        # Implements logic from Paper

        # Fusion for Users
        self.user_fusion_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # Fusion for Items
        self.item_fusion_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # --- 2. CROSS-ATTENTION (User-Item Interaction) ---
        # "How does this specific user feature relate to this item feature?"
        # Implements logic from Paper
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=HEADS, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_dim)

        # --- 3. FINAL PREDICTION MLP ---
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1) # Output Score
        )

    def get_item_embeddings(self, item_ids, item_graph_x, item_graph_edge_index):
        """
        Extracts and fuses Item embeddings from both pipelines.
        """
        # Text View (From pre-computed buffers in text_pipeline)
        # item_ids needs to be flat for indexing
        # Note: We access the internal encoders directly

        # Text Encoder
        # Retrieve raw inputs from the text pipeline's buffers
        b_ov = self.text_pipeline.static_ov[item_ids]
        b_rev = self.text_pipeline.static_rev[item_ids]
        b_mask = self.text_pipeline.static_rev_mask[item_ids]
        b_has_rev = self.text_pipeline.static_has_rev[item_ids]

        i_text = self.text_pipeline.item_encoder(b_ov, b_rev, b_mask, b_has_rev)

        # Graph Encoder
        # We run the GNN to get ALL embeddings, then select the ones we need
        # Optimization: In inference, we cache 'all_embs', but for training we re-run
        # to get gradients through the GNN.
        all_graph_embs = self.graph_pipeline.item_gnn(item_graph_x, item_graph_edge_index)
        i_graph = all_graph_embs[item_ids]

        # Fuse Views (Self-Attention)
        combined = torch.cat([i_text, i_graph], dim=-1)
        alpha = self.item_fusion_gate(combined)

        i_final = alpha * i_text + (1 - alpha) * i_graph
        return i_final, i_text, i_graph

    def get_user_embeddings(self, batch_hist, batch_graph_data):
        """
        Generates and fuses User embeddings from history.
        batch_hist: Tuple (ids, rates, masks, times)
        batch_graph_data: Tuple (x_dict, edge_index_dict)
        """
        ids, rates, masks, times = batch_hist
        x_dict, edge_idx = batch_graph_data

        # Text Pipeline (User)
        u_text = self.text_pipeline(ids, rates, masks)

        # Graph Pipeline (User)
        u_graph = self.graph_pipeline(x_dict, edge_idx, ids, rates, times, masks)

        # Fuse Views (Self-Attention)
        combined = torch.cat([u_text, u_graph], dim=-1)
        alpha = self.user_fusion_gate(combined)

        u_final = alpha * u_text + (1 - alpha) * u_graph

        return u_final, u_text, u_graph

    def forward(self, user_emb, item_emb):
        """
        Computes the score using Cross-Attention + MLP.
        user_emb: (Batch, Dim)
        item_emb: (Batch, Dim)
        """
        # Cross Attention requires (Batch, Seq, Dim)
        # We treat User and Item as a sequence of length 2?
        # Or simpler: Project User as Query, Item as Key/Value

        u_q = user_emb.unsqueeze(1) # (B, 1, D)
        i_k = item_emb.unsqueeze(1) # (B, 1, D)

        # Interaction: User attends to Item
        attn_out, _ = self.cross_attn(u_q, i_k, i_k)
        attn_out = attn_out.squeeze(1)

        interaction = attn_out * user_emb # Hadamard product to emphasize alignment

        score = self.predictor(interaction)
        return score

    def get_all_embeddings(self, data, hist_ids, hist_rates, hist_times, hist_masks):
        """
        Generates Fused Embeddings for ALL Users and ALL Items for inference.
        """
        self.eval() # Ensure eval mode
        with torch.no_grad():
            all_graph_items = self.graph_pipeline.item_gnn(data.x_dict, data.edge_index_dict)

            # We assume the static buffers in text_pipeline cover all items in order 0..N
            i_ov = self.text_pipeline.static_ov
            i_rev = self.text_pipeline.static_rev
            i_mask = self.text_pipeline.static_rev_mask
            i_has = self.text_pipeline.static_has_rev

            i_text_all = self.text_pipeline.item_encoder(i_ov, i_rev, i_mask, i_has)
            i_comb = torch.cat([i_text_all, all_graph_items], dim=-1)
            i_alpha = self.item_fusion_gate(i_comb)
            i_all_fused = i_alpha * i_text_all + (1 - i_alpha) * all_graph_items

            # Text View (User History -> Text Pipe)
            u_text_all = self.text_pipeline(hist_ids, hist_rates, hist_masks)

            # Graph View (User History -> Graph Pipe)
            # We need the specific graph embeddings for the items in history
            B, S = hist_ids.shape
            flat_ids = hist_ids.view(-1)
            # Lookup from the GNN output we just computed
            seq_graph_embs = all_graph_items[flat_ids].view(B, S, -1)

            u_graph_all = self.graph_pipeline.user_encoder(
                seq_graph_embs, hist_rates, hist_times, hist_masks
            )

            # Fuse Users
            u_comb = torch.cat([u_text_all, u_graph_all], dim=-1)
            u_alpha = self.user_fusion_gate(u_comb)
            u_all_fused = u_alpha * u_text_all + (1 - u_alpha) * u_graph_all

            return u_all_fused, i_all_fused

In [32]:

class HybridLoss(nn.Module):
    def __init__(self, temperature=0.1, aux_weight=0.1):
        super().__init__()
        self.temp = temperature
        self.aux_weight = aux_weight # Weight for the contrastive loss
        self.bpr_loss = nn.LogSigmoid()

    def forward(self, pos_scores, neg_scores, u_text, u_graph):
        """
        pos_scores: Scores for Ground Truth items
        neg_scores: Scores for Negative sampled items
        u_text, u_graph: The separate views of the user (before fusion)
        """

        # Main Task: BPR Loss (Maximize Pos - Neg)
        # Loss = -log(sigmoid(pos - neg))
        loss_main = -torch.mean(self.bpr_loss(pos_scores - neg_scores))

        # Aux Task: Self-Supervised Contrastive Loss
        # We want u_text and u_graph to represent the SAME user.
        # Maximize similarity(u_text, u_graph) for the same user in the batch.

        # Normalize
        u_t_norm = F.normalize(u_text, dim=1)
        u_g_norm = F.normalize(u_graph, dim=1)

        # Cosine Similarity matrix (Batch x Batch)
        logits = torch.matmul(u_t_norm, u_g_norm.T) / self.temp

        # Labels: The diagonal (0, 1, 2...) are the positive pairs
        batch_size = u_text.shape[0]
        labels = torch.arange(batch_size).to(u_text.device)

        # InfoNCE Loss (Cross Entropy)
        loss_aux = F.cross_entropy(logits, labels)

        # Total
        return loss_main + (self.aux_weight * loss_aux), loss_main.item(), loss_aux.item()


In [33]:
class BPRDataset(Dataset):
    def __init__(self, train_df, user_map, movie_map):
        self.data = []
        valid_u = set(user_map.keys())
        valid_m = set(movie_map.keys())
        self.user_hist = defaultdict(set)
        self.all_items = list(movie_map.values())

        pos_df = train_df[train_df['rating'] >= 4]

        for u, m in zip(pos_df['user_id'], pos_df['movie_id']):
            if u in valid_u and m in valid_m:
                uid, mid = user_map[u], movie_map[m]
                self.data.append((uid, mid))
                self.user_hist[uid].add(mid)

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        u, pos = self.data[idx]
        while True:
            neg = np.random.choice(self.all_items)
            if neg not in self.user_hist[u]: break
        return torch.tensor(u), torch.tensor(pos), torch.tensor(neg)

In [34]:
def generate_recommendations(model, data, mappings, train_df, hist_tensors, k=50):
    print("Generating Recommendations...")
    model.eval()

    # Unpack tensors
    hist_ids, hist_rates, hist_times, hist_masks = hist_tensors

    # Compute ALL embeddings
    # u_all is ordered by Tensor Index (0, 1, 2... N)
    u_all, i_all = model.get_all_embeddings(data, hist_ids, hist_rates, hist_times, hist_masks)

    # Build User History for filtering
    # Ensure keys are Integers for consistent lookup
    user_history = train_df.groupby('user_id')['movie_id'].apply(set).to_dict()

    # We will iterate through users by their TENSOR INDEX (0..N)
    # This ensures u_all[0] matches the Real User ID at mappings['id_to_user'][0]
    num_users = len(mappings['user_map'])
    tensor_indices = list(range(num_users))

    id_to_movie = mappings['id_to_movie']
    id_to_user = mappings['id_to_user']

    recommendations = {}
    BATCH = 100

    with torch.no_grad():
        for i in tqdm(range(0, num_users, BATCH)):
            # Batch of User Tensor Indices
            batch_indices = tensor_indices[i : i + BATCH]
            batch_u_emb = u_all[batch_indices] # (Batch, Dim)

            # Score against ALL items
            # Shape: (Batch, Num_Items)
            scores = []
            for u_vec in batch_u_emb:
                # Expand User to match all items: (Num_Items, Dim)
                u_repeated = u_vec.unsqueeze(0).expand(i_all.size(0), -1)

                # Predict
                u_scores = model(u_repeated, i_all).squeeze(-1)
                scores.append(u_scores)

            scores = torch.stack(scores)

            # Rank
            _, top_indices = torch.sort(scores, descending=True)
            top_indices = top_indices.cpu().numpy()

            # Decode to Real IDs
            for j, u_tensor_idx in enumerate(batch_indices):
                # Retrieve Real User ID
                real_u_id_int = id_to_user[u_tensor_idx]
                real_u_id_str = str(real_u_id_int)

                # Retrieve Seen Set (using int key)
                seen = user_history.get(real_u_id_int, set())

                recs = []
                for item_tensor_idx in top_indices[j]:
                    real_item_int = id_to_movie[item_tensor_idx]

                    # Filter 'seen' (Training Data)
                    if real_item_int not in seen:
                        recs.append(str(real_item_int))
                        if len(recs) == k: break

                recommendations[real_u_id_str] = recs

    return recommendations

In [35]:
def prepare_eval_data(train_df, test_df, movies_df, mappings):
    print("Constructing Evaluation Dictionaries...")
    ground_truth = (test_df[test_df['rating'] >= 4]
                    .groupby('user_id')
                    .apply(lambda x: dict(zip(x['movie_id'].astype(str), x['rating'])))
                    .to_dict())
    ground_truth = {str(k): v for k, v in ground_truth.items()}

    item_popularity = train_df.groupby('movie_id')['rating'].count().to_dict()
    item_popularity = {str(k): v for k, v in item_popularity.items()}

    item_features = mappings['movie_genres']
    all_items = set(str(m) for m in mappings['movie_map'].keys())
    return ground_truth, item_popularity, item_features, all_items

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader

print("=== 1. INITIALIZATION ===")

# Load Graph Data & Metadata
data, mappings, ctx_meta, train_df, test_df = load_data_and_build_graph_with_context()
data = data.to(DEVICE)

# Load Text Embeddings
ao, ar, am, has_rev = load_text_embeddings('movie_text_embeddings(2).pt', mappings['movie_map'])

# Initialize Pipelines
# Text Pipeline
text_pipe = TextCodificationPipeline(
    ao.to(DEVICE), ar.to(DEVICE), am.to(DEVICE), has_rev.to(DEVICE), dim=1024
).to(DEVICE)

# Graph Pipeline
# We need the node counts for the GNN initialization
num_nodes_dict = {nt: data[nt].num_nodes for nt in data.node_types}
graph_pipe = GraphCodificationPipeline(
    data.metadata(), ctx_meta, num_nodes_dict, hidden_dim=1024
).to(DEVICE)

# Deep Recommender
model = DeepRecommenderSystem(text_pipe, graph_pipe, hidden_dim=1024).to(DEVICE)


print("\n=== 2. PRE-COMPUTING USER HISTORIES ===")
# We need to turn the variable-length sequences in u1.base into fixed tensors
# that we can look up instantly during the training loop.

# Load raw ratings
ratings_df = pd.read_csv('u1.base', sep='\t', names=['user_id', 'movie_id', 'rating', 'timestamp'])
ratings_df['movie_id'] = ratings_df['movie_id'].map(mappings['movie_map']) # Map to internal IDs
ratings_df = ratings_df.dropna().sort_values(['user_id', 'timestamp'])

# Create Tensors: [Num_Users, Max_Seq_Len]
num_users = len(mappings['user_map'])
max_len = 50

# Initialize with Padding (0)
hist_ids = torch.zeros((num_users, max_len), dtype=torch.long)
hist_rates = torch.zeros((num_users, max_len), dtype=torch.long)
hist_times = torch.zeros((num_users, max_len), dtype=torch.long)
hist_masks = torch.zeros((num_users, max_len), dtype=torch.float32)

# Time Bucketing Helper
# We discretize "Time since first interaction" into 64 buckets
def get_time_buckets(timestamps, n_buckets=64):
    if len(timestamps) < 2: return [0] * len(timestamps)
    # Normalize to 0-1
    t_min, t_max = min(timestamps), max(timestamps)
    if t_max == t_min: return [0] * len(timestamps)

    norm_time = [(t - t_min) / (t_max - t_min) for t in timestamps]
    # Scale to 0-(n-1)
    return [int(t * (n_buckets - 1)) for t in norm_time]

print("Building history tensors...")
for uid, group in tqdm(ratings_df.groupby('user_id')):
    if uid not in mappings['user_map']: continue
    u_idx = mappings['user_map'][uid]

    # Get sequences
    mids = group['movie_id'].values.astype(int)[-max_len:]
    rates = group['rating'].values.astype(int)[-max_len:]
    times = get_time_buckets(group['timestamp'].values)[-max_len:]

    seq_len = len(mids)

    # Fill tensors (Left padding or Right padding? Attention usually expects padding at the end if batch_first=True)
    # Let's fill from the beginning for simplicity with batch_first
    hist_ids[u_idx, :seq_len] = torch.tensor(mids)
    hist_rates[u_idx, :seq_len] = torch.tensor(rates)
    hist_times[u_idx, :seq_len] = torch.tensor(times)
    hist_masks[u_idx, :seq_len] = 1.0

# Move to GPU for fast lookup
hist_ids = hist_ids.to(DEVICE)
hist_rates = hist_rates.to(DEVICE)
hist_times = hist_times.to(DEVICE)
hist_masks = hist_masks.to(DEVICE)

print("\n=== 3. TRAINING LOOP (WITH LEAKAGE FIX) ===")
dataset = BPRDataset(train_df, mappings['user_map'], mappings['movie_map'])
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

optimizer = optim.Adam(model.parameters(), lr=LR)
loss_fn = HybridLoss(aux_weight=0.1).to(DEVICE)

loss_history = []

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    total_main = 0
    total_aux = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for u_idx, pos_idx, neg_idx in pbar:
        u_idx, pos_idx, neg_idx = u_idx.to(DEVICE), pos_idx.to(DEVICE), neg_idx.to(DEVICE)

        # ---------------------------------------------------------
        # STEP 1: GLOBAL GRAPH UPDATE
        # ---------------------------------------------------------
        all_graph_items = model.graph_pipeline.item_gnn(data.x_dict, data.edge_index_dict)

        # ---------------------------------------------------------
        # STEP 2: USER EMBEDDINGS (Two Views)
        # ---------------------------------------------------------
        b_h_ids = hist_ids[u_idx].clone()
        b_h_rates = hist_rates[u_idx].clone()
        b_h_times = hist_times[u_idx] # No clone needed if read-only
        b_h_mask = hist_masks[u_idx].clone()

        mask_prob = torch.rand_like(b_h_ids.float())
        bert_mask = (mask_prob < 0.1) & (b_h_mask == 1)

        # Hide 10% of user's liked films to treat them as targets
        b_h_ids[bert_mask] = 0
        b_h_mask[bert_mask] = 0
        b_h_rates[bert_mask] = 0

        # Text Pipeline (User)
        u_text = model.text_pipeline(b_h_ids, b_h_rates, b_h_mask)

        # Graph Pipeline (User)
        B, S = b_h_ids.shape
        flat_hist_ids = b_h_ids.view(-1)
        seq_graph_embs = all_graph_items[flat_hist_ids].view(B, S, -1)

        u_graph = model.graph_pipeline.user_encoder(
            seq_graph_embs, b_h_rates, b_h_times, b_h_mask
        )

        # Fusion
        u_combined = torch.cat([u_text, u_graph], dim=-1)
        u_alpha = model.user_fusion_gate(u_combined)
        u_final = u_alpha * u_text + (1 - u_alpha) * u_graph

        # ---------------------------------------------------------
        # STEP 3: ITEM EMBEDDINGS (Pos & Neg)
        # ---------------------------------------------------------
        def get_item_rep(ids):
            i_txt_ov = model.text_pipeline.static_ov[ids]
            i_txt_rev = model.text_pipeline.static_rev[ids]
            i_txt_mask = model.text_pipeline.static_rev_mask[ids]
            i_txt_has = model.text_pipeline.static_has_rev[ids]
            i_text_emb = model.text_pipeline.item_encoder(i_txt_ov, i_txt_rev, i_txt_mask, i_txt_has)
            i_graph_emb = all_graph_items[ids]
            comb = torch.cat([i_text_emb, i_graph_emb], dim=-1)
            alpha = model.item_fusion_gate(comb)
            return alpha * i_text_emb + (1 - alpha) * i_graph_emb

        pos_items_final = get_item_rep(pos_idx)
        neg_items_final = get_item_rep(neg_idx)

        # ---------------------------------------------------------
        # STEP 4: PREDICTION & LOSS
        # ---------------------------------------------------------

        pos_scores = model(u_final, pos_items_final)
        neg_scores = model(u_final, neg_items_final)

        loss, l_main, l_aux = loss_fn(pos_scores, neg_scores, u_text, u_graph)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_main += l_main
        total_aux += l_aux

        pbar.set_postfix({'Loss': loss.item(), 'Main': l_main, 'Aux': l_aux})

    avg_loss = total_loss / len(loader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")

=== 1. INITIALIZATION ===
Loading Graph Data with Context...
Loading text embeddings from movie_text_embeddings(2).pt...
Aligned 1638 movies out of 1638 raw embeddings.


  self.conv1 = HeteroConv({
  self.conv2 = HeteroConv({



=== 2. PRE-COMPUTING USER HISTORIES ===
Building history tensors...


100%|██████████| 943/943 [00:00<00:00, 1531.64it/s]



=== 2. TRAINING LOOP (WITH LEAKAGE FIX) ===


Epoch 1/10: 100%|██████████| 86/86 [00:48<00:00,  1.77it/s, Loss=0.509, Main=7.21e-5]


Epoch 1 Avg Loss: 0.6459


Epoch 2/10: 100%|██████████| 86/86 [00:48<00:00,  1.77it/s, Loss=0.469, Main=0.000487]


Epoch 2 Avg Loss: 0.5885


Epoch 3/10: 100%|██████████| 86/86 [00:49<00:00,  1.75it/s, Loss=0.426, Main=0.00274]


Epoch 3 Avg Loss: 0.5588


Epoch 4/10: 100%|██████████| 86/86 [00:48<00:00,  1.77it/s, Loss=0.354, Main=0.00278]


Epoch 4 Avg Loss: 0.4724


Epoch 5/10: 100%|██████████| 86/86 [00:48<00:00,  1.77it/s, Loss=0.318, Main=0.0151]


Epoch 5 Avg Loss: 0.4272


Epoch 6/10: 100%|██████████| 86/86 [00:48<00:00,  1.77it/s, Loss=0.273, Main=0.0131]


Epoch 6 Avg Loss: 0.3843


Epoch 7/10: 100%|██████████| 86/86 [00:48<00:00,  1.78it/s, Loss=0.273, Main=0.0022]


Epoch 7 Avg Loss: 0.3713


Epoch 8/10: 100%|██████████| 86/86 [00:48<00:00,  1.76it/s, Loss=0.256, Main=0.0188]


Epoch 8 Avg Loss: 0.3491


Epoch 9/10: 100%|██████████| 86/86 [00:48<00:00,  1.76it/s, Loss=0.234, Main=0.0114]


Epoch 9 Avg Loss: 0.3261


Epoch 10/10:   6%|▌         | 5/86 [00:02<00:45,  1.79it/s, Loss=0.336, Main=0.0048]