In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import gc

from typing import Tuple
from sklearn.model_selection import train_test_split


In [None]:
# Constants
NUM_SCIENTISTS = 10000
NUM_PAPERS = 1000
SEED = 42


# Hyperparameters
# data
SID_WISHLIST_SIZE = 34
SID_CONTEXT_SIZE = 30

PID_WISHLIST_SIZE = 50
PID_CONTEXT_SIZE = 100


# model
EMBEDDING_DIM = 16
DROPOUT_RATE = 0.3
NUM_HEADS = 4

# training
L2_REG = 1e-4
LEARNING_RATE= 1e-3
BATCH_SIZE = 64
EPOCHS = 10


# Set seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)





In [None]:
DATA_DIR = "data/"

def read_data_df() -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Reads in training data and splits it into training and validation sets with a 75/25 split.
    Reads in wishlists and returns as data frame.
    """

    df = pd.read_csv(os.path.join(DATA_DIR, "train_ratings.csv"))

    # Split sid_pid into sid and pid columns
    df[["sid", "pid"]] = df["sid_pid"].str.split("_", expand=True)
    df = df.drop("sid_pid", axis=1)
    df["sid"] = df["sid"].astype(int)
    df["pid"] = df["pid"].astype(int)

    # Split into train and validation dataset
    train_df, valid_df = train_test_split(df, test_size=0.25)

    # read wishlist 
    tbr_df = pd.read_csv(os.path.join(DATA_DIR, "train_tbr.csv"))

    return train_df, valid_df, tbr_df

def save_predictions_csv(model, train_loader, valid_loader, device, train_output_file: str, valid_output_file: str):
    """
    Generates predictions from a trained model on both training and validation sets,
    clamps and rounds the outputs to the nearest integer between 1 and 5,
    and saves them as CSV files.
    """
    model.eval()  # Set model to evaluation mode

    def collect_predictions(loader):
        """Run inference and collect predictions for a given DataLoader."""
        results = []

        for sid, pid, rating, context, wishlist in loader:
            # Move inputs to the target device
            sid, pid, rating = sid.to(device), pid.to(device), rating.to(device)
            context, wishlist = context.to(device), wishlist.to(device)

            with torch.no_grad():
                # Run model prediction
                predictions = model(sid, pid, context, wishlist)

            # Collect results as list of dictionaries
            for s, p, pred in zip(sid.cpu(), pid.cpu(), predictions.cpu()):
                results.append({
                    "sid": s.item(),
                    "pid": p.item(),
                    "predicted": pred.item()
                })

        return results

    # Generate and save predictions for training data
    train_results = collect_predictions(train_loader)
    pd.DataFrame(train_results).to_csv(train_output_file, index=False)
    print(f"Train predictions saved to {train_output_file}")

    # Generate and save predictions for validation data
    valid_results = collect_predictions(valid_loader)
    pd.DataFrame(valid_results).to_csv(valid_output_file, index=False)
    print(f"Validation predictions saved to {valid_output_file}")

## Build Datasets

From each entry $(sid, pid, rating)$ in the original dataset, we generate two new entries: one for the **sid-dataset** and one for the **pid-dataset**. The **sid-dataset** entry is enriched by adding a **context** and a **wishlist** as follows:

* **Context:**
  The context contains up to `context_size` pairs $(pid, rating)$ representing papers that the same scientist ($sid$) has previously rated. These pairs are selected randomly. If there are fewer than `context_size` entries available, the remaining slots are filled with the padding value $(-1, 0)$.

* **Wishlist:**
  The wishlist includes up to `wishlist_size` paper IDs that are currently on the scientist’s ($sid$'s) wishlist. If there are not enough papers to fill the wishlist, the padding value $-1$ is used for the remaining entries.



Similarly, it is done for the **pid-dataset**:

* **Context:**
  The context includes up to `context_size` pairs $(sid, rating)$, where each pair represents a scientist who has rated the paper `pid`. These are selected randomly from the available ratings. If there are fewer than `context_size` such entries, the padding value $(-1, 0)$ is used to fill the remaining slots.

* **Wishlist:**
  The wishlist consists of up to `wishlist_size` scientist IDs who have the paper `pid` on their wishlist. If there are not enough such scientists, the remaining entries are padded with the value $-1$.





In [None]:
def get_dataset_sid(df: pd.DataFrame, wishlist_df: pd.DataFrame, context_size: int, wishlist_size: int, save_path: str) -> torch.utils.data.Dataset:
    """
    Constructs a PyTorch Dataset for each (sid, pid, rating) entry, enriched with:
    - A fixed-size context of (pid, rating) tuples rated by the same scientist (sid).
    - A fixed-size wishlist of paper IDs on the scientist's wishlist.

    The dataset is cached at `save_path` if it exists.
    """

    if os.path.exists(save_path):
        print(f"Loading dataset from {save_path}")
        return torch.load(save_path)

    # Extract data as tensors
    sids = torch.from_numpy(df["sid"].to_numpy())
    pids = torch.from_numpy(df["pid"].to_numpy())
    ratings = torch.from_numpy(df["rating"].to_numpy()).float()

    # Build mappings for context and wishlist
    sid_to_context = df.groupby("sid")[["pid", "rating"]].apply(
        lambda x: list(zip(x["pid"], x["rating"]))
    ).to_dict()

    sid_to_wishlist = wishlist_df.groupby("sid")["pid"].apply(list).to_dict()

    sid_context = []
    sid_wishlist = []

    # Construct dataset entries
    for sid_val, pid_val in zip(sids, pids):
        sid = sid_val.item()
        pid = pid_val.item()

        # Context: (pid, rating) tuples rated by this sid
        context = [(p, r) for p, r in sid_to_context[sid] if p != pid]

        # Pad or sample context to fixed size
        if len(context) >= context_size:
            context = random.sample(context, context_size)
        else:
            context += [(-1, 0.0)] * (context_size - len(context))

        sid_context.append(torch.tensor(context, dtype=torch.int))

        # Wishlist: papers on sid's wishlist
        wishlist = sid_to_wishlist.get(sid, [])
        if pid in wishlist:
            wishlist.remove(pid)

        if len(wishlist) >= wishlist_size:
            wishlist = random.sample(wishlist, wishlist_size)
        else:
            wishlist += [-1] * (wishlist_size - len(wishlist))

        sid_wishlist.append(torch.tensor(wishlist, dtype=torch.int))

    # Stack all examples into tensors
    sid_context = torch.stack(sid_context)
    sid_wishlist = torch.stack(sid_wishlist)

    # Create dataset and save
    dataset = torch.utils.data.TensorDataset(sids, pids, ratings, sid_context, sid_wishlist)
    torch.save(dataset, save_path)

    return dataset


def get_dataset_pid(df: pd.DataFrame, wishlist_df: pd.DataFrame, context_size: int, wishlist_size: int, save_path: str) -> torch.utils.data.Dataset:
    """
    Constructs a PyTorch Dataset for each (sid, pid, rating) entry, enriched with:
    - A fixed-size context of (sid, rating) tuples representing scientists who rated the same paper (pid).
    - A fixed-size wishlist of sids who have the paper (pid) on their wishlist.

    The dataset is cached at `save_path` if it exists.
    """

    if os.path.exists(save_path):
        print(f"Loading dataset from {save_path}")
        return torch.load(save_path)

    # Extract data as tensors
    sids = torch.from_numpy(df["sid"].to_numpy())
    pids = torch.from_numpy(df["pid"].to_numpy())
    ratings = torch.from_numpy(df["rating"].to_numpy()).float()

    # Build mappings for context and wishlist
    pid_to_context = df.groupby("pid")[["sid", "rating"]].apply(
        lambda x: list(zip(x["sid"], x["rating"]))
    ).to_dict()

    pid_to_wishlist = wishlist_df.groupby("pid")["sid"].apply(list).to_dict()

    pid_context = []
    pid_wishlist = []

    # Construct dataset entries
    for sid_val, pid_val in zip(sids, pids):
        sid = sid_val.item()
        pid = pid_val.item()

        # Context: (sid, rating) tuples from other sids who rated the same pid
        context = [(s, r) for s, r in pid_to_context[pid] if s != sid]

        if len(context) >= context_size:
            context = random.sample(context, context_size)
        else:
            context += [(-1, 0.0)] * (context_size - len(context))

        pid_context.append(torch.tensor(context, dtype=torch.int))

        # Wishlist: sids who wishlisted this pid
        wishlist = pid_to_wishlist.get(pid, [])
        if sid in wishlist:
            wishlist.remove(sid)

        if len(wishlist) >= wishlist_size:
            wishlist = random.sample(wishlist, wishlist_size)
        else:
            wishlist += [-1] * (wishlist_size - len(wishlist))

        pid_wishlist.append(torch.tensor(wishlist, dtype=torch.int))

    # Stack all examples into tensors
    pid_context = torch.stack(pid_context)
    pid_wishlist = torch.stack(pid_wishlist)

    # Create dataset and save
    dataset = torch.utils.data.TensorDataset(sids, pids, ratings, pid_context, pid_wishlist)
    torch.save(dataset, save_path)

    return dataset


### Combining SID and PID Predictions

This function constructs a PyTorch dataset by merging predictions from both the SID-based and PID-based models with the original ratings. Each entry in the resulting dataset contains:
- `sid`: Scientist ID
- `pid`: Paper ID
- `rating`: Actual rating provided
- `predicted_sid`: Prediction from the SID-based model
- `predicted_pid`: Prediction from the PID-based model

The final dataset is saved for efficient reuse.


In [None]:
def get_dataset_combined(rating_file: str, pred_sid_file: str, pred_pid_file: str, save_path: str) -> torch.utils.data.Dataset:
    """
    Combines predictions from SID- and PID-based models with ground truth ratings
    into a PyTorch TensorDataset and caches the result.
    """
    full_save_path = os.path.join(DATA_DIR, save_path)

    if os.path.exists(full_save_path):
        print(f"Loading dataset from {full_save_path}")
        return torch.load(full_save_path)

    # Load CSVs
    rating_df = pd.read_csv(os.path.join(DATA_DIR, rating_file))
    pred_sid_df = pd.read_csv(os.path.join(DATA_DIR, pred_sid_file))
    pred_pid_df = pd.read_csv(os.path.join(DATA_DIR, pred_pid_file))

    # Extract 'sid' and 'pid' from 'sid_pid' column in rating_df
    rating_df[['sid', 'pid']] = rating_df['sid_pid'].str.split("_", expand=True)
    rating_df = rating_df.drop(columns=["sid_pid"])
    rating_df["sid"] = rating_df["sid"].astype(int)
    rating_df["pid"] = rating_df["pid"].astype(int)

    # Merge prediction files on (sid, pid)
    merged_df = pd.merge(pred_sid_df, pred_pid_df, on=["sid", "pid"], suffixes=('_sid', '_pid'))

    # Merge with actual ratings
    merged_df = pd.merge(merged_df, rating_df, on=["sid", "pid"])

    # Convert to tensors
    sids = torch.tensor(merged_df["sid"].values, dtype=torch.int)
    pids = torch.tensor(merged_df["pid"].values, dtype=torch.int)
    ratings = torch.tensor(merged_df["rating"].values, dtype=torch.float)
    preds_sid = torch.tensor(merged_df["predicted_sid"].values, dtype=torch.float)
    preds_pid = torch.tensor(merged_df["predicted_pid"].values, dtype=torch.float)

    # Create TensorDataset
    dataset = torch.utils.data.TensorDataset(sids, pids, ratings, preds_sid, preds_pid)

    # Save for reuse
    torch.save(dataset, full_save_path)
    print(f"Combined dataset saved to {full_save_path}")

    return dataset


## Attention-Based Recommender Models

These two models implement attention-based collaborative filtering using both scientist (SID) and paper (PID) embeddings:

- `AttentionRecommenderSID`: Focuses on attention over past paper ratings of each scientist.
- `AttentionRecommenderPID`: Focuses on attention over scientist interactions for each paper.

Each model:
- Embeds users, items, and ratings.
- Uses multi-head self-attention to integrate contextual and wishlist signals.
- Applies a gating mechanism to combine multiple attention outputs.
- Predicts a final rating score using a two-layer MLP head.


In [None]:
class AttentionRecommenderSID(nn.Module):
    """
    Attention-based recommender that focuses on scientist behavior (SID-centric).
    """

    def __init__(self, num_sids: int, num_pids: int, emb_dim: int, dropout_rate: float, num_heads: int):
        super().__init__()

        # Embedding layers
        self.sid_embedding = nn.Embedding(num_sids, emb_dim)
        self.pid_embedding = nn.Embedding(num_pids + 1, emb_dim, padding_idx=-1)
        self.rating_embedding = nn.Embedding(6, emb_dim, padding_idx=0)

        # Layer norms for stabilization
        self.norm_sid = nn.LayerNorm(emb_dim)
        self.norm_pid = nn.LayerNorm(emb_dim)
        self.norm_context_pid = nn.LayerNorm(emb_dim)
        self.norm_rating = nn.LayerNorm(emb_dim)
        self.output_norm = nn.LayerNorm(emb_dim)

        # Attention layers
        self.context_attention = nn.MultiheadAttention(emb_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.wishlist_attention = nn.MultiheadAttention(emb_dim, num_heads, dropout=dropout_rate, batch_first=True)

        # Projection layers
        self.proj_attn_1 = nn.Linear(emb_dim, emb_dim)
        self.proj_attn_2 = nn.Linear(emb_dim, emb_dim)

        self.dropout = nn.Dropout(dropout_rate)

        # Gating and final prediction layers
        self.gate_layer = nn.Linear(2 * emb_dim, emb_dim)
        self.fc1 = nn.Linear(emb_dim + 1, emb_dim)
        self.fc2 = nn.Linear(emb_dim, 1)

    def forward(self, sid, pid, context_pid_ratings, wishlist):
        """
        Args:
            sid: (B,) 
            pid: (B,) 
            context_pid_ratings: (B, K, 2) where each entry is (pid, rating)
            wishlist: (B, W) 
        Returns:
            Predicted rating: (B,)
        """
        context_pids = context_pid_ratings[:, :, 0].long()      # (B, K)
        context_ratings = context_pid_ratings[:, :, 1].long()   # (B, K)

        # Embedding lookups + normalization
        sid_embed = self.norm_sid(self.sid_embedding(sid))      # (B, D)
        pid_embed = self.norm_pid(self.pid_embedding(pid))      # (B, D)
        wishlist_embed = self.norm_pid(self.pid_embedding(wishlist))    # (B, W, D)
        context_pid_embed = self.norm_context_pid(self.pid_embedding(context_pids))         # (B, K, D)
        context_rating_embed = self.norm_rating(self.rating_embedding(context_ratings))     # (B, K, D)

        # Multi-head attention over context (with queries: pid + wishlist)
        query = torch.cat([pid_embed.unsqueeze(1), wishlist_embed], dim=1)      # (B, W+1, D)
        context_attn, _ = self.context_attention(query, context_pid_embed, context_rating_embed)    # (B, W+1, D)
        context_attn = self.output_norm(self.proj_attn_1(context_attn))         # (B, W+1, D)

        rating_pid = context_attn[:, 0, :]          # (B, D)
        ratings_wishlist = context_attn[:, 1:, :]   # (B, W, D)

        # Attention over wishlist
        rating_wishlist, _ = self.wishlist_attention(pid_embed.unsqueeze(1), wishlist_embed, ratings_wishlist)  # (B, 1, D)
        rating_wishlist = self.proj_attn_2(rating_wishlist.squeeze(1))      # (B, D)

        # Fusion via gating
        gate = torch.sigmoid(self.gate_layer(torch.cat([sid_embed, pid_embed], dim=1)))
        fused = gate * rating_pid + (1 - gate) * rating_wishlist        # (B, D)

        bias = torch.sum(sid_embed * pid_embed, dim=-1, keepdim=True)   # (B,)

        x = torch.cat([bias, fused], dim=1)         # (B, D+1)
        x = F.relu(self.fc1(self.dropout(x)))       # (B, D)
        x = self.fc2(self.dropout(x)).squeeze(1)    # (B,)

        return x


class AttentionRecommenderPID(nn.Module):
    """
    Attention-based recommender that focuses on product behavior (PID-centric).
    """

    def __init__(self, num_sids: int, num_pids: int, emb_dim: int, dropout_rate: float, num_heads: int):
        super().__init__()

        # Embedding layers
        self.sid_embedding = nn.Embedding(num_sids + 1, emb_dim, padding_idx=-1)
        self.pid_embedding = nn.Embedding(num_pids, emb_dim)
        self.rating_embedding = nn.Embedding(6, emb_dim, padding_idx=0)

        # Layer norms for stabilization
        self.norm_sid = nn.LayerNorm(emb_dim)
        self.norm_pid = nn.LayerNorm(emb_dim)
        self.norm_context_sid = nn.LayerNorm(emb_dim)
        self.norm_rating = nn.LayerNorm(emb_dim)
        self.output_norm = nn.LayerNorm(emb_dim)

        # Attention layers
        self.context_attention = nn.MultiheadAttention(emb_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.wishlist_attention = nn.MultiheadAttention(emb_dim, num_heads, dropout=dropout_rate, batch_first=True)

        # Projection layers
        self.proj_attn_1 = nn.Linear(emb_dim, emb_dim)
        self.proj_attn_2 = nn.Linear(emb_dim, emb_dim)
        
        self.dropout = nn.Dropout(dropout_rate)

        # Gating and final prediction layers
        self.gate_layer = nn.Linear(2 * emb_dim, emb_dim)
        self.fc1 = nn.Linear(emb_dim + 1, emb_dim)
        self.fc2 = nn.Linear(emb_dim, 1)

    def forward(self, sid, pid, context_pid_ratings, wishlist):
        """
        Args:
            sid: (B,)
            pid: (B,)
            context_pid_ratings: (B, K, 2) where each entry is (sid, rating)
            wishlist: (B, W) 
        Returns:
            Predicted rating: (B,)
        """
        context_sids = context_pid_ratings[:, :, 0].long()      # (B, K)
        context_ratings = context_pid_ratings[:, :, 1].long()   # (B, K)

        # Embedding lookups + normalization
        sid_embed = self.norm_sid(self.sid_embedding(sid))      # (B, D)
        pid_embed = self.norm_pid(self.pid_embedding(pid))      # (B, D)
        wishlist_embed = self.norm_sid(self.sid_embedding(wishlist))    # (B, W, D)
        context_sid_embed = self.norm_context_sid(self.sid_embedding(context_sids))         # (B, K, D)
        context_rating_embed = self.norm_rating(self.rating_embedding(context_ratings))     # (B, K, D)

        # Multi-head attention over context (with queries: pid + wishlist)
        query = torch.cat([sid_embed.unsqueeze(1), wishlist_embed], dim=1)      # (B, W+1, D)
        context_attn, _ = self.context_attention(query, context_sid_embed, context_rating_embed)    # (B, W+1, D)
        context_attn = self.output_norm(self.proj_attn_1(context_attn))         # (B, W+1, D)

        rating_sid = context_attn[:, 0, :]          # (B, D)
        ratings_wishlist = context_attn[:, 1:, :]   # (B, W, D)

        # Attention over wishlist
        rating_wishlist, _ = self.wishlist_attention(sid_embed.unsqueeze(1), wishlist_embed, ratings_wishlist)  # (B, 1, D)
        rating_wishlist = self.proj_attn_2(rating_wishlist.squeeze(1))      # (B, D)

        # Fusion via gating
        gate = torch.sigmoid(self.gate_layer(torch.cat([sid_embed, pid_embed], dim=1)))
        fused = gate * rating_sid + (1 - gate) * rating_wishlist    # (B, D)

        bias = torch.sum(sid_embed * pid_embed, dim=-1, keepdim=True)   # (B,)
        
        x = torch.cat([bias, fused], dim=1)         # (B, D+1)
        x = F.relu(self.fc1(self.dropout(x)))       # (B, D)
        x = self.fc2(self.dropout(x)).squeeze(1)    # (B,)

        return x


In [None]:
def train_model(model, optim, device, epochs, train_loader, valid_loader):
    """
    Trains the model using MSE loss and evaluates it on validation data
    after each epoch. Outputs training loss and validation RMSE.
    """

    gc.collect()
    torch.mps.empty_cache()

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        total_data = 0

        # Training loop
        for sid, pid, rating, context, wishlist in train_loader:
            # Move data to the target device
            sid, pid, rating = sid.to(device), pid.to(device), rating.to(device)
            context, wishlist = context.to(device), wishlist.to(device)

            # Forward pass and compute loss
            predictions = model(sid, pid, context, wishlist)
            loss = F.mse_loss(predictions, rating)

            # Backpropagation
            optim.zero_grad()
            loss.backward()
            optim.step()

            # Accumulate loss
            total_loss += loss.item() * sid.size(0)
            total_data += sid.size(0)

        # Validation loop
        model.eval()
        total_val_loss = 0.0
        total_val_data = 0

        with torch.no_grad():
            for sid, pid, rating, context, wishlist in valid_loader:
                sid, pid, rating = sid.to(device), pid.to(device), rating.to(device)
                context, wishlist = context.to(device), wishlist.to(device)

                predictions = model(sid, pid, context, wishlist)
                mse = F.mse_loss(predictions, rating)

                total_val_loss += mse.item() * sid.size(0)
                total_val_data += sid.size(0)

        # Calculate training and validation RMSE
        train_rmse = (total_loss / total_data) ** 0.5
        val_rmse = (total_val_loss / total_val_data) ** 0.5

        print(f"[Epoch {epoch + 1}/{epochs}] Train RMSE={train_rmse:.3f}, Valid RMSE={val_rmse:.3f}")


`CrossLayer`: A single layer of feature crossing, based on:

$$
\mathbf{x}_{l+1} = \mathbf{x}_0 \cdot (\mathbf{w}^T \mathbf{x}_l) + \mathbf{b} + \mathbf{x}_l
$$

* Captures pairwise feature interactions.
* Used in sequence to model explicit cross terms.


`RecommenderFinal`: It combines: 
* **Embeddings** for `sid` and `pid`.
* **Cross Layers** for explicit feature interactions.
* **Deep MLP** for non-linear modeling.

As inputs it uses the predictions of the previous two attention based recommenders.
* **Inputs**: `sid`, `pid`, `pred_sid`, `pred_pid`
* **Output**: Rating prediction ∈ \[1, 5]


In [None]:
class CrossLayer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        # Trainable parameters for the cross layer
        self.weight = nn.Parameter(torch.randn(input_dim))  # (D,)
        self.bias = nn.Parameter(torch.randn(input_dim))    # (D,)

    def forward(self, x0, x):
        """
        Cross interaction layer:
        x_{l+1} = x0 * (w^T x) + b + x
        """
        xw = torch.sum(x * self.weight, dim=1, keepdim=True)  # (B, 1)
        out = x0 * xw + self.bias + x    # (B, D)
        return out


class RecommenderFinal(nn.Module):
    def __init__(self, num_sids, num_pids, emb_dim, hidden_dim, num_cross_layers, dropout_rate):
        super().__init__()
        
        # Embedding layers for SID and PID
        self.sid_embedding = nn.Embedding(num_sids, emb_dim)
        self.pid_embedding = nn.Embedding(num_pids, emb_dim)

        # Input features: sid_emb + pid_emb + pred_sid + pred_pid
        input_dim = 2 * emb_dim + 2

        # Cross layers to model explicit feature interactions
        self.cross_layers = nn.ModuleList([
            CrossLayer(input_dim) for _ in range(num_cross_layers)
        ])

        # DNN
        self.deep = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)  # Final regression output
        )

    def forward(self, sid, pid, pred_sid, pred_pid):
        """
        Forward pass through the final hybrid recommender.
        """
        sid_emb = self.sid_embedding(sid)  # (B, D)
        pid_emb = self.pid_embedding(pid)  # (B, D)

        # Concatenate raw features and intermediate predictions
        features = torch.cat([sid_emb, pid_emb, pred_sid.unsqueeze(1), pred_pid.unsqueeze(1)], dim=1)  # (B, 2D+2)

        # Pass through Cross Network
        x = features
        for layer in self.cross_layers:
            x = layer(features, x)

        # Pass through DNN
        out = self.deep(x).squeeze(-1)

        # Final output clamped to rating range
        out = torch.clamp(out, 1, 5)

        return out


## Model instantiation and training

In [None]:
device = torch.device("mps")

model_sid = AttentionRecommenderSID(num_sids=NUM_SCIENTISTS, num_pids=NUM_PAPERS, emb_dim=EMBEDDING_DIM, dropout_rate=DROPOUT_RATE, num_heads=NUM_HEADS).to(device)
optim_sid = torch.optim.Adam(model_sid.parameters(), lr=LEARNING_RATE, weight_decay=L2_REG)

model_pid = AttentionRecommenderPID(num_sids=NUM_SCIENTISTS, num_pids=NUM_PAPERS, emb_dim=EMBEDDING_DIM, dropout_rate=DROPOUT_RATE, num_heads=NUM_HEADS).to(device)
optim_pid = torch.optim.Adam(model_pid.parameters(), lr=LEARNING_RATE, weight_decay=L2_REG)

In [None]:
train_df, valid_df, tbr_df = read_data_df()

In [None]:
train_dataset_sid = get_dataset_sid(df=train_df, wishlist_df=tbr_df, context_size=SID_CONTEXT_SIZE, wishlist_size=SID_WISHLIST_SIZE, save_path="data/sid_train_dataset")
valid_dataset_sid = get_dataset_sid(df=valid_df, wishlist_df=tbr_df, context_size=SID_CONTEXT_SIZE, wishlist_size=SID_WISHLIST_SIZE, save_path="data/sid_valid_dataset")

train_loader_sid = torch.utils.data.DataLoader(train_dataset_sid, batch_size=BATCH_SIZE, shuffle=True)
valid_loader_sid = torch.utils.data.DataLoader(valid_dataset_sid, batch_size=BATCH_SIZE, shuffle=False)


train_dataset_pid = get_dataset_pid(df=train_df, wishlist_df=tbr_df, context_size=PID_CONTEXT_SIZE, wishlist_size=PID_WISHLIST_SIZE, save_path="data/pid_train_dataset")
valid_dataset_pid = get_dataset_pid(df=valid_df, wishlist_df=tbr_df, context_size=PID_CONTEXT_SIZE, wishlist_size=PID_WISHLIST_SIZE, save_path="data/pid_valid_dataset")

train_loader_pid = torch.utils.data.DataLoader(train_dataset_pid, batch_size=BATCH_SIZE, shuffle=True)
valid_loader_pid = torch.utils.data.DataLoader(valid_dataset_pid, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
train_model(model_sid, optim_sid, device, EPOCHS, train_loader_sid, valid_loader_sid)
torch.save(model_sid.state_dict(), 'attention_recommender_sid.pt')

save_predictions_csv(model_sid, train_loader_sid, valid_loader_sid, device, train_output_file="data/sid_train_predictions.csv", valid_output_file="data/sid_valid_predictions.csv")

In [None]:
train_model(model_pid, optim_pid, device, EPOCHS, train_loader_pid, valid_loader_pid)
torch.save(model_pid.state_dict(), 'attention_recommender_pid.pt')

save_predictions_csv(model_pid, train_loader_pid, valid_loader_pid, device, train_output_file="data/pid_train_predictions.csv", valid_output_file="data/pid_valid_predictions.csv")

In [None]:
train_dataset_combined = get_dataset_combined(rating_file="train_ratings.csv", pred_sid_file="sid_train_predictions.csv", pred_pid_file="pid_train_predictions.csv", save_path="combined_train_dataset")
valid_dataset_combined = get_dataset_combined(rating_file="train_ratings.csv", pred_sid_file="sid_valid_predictions.csv", pred_pid_file="pid_valid_predictions.csv", save_path="combined_valid_dataset")

train_loader_combined = torch.utils.data.DataLoader(train_dataset_combined, batch_size=BATCH_SIZE, shuffle=True)
valid_loader_combined = torch.utils.data.DataLoader(valid_dataset_combined, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model_combined = RecommenderFinal(num_sids=NUM_SCIENTISTS, num_pids=NUM_PAPERS, emb_dim=64, hidden_dim=128, num_cross_layers=2, dropout_rate=0.2).to(device)
optim_combined = torch.optim.Adam(model_combined.parameters(), lr=LEARNING_RATE, weight_decay=L2_REG)

In [None]:
gc.collect()
torch.mps.empty_cache()

for epoch in range(EPOCHS):
    # Train model for an epoch
    total_loss = 0.0
    total_data = 0
    model_combined.train()
    for sid, pid, rating, pred_sid, pred_pid in train_loader_combined:
        # Move data to GPU
        sid, pid, rating = sid.to(device), pid.to(device), rating.to(device)
        pred_sid, pred_pid = pred_sid.to(device), pred_pid.to(device)

        # Make prediction and compute loss
        pred = model_combined(sid, pid, pred_sid, pred_pid)        
        loss = F.mse_loss(pred, rating)

        # Compute gradients w.r.t. loss and take a step in that direction
        optim_combined.zero_grad()
        loss.backward()
        optim_combined.step()

        # Keep track of running loss
        total_data += len(sid)
        total_loss += len(sid) * loss.item()

    # Evaluate model on validation data
    total_val_mse = 0.0
    total_val_data = 0

    model_combined.eval()

    with torch.no_grad():
        for sid, pid, rating, pred_sid, pred_pid in valid_loader_combined:
            # Move data to GPU
            sid, pid, rating = sid.to(device), pid.to(device), rating.to(device)
            pred_sid, pred_pid = pred_sid.to(device), pred_pid.to(device)

            pred = model_combined(sid, pid, pred_sid, pred_pid)
            mse = F.mse_loss(pred, rating)

            # Keep track of running metrics
            total_val_data += len(sid)
            total_val_mse += len(sid) * mse.item()

    print(f"[Epoch {epoch+1}/{EPOCHS}] Train loss={total_loss / total_data:.3f}, Valid RMSE={(total_val_mse / total_val_data) ** 0.5:.3f}")