In [1]:
from typing import Tuple, Callable

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
import os

In [2]:
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

In [3]:
DATA_DIR = ""


def read_data_df() -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Reads in data and splits it into training and validation sets with a 75/25 split."""

    df = pd.read_csv(os.path.join(DATA_DIR, "train_ratings.csv"))

    # Split sid_pid into sid and pid columns
    df[["sid", "pid"]] = df["sid_pid"].str.split("_", expand=True)
    df = df.drop("sid_pid", axis=1)
    df["sid"] = df["sid"].astype(int)
    df["pid"] = df["pid"].astype(int)

    # Split into train and validation dataset
    train_df, valid_df = train_test_split(df, test_size=0.01)
    return train_df, valid_df


def read_data_matrix(df: pd.DataFrame) -> np.ndarray:
    """Returns matrix view of the training data, where columns are scientists (sid) and
    rows are papers (pid)."""

    return df.pivot(index="sid", columns="pid", values="rating").values


def evaluate(valid_df: pd.DataFrame, pred_fn: Callable[[np.ndarray, np.ndarray], np.ndarray]) -> float:
    """
    Inputs:
        valid_df: Validation data, returned from read_data_df for example.
        pred_fn: Function that takes in arrays of sid and pid and outputs their rating predictions.

    Outputs: Validation RMSE
    """

    preds = pred_fn(valid_df["sid"].values, valid_df["pid"].values)
    return root_mean_squared_error(valid_df["rating"].values, preds)


def make_submission(pred_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], filename: os.PathLike):
    """Makes a submission CSV file that can be submitted to kaggle.

    Inputs:
        pred_fn: Function that takes in arrays of sid and pid and outputs a score.
        filename: File to save the submission to.
    """

    df = pd.read_csv(os.path.join(DATA_DIR, "sample_submission.csv"))

    # Get sids and pids
    sid_pid = df["sid_pid"].str.split("_", expand=True)
    sids = sid_pid[0]
    pids = sid_pid[1]
    sids = sids.astype(int).values
    pids = pids.astype(int).values

    df["rating"] = pred_fn(sids, pids)
    df.to_csv(filename, index=False)

In [4]:
train_df, valid_df = read_data_df()

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

Using: cuda


In [6]:
def get_dataset(df: pd.DataFrame) -> torch.utils.data.Dataset:
    """Conversion from pandas data frame to torch dataset."""

    sids = torch.from_numpy(df["sid"].to_numpy())
    pids = torch.from_numpy(df["pid"].to_numpy())
    ratings = torch.from_numpy(df["rating"].to_numpy()).float()
    return torch.utils.data.TensorDataset(sids, pids, ratings)

In [7]:
def read_data_matrix(df: pd.DataFrame) -> np.ndarray:
    """Returns matrix view of the training data, where columns are scientists (sid) and
    rows are papers (pid)."""

    return df.pivot(index="sid", columns="pid", values="rating").values

def impute_values(mat: np.ndarray) -> np.ndarray:
    return np.nan_to_num(mat, nan=0.0)

In [8]:
Y = read_data_matrix(train_df)
Y = impute_values(Y)

In [9]:
def read_data_tbr() -> pd.DataFrame:
    """Reads in wishlist data"""

    df = pd.read_csv(os.path.join(DATA_DIR, "train_tbr.csv"))


    return df

In [10]:
wishlist_df = read_data_tbr()
wishlist_df["rating"] = 1

missing_sids = []
for i in range(10000):
    if wishlist_df[wishlist_df["sid"] == i].shape[0] == 0:
        missing_sids.append(i)

for i in range(len(missing_sids)):
    wishlist_df = pd.concat([wishlist_df, pd.DataFrame({"sid": [missing_sids[i]], "pid": [0], "rating": [0]})], ignore_index=True)

wishlist = read_data_matrix(wishlist_df)
wishlist = impute_values(wishlist)

In [18]:
class DeepMatrixFactorizationModel(nn.Module):
    def __init__(self, num_scientists: int, num_papers: int, dim: int, hidden_dim: int, Y: np.ndarray, rating: np.ndarray):
        super().__init__()

        self.register_buffer("Y", torch.from_numpy(Y).float())
        self.register_buffer("rating", torch.from_numpy(rating).float())


        self.scientist_nn = nn.Sequential(
            #nn.Dropout(0.5),
            nn.Linear(num_papers, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
            nn.ReLU()
        )

        self.paper_nn = nn.Sequential(
            #nn.Dropout(0.5),
            nn.Linear(num_scientists, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
            nn.ReLU()
        )

        self.srating_nn = nn.Sequential(
            nn.Linear(num_papers, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
            nn.ReLU()
        )

        self.prating_nn = nn.Sequential(
            nn.Linear(num_scientists, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
            nn.ReLU()
        )

        self.final_nn = nn.Sequential(
            nn.Linear(4*dim, 2*dim),
            nn.ReLU(),
            nn.Linear(2*dim, dim),
            nn.ReLU(),
            nn.Linear(dim, 1)
        )


    def forward(self, sid: torch.Tensor, pid: torch.Tensor) -> torch.Tensor:
        """
        Inputs:
            sid: [B,], int
            pid: [B,], int

        Outputs: [B,], float
        """
        scientist_row = self.Y[sid, :]
        paper_row = self.Y[:, pid].T

        # to not leak info
        #scientist_row[pid] = 0
        #paper_row[sid] = 0

        srating_row = self.rating[sid, :]
        prating_row = self.rating[:, pid].T



        p = self.scientist_nn(scientist_row)
        q = self.paper_nn(paper_row)

        srating = self.srating_nn(srating_row)
        prating = self.prating_nn(prating_row)

        #r = p * srating
        #sr = q * prating


        # Per-pair dot product
        return self.final_nn(torch.cat([p,q,srating,prating], dim=1)).squeeze(1)
        #return self.final_nn(torch.cat([p,q], dim=1)).squeeze(1)

In [19]:
# Define model (10k scientists, 1k papers, 32-dimensional embeddings) and optimizer
model = DeepMatrixFactorizationModel(10_000, 1_000, 35, 35, Y, wishlist).to(device)
optim = torch.optim.Adam(model.parameters(), lr=5*1e-5)

In [20]:
train_dataset = get_dataset(train_df)
valid_dataset = get_dataset(valid_df)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

In [21]:
from tqdm.notebook import tqdm
MODEL_NAME = "deeper_final_layer"
OUTPUT_DIR = "models"
os.makedirs(OUTPUT_DIR, exist_ok=True)

history = {
    'epoch': [],
    'train_loss': [],
    'val_rmse': [],
}

best_val_rmse = float('inf')
best_model_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_best_val_rmse.pth")



NUM_EPOCHS = 60
epochs = tqdm(range(NUM_EPOCHS), desc="Epochs")
for epoch in epochs:
    # Train model for an epoch
    total_loss = 0.0
    total_data = 0
    model.train()
    train_batch = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]", leave=False)
    for sid, pid, ratings in train_batch:
        # Move data to GPU
        sid = sid.to(device)
        pid = pid.to(device)
        ratings = ratings.to(device)

        # Make prediction and compute loss
        pred = model(sid, pid)
        loss = F.mse_loss(pred, ratings)

        # Compute gradients w.r.t. loss and take a step in that direction
        optim.zero_grad()
        loss.backward()
        optim.step()

        # Keep track of running loss
        total_data += len(sid)
        total_loss += len(sid) * loss.item()

    # Evaluate model on validation data
    total_val_mse = 0.0
    total_val_data = 0
    model.eval()
    valid_batch = tqdm(valid_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Valid]", leave=False)
    for sid, pid, ratings in valid_batch:
        # Move data to GPU
        sid = sid.to(device)
        pid = pid.to(device)
        ratings = ratings.to(device)

        # Clamp predictions in [1,5], since all ground-truth ratings are
        pred = model(sid, pid).clamp(1, 5)
        mse = F.mse_loss(pred, ratings)

        # Keep track of running metrics
        total_val_data += len(sid)
        total_val_mse += len(sid) * mse.item()

    curr_rmse = (total_val_mse / total_val_data) ** 0.5
    curr_loss = total_loss / total_data
    history['epoch'].append(epoch+1)
    history['train_loss'].append(curr_loss)
    history['val_rmse'].append(curr_rmse)

    saved_this_epoch = False
    if curr_rmse < best_val_rmse:
        best_val_rmse = curr_rmse
        torch.save(model.state_dict(), best_model_path)
        saved_this_epoch = True

    postfix_str = f"Train Loss={curr_loss:.3f}, Valid RMSE={curr_rmse:.3f}, Best RMSE={best_val_rmse:.3f}"
    if saved_this_epoch:
        postfix_str += " (Saved Best Model)"
    epochs.set_postfix_str(postfix_str)

Epochs:   0%|          | 0/60 [00:00<?, ?it/s]

Epoch 1/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 1/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 2/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 2/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 3/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 3/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 4/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 4/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 5/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 5/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 6/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 6/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 7/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 7/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 8/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 8/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 9/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 9/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 10/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 10/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 11/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 11/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 12/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 12/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 13/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 13/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 14/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 14/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 15/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 15/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 16/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 16/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 17/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 17/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 18/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 18/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 19/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 19/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 20/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 20/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 21/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 21/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 22/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 22/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 23/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 23/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 24/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 24/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 25/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 25/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 26/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 26/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 27/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 27/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 28/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 28/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 29/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 29/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 30/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 30/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 31/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 31/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 32/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 32/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 33/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 33/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 34/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 34/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 35/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 35/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 36/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 36/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 37/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 37/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 38/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 38/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 39/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 39/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 40/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 40/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 41/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 41/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 42/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 42/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 43/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 43/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 44/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 44/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 45/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 45/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 46/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 46/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 47/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 47/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 48/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 48/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 49/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 49/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 50/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 50/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 51/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 51/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 52/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 52/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 53/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 53/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 54/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 54/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 55/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 55/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 56/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 56/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 57/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 57/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 58/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 58/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 59/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 59/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

Epoch 60/60 [Train]:   0%|          | 0/34904 [00:00<?, ?it/s]

Epoch 60/60 [Valid]:   0%|          | 0/353 [00:00<?, ?it/s]

In [22]:
results_df = pd.DataFrame(history)
output_dir = "training_results"
os.makedirs(output_dir, exist_ok=True)
csv_filename = os.path.join(output_dir, f"{MODEL_NAME}_metrics.csv")
results_df.to_csv(csv_filename, index=False)

In [23]:
def batched_pred_fn(sids, pids, batch_size=1024):
    results = []
    num_samples = len(sids)
    best_model_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_best_val_rmse.pth")
    model.load_state_dict(torch.load(best_model_path))
    model.eval()

    for i in range(0, num_samples, batch_size):
        batch_sids = sids[i:i+batch_size]
        batch_pids = pids[i:i+batch_size]

        batch_sids_tensor = torch.from_numpy(batch_sids).to(device)
        batch_pids_tensor = torch.from_numpy(batch_pids).to(device)




        batch_preds = model(batch_sids_tensor, batch_pids_tensor).clamp(1, 5).cpu().numpy()
        results.append(batch_preds)

        del batch_sids_tensor, batch_pids_tensor
        torch.cuda.empty_cache()

    return np.concatenate(results)

# Evaluate on validation data
with torch.no_grad():
    # First clear any unused memory
    torch.cuda.empty_cache()
    val_score = evaluate(valid_df, batched_pred_fn)

print(f"Validation RMSE: {val_score:.3f}")

Validation RMSE: 0.856


In [24]:
with torch.no_grad():
    make_submission(batched_pred_fn, "deep_matrix_submission_no_dropout.csv")