<a href="https://colab.research.google.com/github/caravanuden/clipme/blob/main/CS224W_ClipMe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Clip Me! Project Code

This is the official code for our CS224W final project Clip Me! where we aim to model the Spotify Podcast Dataset as a single graph with clip and topic nodes, with Clip2Clip if the clips are from the same episode and Clip2Topic edges if the clip is labelled with that specific topic. Our end goal is to recommend clips from the graph based off of a user-given clip query. To do so, we model our task as link prediction in a semi-inductive setting and use the Clip2Clip same episode edges as a proxy for clip similarity.

Our code base uses the PyG, NetworkX and DeepSNAP libraries to model the graph and GNN. We additionally use Pytorch Lightning (with optional logging to Weights & Biases) to wrap our dataloading and training/testing to simplify our code. 

## Setup

In this section, we install all required packages, mount to Google Drive, import packages and define constants.

All the required files to run this notebook can be found [here](https://drive.google.com/drive/folders/1MpP7Odqw7o0LWS4UoHkqltSB3IiCWoRW?usp=sharing). 

### Install dependencies

In [1]:
!pip install pytorch-lightning
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
!pip install torch-geometric

!pip install -q git+https://github.com/snap-stanford/deepsnap.git
!pip install faiss-cpu
!pip install sentence-transformers
!pip install bertopic
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.1+cu116.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.1+cu116.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/



### Mounting on Google Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

%cd /content/drive/MyDrive/"CS224W Final Project"
# %cd /content/drive/MyDrive/"[FOLDER -- TO REPLACE!]"

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project


### Imports and constants

In [3]:
import ast
import os
import copy

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import collections
import pandas as pd
import torch
from typing import List, Optional
import random
import json
import string
import tqdm
import pickle
import itertools
import numpy as np

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from bertopic import BERTopic
import nltk
from nltk.stem.porter import PorterStemmer
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import CountVectorizer

from torch_geometric.data import InMemoryDataset, HeteroData, download_url, extract_zip
from torch_geometric.transforms import ToUndirected

nltk.download("stopwords")
nltk.download("punkt")
nltk_stop_words = stopwords.words("english")

# Defining set of directories
# DATA_DIR = '/content/drive/MyDrive/[FOLDER -- TO REPLACE!]'
DATA_DIR = '/content/drive/MyDrive/CS224W Final Project/'
RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")
INTERMEDIATE_DATA_DIR = os.path.join(DATA_DIR, "intermediate")
PROCESSED_DATA_DIR = os.path.join(DATA_DIR, "processed")

os.makedirs(RAW_DATA_DIR, exist_ok=True)
os.makedirs(INTERMEDIATE_DATA_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)

N_CLIPS = 5

SENTENCE_SIMILARITY_THRESHOLD = 0.6
CLIP_LENGTH = 300

KEYBERT_VOCAB_NGRAM_RANGE = (1, 2)
BERTOPIC_MAX_DF = 0.5
BERTOPIC_MMR_DIVERSITY = 0.2
BERTOPIC_MIN_TOPIC_SIZE = 100  # or 10
BERTOPIC_MIN_SAMPLES = 1
BERTOPIC_TOP_N_WORDS = 10
BERTOPIC_REDUCE_OUTLIERS = False
BERTOPIC_REDUCE_OUTLIERS_THRESHOLD = 0.1

SUPERVISION_EDGE = ("clip", "same_episode", "clip")
LOCAL_CLIP2TOPIC_EDGE = ("clip", "has_topic_local", "topic")
GLOBAL_CLIP2TOPIC_EDGE = ("clip", "has_topic_global", "topic")

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Data preprocessing

Data preprocessing was quite intensive. Please refer to the full [Git repository](https://github.com/caravanuden/clipme) for these steps and scripts. It involved multiple steps:
1. Getting the [Spotify Podcast Dataset](https://podcastsdataset.byspotify.com/). You'll need to request access first! Then you can download and unzip the dataset.
2. Sampling the podcast dataset to only select English-language "interview" podcasts. Please see the script for that in `src/datasets/create_subset.py`.
3. The following steps are done automatically in `src/datasets/preprocess_clips.py`:
> a. Randomly sample 5 clips from a podcast episode and extract the transcript from each clip (`extract_clip_transcript` and lines 205-239).\
> b. Extract the KeyBERT vocabulary from the clips (`extract_vocab` and lines 245-250).\
> c. Extract the BERTopic topics from the vocab and clips (`extract_clip_topics` and lines 252-255).

At the end of this process, you'll have a set of topics and clips saved to a CSV file of your choice!

If this does not work for you, you can also use our [preprocessed PyG data slice](https://drive.google.com/file/d/1yLnOQnl6-vMFN7MqxxvpbXwBQZiRBM7P/view?usp=sharing), which integrates with our SpotifyPodcastDataset and dataloaders.

### Topic model visualizations
Now, we can visualize some of the topics in our BERTopic global model.

In [40]:
from bertopic import BERTopic
topic_model = BERTopic.load('/content/drive/MyDrive/CS224W Final Project/intermediate/bertopic_model_clips5_100_1_10')

In [41]:
# generate the inter-topic distance map
topic_model.visualize_topics()

In [42]:
# visual the inter-topic similarity matrix
topic_model.visualize_heatmap(top_n_topics=25)

In [43]:
# visualize the keywords for the top eight largest topics
topic_model.visualize_barchart()

## Data loading

In this section, we detail how we create our heterogeous graph from a pre-processed CSV file with all the clip and topic data (please see the "Data preprocessing" section for more information for how we get the CSV). Specifically, 
- **Clip node features** are intialized with the **sBERT embeddings** of the clip transcript
- **Topic node features** are **one-hot** embeddings of the topic ID
  - We use both *local* (fine-grained) and *global* (coarse-grained) topic nodes
- **clip-clip edges** are defined if both clips are from the **same episode**
- **clip-topic edges** are defined if the **clip belongs to that topic**
- **topic-topic** edges are defined by if the **topics share more than 1 keyword**, where keywords come from a pre-defined list which were derived using BERTopic. **Edge weights** are defined by the **Jaccard similarity (IOU)** between the topic keyword lists

Note: all edges are undirected.


### Helpers

In [4]:
# Jaccard similarity calculation for topic-topic edge weights
def jaccard(list1, list2):
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union


class SequenceEncoder(object):
    def __init__(self, model_name="all-MiniLM-L6-v2", device=None):
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)

    @torch.no_grad()
    def __call__(self, df):
        """
        sBERT embedder that converts clip text into embeddings.
        Default sBERT model is "all-MiniLM-L6-v2".
        
        Args:
          df: N clip transcripts stored in a dataframe.
        Returns:
          (N, embed_dim) torch tensor of embeddings generated by the `model_name` sbERT model
        """
        x = self.model.encode(
            df.values,
            show_progress_bar=True,
            convert_to_tensor=True,
            device=self.device,
        )
        return x.cpu()


class OneHotEncoder(object):
    """
    Encodes topics from x into one-hot embeddings. 
    Note that topic IDs are in the range [0, num_topics].
    Args:
      x: N distinct topics (BERTopic IDs) stored in Pandas series (DataFrame col).
    Returns:
      (N, N) torch tensor of one-hot embeddings generated from the topics.
    """
    def __call__(self, x):
        return torch.eye(x.shape[0])

### Creating the heterogeneous dataset

In [5]:
class SpotifyPodcastDataset(InMemoryDataset):
    def __init__(self, root="", transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.raw_data = os.path.join(
            INTERMEDIATE_DATA_DIR, "episode_subset.tsv"
        )

        self.clip_path = os.path.join(
            INTERMEDIATE_DATA_DIR,
            f"processed_clips_w_global_local_topics.csv",
        )
        self.clip_mapping_path = os.path.join(
            INTERMEDIATE_DATA_DIR, f"processed_clip_nodes_mapping.pkl"
        )
        self.local_topic_mapping_path = os.path.join(
            INTERMEDIATE_DATA_DIR, f"processed_local_topic_nodes_mapping.pkl"
        )
        self.global_topic_mapping_path = os.path.join(
            INTERMEDIATE_DATA_DIR, f"processed_global_topic_nodes_mapping.pkl"
        )

        self.node_types = ["clip", "local_topic", "global_topic"]
        self.edge_types = {
            "has_topic_local": {
                "edge_name": ("clip", "has_topic_local", "local_topic"),
                "src_index_col": "episode_clip_uri",
                "src_mapping": "clip",
                "dst_index_col": "topic_id_local",
                "dst_mapping": "local_topic",
            },
            "has_topic_global": {
                "edge_name": ("clip", "has_topic_global", "global_topic"),
                "src_index_col": "episode_clip_uri",
                "src_mapping": "clip",
                "dst_index_col": "topic_id_global",
                "dst_mapping": "global_topic",
            },
            "topic_keyword_sim_local": {
                "edge_name": ("local_topic", "topic_keyword_sim_local", "local_topic"),
                "src_index_col": "topic_id_local",
                "src_keyword_col": "topic_top_n_words_local",
                "src_mapping": "local_topic",
                "dst_index_col": "topic_id_local",
                "dst_keyword_col": "topic_top_n_words_local",
                "dst_mapping": "local_topic",
            },
            "topic_keyword_sim_global": {
                "edge_name": (
                    "global_topic",
                    "topic_keyword_sim_global",
                    "global_topic",
                ),
                "src_index_col": "topic_id_global",
                "src_keyword_col": "topic_top_n_words_global",
                "src_mapping": "global_topic",
                "dst_index_col": "topic_id_global",
                "dst_keyword_col": "topic_top_n_words_global",
                "dst_mapping": "global_topic",
            },
            "topic_keyword_sim_local_global": {
                "edge_name": (
                    "local_topic",
                    "topic_keyword_sim_local_global",
                    "global_topic",
                ),
                "src_index_col": "topic_id_local",
                "src_keyword_col": "topic_top_n_words_local",
                "src_mapping": "local_topic",
                "dst_index_col": "topic_id_global",
                "dst_keyword_col": "topic_top_n_words_global",
                "dst_mapping": "global_topic",
            },
            "same_episode": {
                "edge_name": ("clip", "same_episode", "clip"),
                "src_index_col": "episode_clip_uri",
                "src_mapping": "clip",
                "dst_index_col": "episode_clip_uri",
                "dst_mapping": "clip",
            },
        }

        

        if not os.path.exists(PROCESSED_DATA_DIR + "/data_slice.p"):
            self.process()  # For debugging
        self.data, self.slices = torch.load(
            PROCESSED_DATA_DIR + "/data_slice.p"
        )
        
        self._set_node_start_indices()
        # print(self.start_indices)

    @property
    def raw_file_names(self):
        raw_data = pd.read_table(self.raw_data)
        return [f for f in raw_data.path]

    @property
    def processed_file_names(self):
        return [
            os.path.join(PROCESSED_DATA_DIR, f)
            for f in os.listdir(PROCESSED_DATA_DIR)
        ]

    def _load_df(self):
        df = pd.read_csv(
            self.clip_path, on_bad_lines="skip"
        )  ## make sure no duplicates!
        df = df.dropna().reset_index(drop=True)
        df = df.loc[~df[["episode_clip_uri"]].duplicated()]

        return df

    def _set_node_start_indices(self):
        with open(self.clip_mapping_path, 'rb') as f:
          mapping = pickle.load(f)['mapping']
          num_clips = len(mapping)
        
        with open(self.local_topic_mapping_path, 'rb') as f:
          mapping = pickle.load(f)['mapping']
          num_local_topics = len(mapping)

        self.start_indices = {
            "clip": 0,
            "local_topic": len(num_clips),  # start after clip nodes
            "global_topic": len(num_clips + num_local_topics)
            + len(
                df[["topic_id_local"]].drop_duplicates()
            ),  # start after local topic nodes
        }

    def _load_nodes(
        self, df, index_col, mapping_path, encoders=None, start_idx=0, **kwargs
    ):
        mapping = {
            index: i + start_idx
            for i, index in enumerate(sorted(list(set(df[index_col]))))
        }

        x = None
        if encoders is not None:
            xs = [encoder(df[col]) for col, encoder in encoders.items()]
            x = torch.cat(xs, dim=-1)

        with open(mapping_path, "wb") as f:
            pickle.dump({"x": x, "mapping": mapping}, f)

        print(f"=> Finished getting x (shape={x.shape}), mapping (len={len(mapping)})")
        assert x.shape[0] == len(mapping)
        return x, mapping

    def _draw_episode_edges(
        self, src_index_col, src_mapping, dst_index_col, dst_mapping, df
    ):
        """
        For each episode, get all the corresponding clips (node ids)
        For a single set of clips, do itertools combo pairs
        """
        edge_index = []
        episode_uris = df["episode_uri"].unique()
        for episode_uri in tqdm.tqdm(episode_uris, total=len(episode_uris)):
            clips_for_episode = df.loc[df["episode_uri"] == episode_uri][
                src_index_col
            ].values
            nodes_for_episode = [src_mapping[uri] for uri in clips_for_episode]

            edges_for_episode = torch.as_tensor(
                list(itertools.permutations(nodes_for_episode, 2))
            ).T.long()

            edge_index.append(edges_for_episode)

        edge_index = torch.hstack(edge_index)
        assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
        return edge_index, []

    def _draw_clip_topic_edges(
        self, src_index_col, src_mapping, dst_index_col, dst_mapping, df, ignore=[]
    ):
        """
        For each unique topic, get all the corresponding clips (node ids)
        For a single set of clips, map them all to the same topic node
        Don't skip noise topic (-1)
        """
        edge_index = []
        for topic_id in dst_mapping.keys():
            if topic_id in ignore:
                continue
            clips_for_topic = df.loc[df[dst_index_col] == topic_id][
                src_index_col
            ].values
            clip_nodes_for_topic = [src_mapping[uri] for uri in clips_for_topic]

            edges_for_topic_x = torch.as_tensor(clip_nodes_for_topic)
            edges_for_topic_y = torch.as_tensor(
                [dst_mapping[topic_id]] * len(clip_nodes_for_topic)
            )
            edges_for_topic = torch.vstack([edges_for_topic_x, edges_for_topic_y])

            edge_index.append(edges_for_topic)

        edge_index = torch.hstack(edge_index)
        assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
        return edge_index, []

    def _draw_jaccard_topic_keyword_edges(
        self,
        src_index_col,
        src_mapping,
        dst_index_col,
        dst_mapping,
        df,
        sep="|",
        src_keyword_col=None,
        dst_keyword_col=None,
    ):
        """
        Compute jaccard similarity of shared keywords for each pair of topics.
        """
        edge_index = []
        edge_attr = []

        if src_keyword_col is None:
            src_keyword_col = "topic_top_n_words"
            dst_keyword_col = "topic_top_n_words"

        src_df = df[[src_index_col, src_keyword_col]].drop_duplicates()  # topics only
        dst_df = df[[dst_index_col, dst_keyword_col]].drop_duplicates()  # topics only

        for i in tqdm.tqdm(range(len(src_df))):
            # words_i = set(df.iloc[i]['topic_top_n_words'].split(sep))
            words_i = set(ast.literal_eval(src_df.iloc[i][src_keyword_col]))
            node_i = src_mapping[src_df.iloc[i][src_index_col]]
            for j in range(len(dst_df)):
                # words_j = set(df.iloc[j]['topic_top_n_words'].split(sep))
                words_j = set(ast.literal_eval(dst_df.iloc[j][dst_keyword_col]))
                node_j = dst_mapping[dst_df.iloc[j][dst_index_col]]

                jaccard_sim = jaccard(words_i, words_j)

                # Add edges for both directions
                if jaccard_sim > 0.0:
                    edge_index.append(torch.tensor([[node_i, node_j]]).T)
                    edge_attr.append(jaccard_sim)

        edge_index = torch.hstack(edge_index)
        edge_attr = torch.tensor(edge_attr)

        assert (
            edge_index.shape[0] == 2
            and len(edge_index.shape) == 2
            and edge_index.shape[1] == edge_attr.shape[0]
        )
        return edge_index, edge_attr

    def _load_edges(
        self,
        edge_type,
        src_index_col,
        src_mapping,
        dst_index_col,
        dst_mapping,
        df,
        src_keyword_col=None,
        dst_keyword_col=None,
    ):
        if edge_type == "same_episode":
            return self._draw_episode_edges(
                src_index_col, src_mapping, dst_index_col, dst_mapping, df
            )
        elif "has_topic" in edge_type:
            return self._draw_clip_topic_edges(
                src_index_col, src_mapping, dst_index_col, dst_mapping, df
            )
        elif "keyword_sim" in edge_type:
            return self._draw_jaccard_topic_keyword_edges(
                src_index_col,
                src_mapping,
                dst_index_col,
                dst_mapping,
                df,
                src_keyword_col=src_keyword_col,
                dst_keyword_col=dst_keyword_col,
            )
        else:
            raise NotImplementedError(f"{edge_type} not implemented!")

    def process(self):
        df = self._load_df()
        self._set_node_start_indices(df)

        clip_x, clip_mapping = self._load_nodes(
            df=df,
            index_col="episode_clip_uri",
            mapping_path=self.clip_mapping_path,
            encoders={
                "transcript": SequenceEncoder(),
            },
            start_idx=self.start_indices["clip"],
        )
        print("=> Finished loading clip nodes")

        # assert len(df) == len(clip_mapping)

        local_topic_x, local_topic_mapping = self._load_nodes(
            df=df[["topic_id_local"]].drop_duplicates(),
            index_col="topic_id_local",
            mapping_path=self.local_topic_mapping_path,
            encoders={
                "topic_id_local": OneHotEncoder(),
            },
            start_idx=self.start_indices["local_topic"],
        )
        print("=> Finished loading local topic nodes")

        global_topic_x, global_topic_mapping = self._load_nodes(
            df=df[["topic_id_global"]].drop_duplicates(),
            index_col="topic_id_global",
            mapping_path=self.global_topic_mapping_path,
            encoders={
                "topic_id_global": OneHotEncoder(),
            },
            start_idx=self.start_indices["global_topic"],
        )
        print("=> Finished loading global topic nodes")

        mappings = {
            "clip": clip_mapping,
            "local_topic": local_topic_mapping,
            "global_topic": global_topic_mapping,
        }

        data = HeteroData()
        data["clip"].x = clip_x
        data["local_topic"].x = local_topic_x
        data["global_topic"].x = global_topic_x

        for edge_type, edge_type_map in self.edge_types.items():
            edge_type_map_cpy = copy.copy(edge_type_map)
            del edge_type_map_cpy["edge_name"]

            src_mapping = edge_type_map_cpy.pop("src_mapping")
            dst_mapping = edge_type_map_cpy.pop("dst_mapping")

            print(f"=> Creating edges for {edge_type}")
            edge_index, edge_attr = self._load_edges(
                edge_type=edge_type,
                df=df,
                src_mapping=mappings[src_mapping],
                dst_mapping=mappings[dst_mapping],
                **edge_type_map_cpy,
            )

            data[edge_type_map["edge_name"]].edge_index = edge_index
            if edge_attr != []:
                data[edge_type_map["edge_name"]].edge_attr = edge_attr

        data = ToUndirected()(data)
        data_list = [data]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), PROCESSED_DATA_DIR + "/data_slice.p")


### Creating the HeteroGraph from the HeteroData

In this section, we transform the HeteroData that was derived from the pre-processed CSV file of clip and topic data into a HeteroGraph on which a GNN can train. We further split our graph in an inductive setting into train/validation/test splits. We then wrap the dataloader in a Pytorch Lightning Data Module. Overall, this involves:
- Converting the HeteroData into a NetworkX DiGraph
- Converting the DiGraph into a DeepSNAP HeteroGraph
- Making n=10 copies of the HeteroGraph (this is so that we can perform an inductive split on the graph)
- Converting the list of HeteroGraphs into a GraphDataset for inductive splitting with negatively sampled edges
- Splitting the GraphDatset into train/validation/test splits

In [6]:
import copy
import torch
import pytorch_lightning as pl

from torch.utils.data import DataLoader
import networkx as nx
from deepsnap.dataset import GraphDataset
from deepsnap.hetero_graph import HeteroGraph
from deepsnap.batch import Batch
from torch_geometric.data import HeteroData

In [7]:
class SpotifyPodcastDataModule(pl.LightningDataModule):
    def __init__(
        self, val_split=0.05, test_split=0.1, neg_sampling_ratio=2.0, edge_message_ratio=0.8, keep_hetero=False, save_splits=False
    ):
        super().__init__()
        
        self.save_splits = save_splits
        self.neg_sampling_ratio = neg_sampling_ratio
        self.edge_message_ratio = edge_message_ratio
        self.split_ratio = [1 - val_split - test_split, val_split, test_split]

        self.spotify_dataset = SpotifyPodcastDataset()
        self.data = self.spotify_dataset.data

        self.node_types = self.data.node_types
        self.edge_types = self.data.edge_types

        self.input_dim = self.data[self.node_types[0]].x.shape[-1]

        # Split graph
        hetero = self.convert_graph_to_deepsnap()
        self.split_dataset_from_heterograph(hetero)

        # Update edge types (because perhaps duplicate edges that cancel out an edge type)
        # print(f"Old edge types: {self.edge_types}")
        self.edge_types = list(hetero.edge_type.keys())
        # print(f"Updated edge types to: {self.edge_types}")

        if keep_hetero:
            self.hetero = hetero
        else:
            self.hetero = None
            del hetero

    def convert_graph_to_deepsnap(self):
        G = nx.DiGraph()  # Directed because of clip2topic edges
        nodes_lst = []
        for node_type in self.node_types:
            nodes = self.data[node_type].x
            start_idx = self.spotify_dataset.start_indices[node_type]
            nodes_type_lst = [
                (start_idx + id_, {"node_type": node_type, "node_feature": nodes[id_]})
                for id_ in range(nodes.shape[0])
            ]

            nodes_lst.extend(nodes_type_lst)

        G.add_nodes_from(nodes_lst)

        edge_lst = []
        for i, edge_type in enumerate(self.data.edge_types):
            edges = self.data[edge_type].edge_index

            edge_feats = (
                self.data[edge_type].edge_attr
                if "edge_attr" in self.data[edge_type]
                else None
            )

            if edge_feats is not None:
                edge_type_lst = [
                    (
                        edges[0, id_].item(),
                        edges[1, id_].item(),
                        {"edge_type": edge_type[1], "weight": edge_feats[id_]},
                    )
                    for id_ in range(edges.shape[-1])
                ]
            else:
                edge_type_lst = [
                    (
                        edges[0, id_].item(),
                        edges[1, id_].item(),
                        {"edge_type": edge_type[1], "weight": 1},
                    )
                    for id_ in range(edges.shape[-1])
                ]  # dummy edge weight
            edge_lst.extend(edge_type_lst)

        G.add_edges_from(edge_lst)

        self.nx_graph = G

        hetero = HeteroGraph(G, directed=False)

        for edge_type in self.data.edge_types:
            assert (
                hetero.edge_index[edge_type].shape[-1]
                == self.data[edge_type].edge_index.shape[-1]
            )

        return hetero

    def split_dataset_from_heterograph(self, hetero):
        self.dataset = GraphDataset(
            [hetero],
            task="link_pred",
            edge_train_mode="disjoint",
            edge_message_ratio=self.edge_message_ratio,
            edge_negative_sampling_ratio=self.neg_sampling_ratio,
        )

        self.dataset_train, self.dataset_val, self.dataset_test = self.dataset.split(
            transductive=True, split_ratio=self.split_ratio
        )

        if self.save_splits:
          train_edge_label_index = self.dataset_train[0].edge_label_index[SUPERVISION_EDGE]
          train_edge_label = self.dataset_train[0].edge_label[SUPERVISION_EDGE]
          val_edge_label_index = self.dataset_val[0].edge_label_index[SUPERVISION_EDGE]
          val_edge_label = self.dataset_val[0].edge_label[SUPERVISION_EDGE]
          test_edge_label_index = self.dataset_test[0].edge_label_index[SUPERVISION_EDGE]
          test_edge_label = self.dataset_test[0].edge_label[SUPERVISION_EDGE]
          
          train_split = {"edge_label_index": train_edge_label_index, "edge_label": train_edge_label}
          val_split = {"edge_label_index": val_edge_label_index, "edge_label": val_edge_label}
          test_split = {"edge_label_index": test_edge_label_index, "edge_label": test_edge_label}
          
          data_splits = {'train': train_split, 'val': val_split, 'test': test_split}
          
          with open(DATA_SPLITS_PATH, 'wb') as f:
              pickle.dump(data_splits, f)
          
          print("=> Finished saving datasplits")


    def train_dataloader(self):
        return DataLoader(
            self.dataset_train,
            shuffle=True,
            batch_size=1,
            num_workers=0,
            collate_fn=Batch.collate(),
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val,
            shuffle=False,
            batch_size=1,
            num_workers=0,
            collate_fn=Batch.collate(),
        )

    def test_dataloader(self):
        return DataLoader(
            self.dataset_test,
            shuffle=False,
            batch_size=1,
            num_workers=0,
            collate_fn=Batch.collate(),
        )


## Defining the RGraphSAGE Model

In this section, we define our RGraphSAGE model. Since we have edge weights, we use `GraphCONV` and then specify our aggregation method to be *mean aggregation*.



In [8]:
import torch
from torch_geometric.nn import Linear, to_hetero, GraphConv

In [9]:
"""
Code adapted from https://medium.com/@pytorch_geometric/link-prediction-on-heterogeneous-graphs-with-pyg-6d5c29677c70
"""

class RGraphSAGE(torch.nn.Module):
    def __init__(
        self,
        hidden_dim,
        num_layers=2,
    ):
        super().__init__()

        self.input_lin = Linear(-1, hidden_dim)
        self.output_lin = Linear(hidden_dim, hidden_dim)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            # Note: GraphConv + mean aggregation = GraphSAGE
            conv = GraphConv(hidden_dim, hidden_dim, aggr="mean")
            self.convs.append(conv)

    def forward(self, x, edge_index, edge_weight):
        x = self.input_lin(x)
        for conv in self.convs:
            x = conv(x, edge_index, edge_weight)

        x = self.output_lin(x)

        return x


class Classifier(torch.nn.Module):
    def forward(self, x1, x2, edge_label_index):
        # Get node embeddings
        node_feat_x1 = x1[edge_label_index[0]]  # (num_labels, embed_dim)
        node_feat_x2 = x2[edge_label_index[1]]  # (num_labels, embed_dim)

        # Dot product of node embeddings to get logits for edge existence
        return (node_feat_x1 * node_feat_x2).sum(dim=-1)  # (num_labels, )


class Model(torch.nn.Module):
    def __init__(
        self,
        hidden_dim,
        num_layers=2,
        node_types=None,
        edge_types=None,
        checkpoint=None,
    ):
        super().__init__()
        self.gnn = RGraphSAGE(hidden_dim, num_layers)

        self.gnn = to_hetero(self.gnn, metadata=(node_types, edge_types))
        self.classifier = Classifier()

        self.node_types = node_types

        if checkpoint:
            self.load_from_lightning_checkpoint(checkpoint)

    def load_from_lightning_checkpoint(self, checkpoint_path):
        """
        Load model checkpoint from Pytorch Lightning Module checkpoint.
        """
        checkpoint = torch.load(checkpoint_path)

        ckpt = checkpoint["state_dict"]
        # Remap keys
        ckpt = {k.replace("model.", ""): v for k, v in ckpt.items()}

        msg = self.load_state_dict(ckpt, strict=False)

        print("=" * 80)
        print(msg)
        print("=" * 80)

    def forward(self, x_dict, edge_index_dict, edge_weight_dict, edge_label_index):
        x_dict = self.gnn(x_dict, edge_index_dict, edge_weight_dict)

        pred = self.classifier(
            x_dict[SUPERVISION_EDGE[0]], x_dict[SUPERVISION_EDGE[2]], edge_label_index
        )

        return pred


## Metrics

Since our task is link prediction, we use the standard binary classification metric AUROC.

In [10]:
from sklearn.metrics import roc_auc_score

def compute_auroc(y, prob):
    if type(y) == torch.Tensor:
        y = y.detach().cpu().numpy()
    if type(prob) == torch.Tensor:
        prob = prob.detach().cpu().numpy()

    auroc = roc_auc_score(y, prob)
    return auroc

## PyTorch Lightning Module

In this section, we define the Pytorch Lightning Module wrapper that modularizes our training and testing loops. 

In [11]:
class RGraphSAGELightning(pl.LightningModule):
    def __init__(self, args):
        super().__init__()

        self.lr = args.lr
        self.model = Model(
            hidden_dim=args.hidden_dim,
            node_types=args.node_types,
            edge_types=args.edge_types,
            num_layers=args.num_layers,
        )

        self.loss_fn = torch.nn.BCEWithLogitsLoss()

        self.supervision_edge = SUPERVISION_EDGE

        self.training_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")

    def on_train_epoch_end(self):
        self.shared_epoch_end(self.training_step_outputs, "train")
        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):
        self.shared_epoch_end(self.validation_step_outputs, "val")
        self.validation_step_outputs.clear()

    def on_test_epoch_end(self):
        self.shared_epoch_end(self.test_step_outputs, "test")
        self.test_step_outputs.clear()

    def shared_step(self, batch, split):
        out = self.model(
            batch.node_feature,
            batch.edge_index,
            batch.weight,
            batch.edge_label_index[self.supervision_edge],
        )
        y = batch.edge_label[self.supervision_edge].float()
        loss = self.loss_fn(out, y)

        self.log(
            f"{split}_loss",
            loss,
            on_epoch=True,
            on_step=True,
            logger=True,
            prog_bar=True,
        )
        return_dict = {"y": y, "logits": out, "loss": loss}

        if split == "train":
            self.training_step_outputs.append(return_dict)
        elif split == "val":
            self.validation_step_outputs.append(return_dict)
        else:
            self.test_step_outputs.append(return_dict)

        return return_dict

    def shared_epoch_end(self, step_outputs, split):
        logit = torch.cat([x["logits"] for x in step_outputs])
        prob = torch.sigmoid(logit)
        y = torch.cat([x["y"] for x in step_outputs])

        auroc = compute_auroc(y, prob)
        self.log(f"{split}_auroc", auroc, on_epoch=True, prog_bar=True, logger=True)

        # print(f"\nEpoch: {self.current_epoch} \t {split}_auroc: {round(auroc, 3)}")


## Checkpointing Utils

Utility functions for checkpointing and getting the best checkpoint for a single run.

In [12]:
import yaml

def get_best_ckpt_path(ckpt_paths, ascending=False):
    """
    Get best checkpoint path from a list of checkpoint paths

    ckpt_paths: JSON file with ckpt path to metric pair
    ascending: sort paths based on ascending or descending metrics
    """

    with open(ckpt_paths, "r") as stream:
        ckpts = yaml.safe_load(stream)

    ckpts_df = pd.DataFrame.from_dict(ckpts, orient="index").reset_index()
    ckpts_df.columns = ["path", "metric"]
    best_ckpt_path = (
        ckpts_df.sort_values("metric", ascending=ascending).head(1)["path"].item()
    )

    return best_ckpt_path


def save_best_checkpoints(checkpoint_callback, savedir, return_best=True):
    ckpt_paths = os.path.join(savedir, "best_ckpts.yaml")
    checkpoint_callback.to_yaml(filepath=ckpt_paths)
    if return_best:
        ascending = False
        best_ckpt_path = get_best_ckpt_path(ckpt_paths, ascending)
        return best_ckpt_path

## Defining the PyTorch Lightning Trainer

In this section, we define the main function that we call for training/testing. This involves instantiating the SpotifyPodcastDataModule, checkpoint directory, an optional WandB logger and the PyTorch Lightning Trainer.

In [13]:
from dateutil import tz
import datetime

from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.seed import isolate_rng
from pytorch_lightning import loggers as pl_loggers

In [14]:
def main(args):
    dm = SpotifyPodcastDataModule(
        val_split=args.val_split,
        test_split=args.test_split,
        neg_sampling_ratio=args.neg_sampling_ratio,
        edge_message_ratio=args.edge_message_ratio,
    )

    args.node_types = dm.node_types
    args.edge_types = dm.edge_types

    if args.checkpoint:
        print("=" * 80)
        print(f"*** Loading checkpoint: {args.checkpoint}")
        print("=" * 80)
        model = RGraphSAGELightning.load_from_checkpoint(args.checkpoint, args=args)
    else:
        model = RGraphSAGELightning(args)

    # Modify checkpoint name
    now = datetime.datetime.now(tz.tzlocal())
    timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
    args.savedir += f"/{timestamp}"

    os.makedirs(args.savedir, exist_ok=True)

    checkpoint_callback = ModelCheckpoint(
        monitor="val_auroc",
        dirpath=args.savedir,
        save_last=True,
        mode="max",
        save_top_k=3,
    )

    early_stop_callback = EarlyStopping(
        monitor="val_loss_epoch", 
        min_delta=0., 
        patience=args.patience, 
        verbose=False, 
        mode="min"
    )

    # Logger
    logger = None
    if args.log_to_wandb:
        os.makedirs(args.wandb_savedir, exist_ok=True)
        args.wandb_name += f"/{timestamp}"

        logger_type = "WandbLogger"
        logger_class = getattr(pl_loggers, logger_type)
        logger = logger_class(
            name=args.wandb_name,
            project=args.wandb_project,
            save_dir=args.wandb_savedir,
        )

    trainer = Trainer(
        max_epochs=args.num_epochs,
        callbacks=[checkpoint_callback],
        logger=logger,
        log_every_n_steps=1,
        check_val_every_n_epoch=1,
    )

    if args.train:
        trainer.fit(model, dm)
        best_ckpt = save_best_checkpoints(
            checkpoint_callback, args.savedir, return_best=True
        )
        print(f"Best checkpoint path: {best_ckpt}")
        args.checkpoint = best_ckpt

    if args.test:
        trainer.test(model=model, datamodule=dm)

## Defining arguments

In place of an argument parser, we create this TrainArguments class.

In [15]:
class TrainArguments():
  def __init__(
      self, 
      hidden_dim=128, 
      num_layers=2, 
      lr=0.001, 
      patience=5,
      num_epochs=50, 
      val_split=0.1, 
      test_split=0.2, 
      neg_sampling_ratio=1.0,  # 2.0
      edge_message_ratio=0.8,  # 0.5
      train=True, 
      test=True,
      checkpoint=None,
      savedir=".",
      log_to_wandb=True,
      wandb_savedir=DATA_DIR,
      wandb_name="clipme",
      wandb_project="clipme",
      ):
    
    self.hidden_dim = hidden_dim
    self.num_layers = num_layers
    self.lr = lr
    self.patience = patience
    self.num_epochs = num_epochs
    self.val_split = val_split
    self.test_split = test_split
    self.neg_sampling_ratio = neg_sampling_ratio
    self.edge_message_ratio = edge_message_ratio
    self.train = train
    self.test = test
    self.checkpoint = checkpoint
    self.savedir = savedir
    self.log_to_wandb = log_to_wandb
    self.wandb_savedir = wandb_savedir
    self.wandb_name = wandb_name
    self.wandb_project = wandb_project

# Run Training + Testing

We run training and testing for 10 different random seeds.

In [16]:
for seed in range(10):
  seed_everything(seed)
  args = TrainArguments(savedir=f'{DATA_DIR}/ckpt', log_to_wandb=True)
  main(args)

INFO:lightning_fabric.utilities.seed:Global seed set to 0
[34m[1mwandb[0m: Currently logged in as: [33mcvanuden[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_32_00/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 1


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc            0.924690192417929
     test_loss_epoch        0.49393850564956665
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_33_06/epoch=47-step=48.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 2


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9387171070720612
     test_loss_epoch        0.4821975529193878
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_34_07/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 3


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9348022107025259
     test_loss_epoch        0.4863097071647644
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_35_06/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 4


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9338177854760247
     test_loss_epoch        0.48798730969429016
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_36_07/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 5


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9144613400178818
     test_loss_epoch        0.5009862184524536
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_37_06/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 6


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9239858498472379
     test_loss_epoch        0.49454909563064575
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_38_08/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 7


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc            0.924557311516942
     test_loss_epoch        0.4947493374347687
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_39_09/epoch=47-step=48.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 8


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9333957549542955
     test_loss_epoch        0.48873189091682434
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_40_10/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

INFO:lightning_fabric.utilities.seed:Global seed set to 9


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9183536927439886
     test_loss_epoch        0.5008264780044556
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params
----------------------------------------------
0 | model   | Model             | 444 K 
1 | loss_fn

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /content/drive/.shortcut-targets-by-id/1nt4grspk05bfDsz_lJogqO137Fds1j-V/CS224W Final Project/ckpt/2023_03_23_02_41_11/epoch=49-step=50.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auroc           0.9190118047572464
     test_loss_epoch        0.49875780940055847
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [44]:
import numpy as np
import scipy.stats

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

test_aurocs = [0.924690192417929, 0.9387171070720612, 0.9348022107025259, 0.9338177854760247, 0.9144613400178818, 0.9239858498472379, 0.924557311516942, 0.9333957549542955, 0.9183536927439886, 0.9190118047572464]
mci = mean_confidence_interval(test_aurocs)
print(f'Test AUROC: {mci[0]} with 95% CI ({mci[1]}, {mci[2]})!')

Test AUROC: 0.9265793049506132 with 95% CI (0.9207386444491326, 0.9324199654520939)!


# Run Prediction

In this section, we define our clip retrieval predictor. Specifically, for a single **clip query**, we **retrieve the top k nearest clips from our graph**. We do this by adding the query clip node to the graph in addition to its corresponding Clip2Topic edges. Then, for each of the candidate clips sampled from the set of clips sharing the same topic, we add the Clip2Clip edge connecting the query to the candidate clip node and perform link prediction on that edge. We then retrieve the top k candidate clip nodes.

## Defining Predictor

In [24]:
class Predictor():
    """
    This will not run without a clip dataframe! 
    You;ll need to do the data preprocessing described near the top of this Colab first.

    You can either explicitly pass in the clip dataframe, 
    or load it from the underlying SpotifyPodcastDataset.
    """
    def __init__(self, args, clip_df=None):
        self.spotify_datamodule = SpotifyPodcastDataModule()
        
        self.node_types = self.spotify_datamodule.node_types
        self.edge_types = self.spotify_datamodule.edge_types
        
        self.start_indices = self.spotify_datamodule.spotify_dataset.start_indices
        
        self.model = Model(
            hidden_dim=args['hidden_dim'],
            node_types=self.node_types,
            edge_types=self.edge_types,
            num_layers=args['num_layers'],
            checkpoint=args['checkpoint']
        )
        
        if clip_df is None:
          self.clip_df = self.spotify_datamodule.spotify_dataset._load_df()
        else:
          self.clip_df = clip_df

        self.clip_mapping_path = self.spotify_datamodule.spotify_dataset.clip_mapping_path
        self.local_topic_mapping_path = self.spotify_datamodule.spotify_dataset.local_topic_mapping_path
        self.global_topic_mapping_path = self.spotify_datamodule.spotify_dataset.global_topic_mapping_path
        
        with open(self.spotify_datamodule.spotify_dataset.clip_mapping_path, 'rb') as f:
            self.clip_mapping = pickle.load(f)['mapping']
        
        self.clip_mapping_rev = {value:key for key, value in self.clip_mapping.items()} # reverse the mapping
        
        with open(self.spotify_datamodule.spotify_dataset.local_topic_mapping_path, 'rb') as f:
            self.local_topic_mapping = pickle.load(f)['mapping']
            
        with open(self.spotify_datamodule.spotify_dataset.global_topic_mapping_path, 'rb') as f:
            self.global_topic_mapping = pickle.load(f)['mapping']
        
        self.encoder = SentenceTransformer(args['sbert_model_name'], device=None)
        
        self.device = 'cpu'
        self.model.to(self.device)
        
        self.k = args['k']
        
        self.num_existing_nodes = self.spotify_datamodule.spotify_dataset.data.num_nodes
        self.num_existing_clip_nodes = self.spotify_datamodule.spotify_dataset.data[SUPERVISION_EDGE[0]].x.shape[0]
        self.num_topic_nodes = self.num_existing_nodes - self.num_existing_clip_nodes
    
    def add_new_nodes_and_edges(self, new_nodes, new_edges):
        # Add new node and edges to existing graph
        G = self.spotify_datamodule.nx_graph
        
        nodes_lst = []
        
        num_nodes = self.spotify_datamodule.spotify_dataset.data.num_nodes
        
        for node_type in self.spotify_datamodule.spotify_dataset.data.node_types:
            nodes = self.spotify_datamodule.spotify_dataset.data[node_type].x.to(self.device)
            
            # Append new clip nodes
            num_new_nodes = 0
            if node_type in new_nodes:
                nodes = torch.vstack([nodes, new_nodes[node_type]])
                num_new_nodes = new_nodes[node_type].shape[0]
            
                nodes_type_lst = [(id_ + num_nodes, {'node_type': node_type, 'node_feature': nodes[id_].to(self.device)}) for id_ in range(num_new_nodes)]
                num_nodes += num_new_nodes
            
                nodes_lst.extend(nodes_type_lst)
        
        edge_lst = []
        for i, edge_type in enumerate(self.edge_types):
            # Append new edges
            if edge_type in new_edges:
                edge_index = new_edges[edge_type]['edge_index']
                
                if 'edge_attr' in new_edges[edge_type]:
                    edge_attr = new_edges[edge_type]['edge_attr']
                    edge_type_lst = [(edge_index[0, id_].item(), edge_index[1, id_].item(), {'edge_type': edge_type[1], 'weight': edge_attr[id_]}) for id_ in range(edge_index.shape[-1])]
                else: 
                    edge_type_lst = [(edge_index[0, id_].item(), edge_index[1, id_].item(), {'edge_type': edge_type[1], 'weight': 1}) for id_ in range(edge_index.shape[-1])]
                
                edge_lst.extend(edge_type_lst)
        
        G.add_nodes_from(nodes_lst)
        G.add_edges_from(edge_lst)
        
        return HeteroGraph(G, directed=False)
    
    
    def get_new_nodes(self, clip_transcripts):
        """
        Embed the clip transcripts using sBERT
        """
        new_nodes = self.encoder.encode(
            clip_transcripts,
            show_progress_bar=False,
            convert_to_tensor=True,
            device=self.device,
        ) # (batch_size, embed_dim)
        
        self.num_new_nodes = new_nodes.shape[0]
        
        return {'clip': new_nodes}
        
        
    def get_new_edges(self, clip_neighbour_candidate, local_topic_id, global_topic_id):
        """
        Add Clip2Clip (same episode) edges and Clip2Topic edges (global and local)
        """
        
        new_edges_dict = {}
        
        # Add new Clip2Clip edges (both directions)
        # Note: new node id = self.num_existing_nodes
        new_edges = torch.tensor([[clip_neighbour_candidate, self.num_existing_nodes], [self.num_existing_nodes, clip_neighbour_candidate]]).T
        
        # Edge index label will point to all the newly added same_episode edges 
        # Note: very hacky way to get the edge_index_label
        new_edges_dict[SUPERVISION_EDGE] = {'edge_index': new_edges.to(self.device)}
        self.edge_index_label = torch.where(
            new_edges_dict[SUPERVISION_EDGE]['edge_index'] > self.num_existing_clip_nodes,
            new_edges_dict[SUPERVISION_EDGE]['edge_index'] - self.num_topic_nodes,
            new_edges_dict[SUPERVISION_EDGE]['edge_index']
        )
        
        # new global and local clip2topic edges -- for message passing
        if local_topics:
            new_clip2topic_edges = torch.tensor([[self.num_existing_nodes, local_topic_id + self.start_indices['local_topic']], 
                                                 [local_topic_id + self.start_indices['local_topic'], self.num_existing_nodes]]).T
            new_edges_dict[LOCAL_CLIP2TOPIC_EDGE] = {'edge_index': new_clip2topic_edges.to(self.device)}
        
        if global_topics:
            new_clip2topic_edges = torch.tensor([[self.num_existing_nodes, global_topic_id + self.start_indices['global_topic']],
                                                  [global_topic_id + self.start_indices['global_topic'], self.num_existing_nodes]]).T
            new_edges_dict[GLOBAL_CLIP2TOPIC_EDGE] = {'edge_index': new_clip2topic_edges.to(self.device)}

        return new_edges_dict
    
    def predict_for_candidate(self, clip_neighbour_candidate, clip_transcript, local_topic_id, global_topic_id, return_feats=True):
        """
        Edge prediction:
        - Connect query node to candidate node
        - Edge label index would correspond to the newly added edges
        - Run forward pass of the GNN to get probabilities for each of those newly connected edges
        - Pick top k connected nodes and return node information
        """
        
        # Add nodes and edges to graph
        new_nodes = self.get_new_nodes([clip_transcript])
        new_edges = self.get_new_edges(clip_neighbour_candidate, local_topic_id, global_topic_id)
        self.new_graph = self.add_new_nodes_and_edges(new_nodes, new_edges)
        
        # GNN forward pass
        logits = self.model(
            self.new_graph.node_feature, 
            self.new_graph.edge_index, 
            self.new_graph.weight,
            self.edge_index_label
        ) # (2, )
        # print("=> Finished forward pass of model")
        
        # probs = torch.sigmoid(logits, dim=-1)
        probs = torch.sigmoid(logits)
        prob = torch.mean(probs)
        return prob
        
    def sample_clips_from_same_topic(self, local_topics, global_topics, num_to_sample=10):
        """
        We sample a maximum of num_to_sample clips that share the same local and global topics, respectively.
        """
        all_sampled_clips = []
        for i in range(len(local_topics)):
            local_topic = local_topics[i]
            global_topic = global_topics[i]
            
            clips_with_same_local = self.clip_df.loc[self.clip_df['topic_id_local'] == local_topic]['episode_clip_uri'].values
            clips_with_same_local = [self.clip_mapping[clip_uri] for clip_uri in clips_with_same_local]
            clips_with_same_global = self.clip_df.loc[self.clip_df['topic_id_global'] == global_topic]['episode_clip_uri'].values
            clips_with_same_global = [self.clip_mapping[clip_uri] for clip_uri in clips_with_same_global]
            
            print(f'{len(clips_with_same_local)} clips with same local topic')
            print(f'{len(clips_with_same_global)} clips with same global topic')
            sampled_clips = random.sample(clips_with_same_local, np.min([num_to_sample, len(clips_with_same_local)]))
            sampled_clips.extend(random.sample(clips_with_same_global, np.min([num_to_sample, len(clips_with_same_global)])))
            
            all_sampled_clips.append(sampled_clips)

        return all_sampled_clips
        
        
    def predict(
        self, 
        clip_transcripts: List[str], 
        local_topics, 
        global_topics, 
        sampled_clips_from_same_topic, 
        return_feats=True
    ):
        all_logits = []
        for i, clip_transcript in enumerate(clip_transcripts):
            clip_logits = []
            for clip_num in tqdm.tqdm(sampled_clips_from_same_topic[i]):
                logits = self.predict_for_candidate(clip_num, clip_transcript, local_topics[i], global_topics[i], 
                                                    return_feats)
                clip_logits.append(logits)
            all_logits.append(clip_logits)
        all_logits = torch.as_tensor(all_logits)
        knn = torch.topk(all_logits, self.k, dim=-1).indices # (num_new_nodes, k)

        f = lambda i, j : sampled_clips_from_same_topic[i][j]
        knn = torch.as_tensor([f(i, knn[i][j]) for i in range(knn.shape[0]) for j in range(knn.shape[1])]).reshape((knn.shape))
        print(knn)
        
        # Return features (list of df in sorted order from most to least nearest neighbour)
        if return_feats:
            # Get corresponding episode_clip_uris and then corresponding row in clip_df
            all_knns = []
            for i in range(knn.shape[0]):
                knn_for_node = pd.DataFrame(columns=self.clip_df.columns)
                for j in range(knn.shape[1]):
                    episode_clip_uri = self.clip_mapping_rev[knn[i, j].item()]
                    df_row = self.clip_df.loc[self.clip_df['episode_clip_uri'] == episode_clip_uri]
                    knn_for_node = pd.concat([knn_for_node, df_row])
                
                all_knns.append(knn_for_node.reset_index(drop=True))
            knn = all_knns
            
        return knn

## Defining arguments

In [20]:
seed_everything(0)
args = {
    'hidden_dim': 128,
    'num_layers': 2,
    'sbert_model_name': 'all-MiniLM-L6-v2',
    'checkpoint': 'ckpt/2023_03_23_02_33_06/epoch=47-step=48.ckpt',
    'k': 3,
    'num_to_sample': 30
}

INFO:lightning_fabric.utilities.seed:Global seed set to 0


## Example 0: Science and Fitness
We first get the global and local topics for the new inputs, and then inspect the top-K retrieved clips.

In [21]:
clip_queries = [
    """
    Why do humans need sleep? Let's go with the big first question. OK, well, the answer I'll start with is the one that I always default to when there's a why question, because I wasn't consulted in the design phase. So, so, so I wriggle my way out of giving a absolute answer. Right. But there's one mechanism that's very clear that superimportant, which is that the longer we are awake, the more adenosine accumulates in our brain and adenosine binds to adenosine receptors. No surprise there. And it creates the feeling of sleepiness independent of time of day or night. So there are two mechanisms. One is we get sleepy. As adenosine accumulates, the longer we've been awake, the more adenosine has accumulated in our system, but how sleepy we get for a given amount of adenosine depends on where we are in this so-called circadian cycle. And the circadian cycle is just very, very well conserved oscillation. It's a temperature oscillation where you go from a low point. Typically, if you're awake during the day, in your sleep at night, you're your lowest temperature point will be. Like 3:00 a.m., 4:00 a.m., and then your temperature will start to creep up as you wake up in the morning and then it'll peak in the late afternoon and then it'll start to drop again toward the evening and then you get sleep again. That oscillation in temperature takes twenty four hours plus or minus your. Yeah. Plus or minus an hour. And I don't even though I wasn't consulted at the design phase, I do not think it's a coincidence that it's aligned to the 24 hour spin of the Earth on its axis and the fact that we tend to be bathed in sunlight for a portion of that spin and in darkness for the other portion that's been. So there are two mechanisms, the adenosine accumulation and the circadian time point that we happen to be at, and those converge to create a sense of sleepiness or wakefulness. The simple way to reveal these two mechanisms to uncouple them is stay up for twenty four hours and you will find that even though you've been let's say you stay up midnight, 2:00 a.m., 3:00 a.m., provided you're on a regular schedule like that, I follow not like the kind that you follow. You get I will get very sleepy around 3:00, 4:00 a.m. but then around 5:00 or 6:00 or 7:00 a.m., which is my normal wake up time, I'll start to feel more alert, even though adenosine has been accumulating further. So adenosine is higher for me the longer I stay up. And yet I feel more alert than I did a few hours ago. And that's because these are two interacting forces. So adenosine makes you sleepy. And then just how sleepy or how awake you feel also depends on where you are in this temperature oscillation that takes 24 hours. OK, so that's fascinating. So there's a bunch of oscillations going on and then it kind of through the evolutionary process, have evolved to all be aligned somewhat and they interplay. So you said your body temperature goes up and down, the chemicals in your brain that oscillate. And then there's the actual oscillation of the sun in the in the sky. So all of that together. Has some impact on each other and somehow that all results in us wanting to go to sleep every night. Right. So and we can get right into the meat of this issue. I guess we just dove right in. But the so the temperature oscillation is the effect of the circadian clock. So every cell in our body has a 24 hour rhythm that's dictated by genes like clock per bommel. This is one of the great successes of biology. They give a Nobel Prize to Reppert and I don't know if Reppert got it. Forgive me, but sorry if you got it. Steve, congratulations. If you didn't, I'm sorry I wasn't on the committee nonetheless. Did beautiful work, Steve Ripperton and others. But Mike Ross Bashan, like other people, worked out these mechanisms and flies and bacteria and mammals. There are these genes that create 24 hour oscillations in gene expression, et cetera, in every cell of our body. But what aligns those is a signal from the master circadian clock, which sits right above the roof of the mouth called the suprachiasmatic nucleus. And that clock synchronizes all the clocks of the body to this general temperature rhythm by way of controlling systemic temperature, which makes perfect sense if you want to create a general oscillation in all the tissues and organs of the body, use temperature. And so that work on temperature, if people want to explore further, was Joe Takahashi, who was at Northwestern, now at UT Southwestern in Dallas. And it is absolutely clear that humans do better on a diagonal schedule sorry, less than a nocturnal schedule because you could say, well, provided I sleep and push adenosine back downhill, which is what happens when we sleep. Adenosine is then reduced and provided I am on more or less a 24 hour schedule. Why should it matter that I'm awake when the sun's out and and I'm asleep when the sun is down? But it it turns out that if you look at health metrics, people that are strictly nocturnal do far worse on immune function or metabolic function, et cetera, than people who are diurnal, who are awake during the daytime and animals that are nocturnal.
    """
]

In [22]:
local_topic_model = BERTopic.load(os.path.join(INTERMEDIATE_DATA_DIR, 'bertopic_model_clips5_10_1_10'))
global_topic_model = BERTopic.load(os.path.join(INTERMEDIATE_DATA_DIR, 'bertopic_model_clips5_100_1_10'))

print('=> Fetching local topics')
local_topics, _ = local_topic_model.transform(clip_queries)
print('=> Fetching global topics')
global_topics, _ = global_topic_model.transform(clip_queries)

# topic+=1 so we don't ignore -1 noise topic!
# this pre-processing was done during training too
local_topics = [t + 1 for t in local_topics]
global_topics = [t + 1 for t in global_topics]

=> Fetching local topics


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

=> Fetching global topics


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [25]:
predictor = Predictor(args)
sampled_clips_from_same_topic = predictor.sample_clips_from_same_topic(local_topics, global_topics, args['num_to_sample'])
knn = predictor.predict(clip_queries, local_topics, global_topics, sampled_clips_from_same_topic)

<All keys matched successfully>
127 clips with same local topic
117 clips with same global topic


100%|██████████| 60/60 [04:14<00:00,  4.24s/it]


tensor([[26403, 19595, 35739]])


In [26]:
for i in range(len(knn)):
  print(f"Query: {clip_queries[i]}")

  for j, df_row in knn[i].iterrows():
      local_topic_name = df_row['topic_name_local']
      global_topic_name = df_row['topic_name_global']
      transcript = df_row['transcript']
      show_name, episode_name = df_row['show_name'], df_row['episode_name']
      episode_description = df_row['episode_description']
      
      print("="*80)
      print(f"{j}-NN:")
      print(f"Show name: {show_name}")
      print(f"Episode name: {episode_name}")
      print(f"Local topic: {local_topic_name}")
      print(f"Global topic: {global_topic_name}")
      print(f"Transcript: {transcript}")
      print("="*80)
      print()

Query: 
    Why do humans need sleep? Let's go with the big first question. OK, well, the answer I'll start with is the one that I always default to when there's a why question, because I wasn't consulted in the design phase. So, so, so I wriggle my way out of giving a absolute answer. Right. But there's one mechanism that's very clear that superimportant, which is that the longer we are awake, the more adenosine accumulates in our brain and adenosine binds to adenosine receptors. No surprise there. And it creates the feeling of sleepiness independent of time of day or night. So there are two mechanisms. One is we get sleepy. As adenosine accumulates, the longer we've been awake, the more adenosine has accumulated in our system, but how sleepy we get for a given amount of adenosine depends on where we are in this so-called circadian cycle. And the circadian cycle is just very, very well conserved oscillation. It's a temperature oscillation where you go from a low point. Typically, if y

## Example 1: True Crime
We first get the global and local topics for the new inputs, and then inspect the top-K retrieved clips.

In [48]:
clip_queries = [
    """
    At 10 p.m., Cancellara entered a convenience store across from the Sheraton. The clerk brewed him a fresh cup of coffee and watched as he returned to the hotel. It was the last time anyone saw Danny Cancellara alive. At nine a.m. the next morning, Casolaro, housekeeper in Virginia, received several threatening phone calls. One anonymous voice said, I will cut up his body and throw it to the sharks. A half hour later, another man called and simply said, Drop dead. The phone rang several more times that day, but the housekeeper said there was no one on the line, only music or silence. Back in Martinsburg, a Sheraton maid entered Casolaro hotel room at 150 p.m. She stepped into the bathroom and was horrified to find Casolaro lying naked in the bathtub dead. Within an hour, officers arrived to assess the scene. The coroner noted that Casolaro had several deep lacerations on his wrists. On the desk was a short note. It reportedly ended with the reassurance, God will let me in. There was no sign of a struggle based on this evidence, Kessler, whose death was ruled a suicide. However, there were some irregularities. Casolaro research notes, including the documents he'd gotten from Turner, were gone. Furthermore, the police report contained errors that couldn't be explained. It mentioned that there were plastic bags found in the tub with Casolaro, his body. There was also a used shoelace draped around his neck. As far as we can tell, these items weren't inspected or dusted for fingerprints and their presence was never explained. Even stranger, the coroner didn't follow the standard procedure of filtering the water as they drained the tub, though, a razor blade was recovered. Other important evidence may have been washed away. Authorities notified Kessler's family of his death two days later, on Monday, August 12th. They presented it as a clear cut case of suicide. The caseloads were shocked by the news. Tony recalled his brother's warning that if something were to happen to him, not to believe it was an accident. Haunted by this, he asked what the autopsy had shown. Incredibly, the police hadn't scheduled one. Tony was also perturbed to learn that none of Casolaro research was recovered from the hotel room, either. Tony demanded the police perform a thorough investigation and an autopsy. But Casolaro’s body had already been transferred to a funeral home and embalmed, meaning it was injected with preservatives. This made an accurate autopsy far more difficult. It also seemed like yet another suspicious clue. Before a body can be embalmed, the deceased's family has to consent for the procedure. So what happened to Casolaro body wasn't just unusual, it was actually illegal. As a result, his brother Tony demanded to know how this happened, but he was simply told it was a mistake, one that couldn't be explained. Despite this obstacle, the West Virginia medical examiner attempted an autopsy. He found the cause of death to be blood loss from multiple deep slashes on the reporter's wrists. This corroborated the conclusion of suicide. But according to Casolaro brother, the examiner also found two mysterious bruises on Casolaro Head. This contradicted the original police report, saying there was no sign of a struggle. Finally, the medical examiner found trace amounts of hydrocodone and an unidentifiable antidepressant in Casolaro system. Tony claimed his brother didn't have a prescription for painkillers or antidepressants due to these Incongruence sees the police returned to Kessler's room to search for additional evidence. Unfortunately, it had already been professionally cleaned in the aftermath, authorities maintain that the reporter's death was a suicide case. Casolaro body was brought back to Virginia at his funeral. The family was left with one final mystery before his casket was lowered into the ground.
    """
]

In [49]:
local_topic_model = BERTopic.load(os.path.join(INTERMEDIATE_DATA_DIR, 'bertopic_model_clips5_10_1_10'))
global_topic_model = BERTopic.load(os.path.join(INTERMEDIATE_DATA_DIR, 'bertopic_model_clips5_100_1_10'))

print('=> Fetching local topics')
local_topics, _ = local_topic_model.transform(clip_queries)
print('=> Fetching global topics')
global_topics, _ = global_topic_model.transform(clip_queries)

# topic+=1 so we don't ignore -1 noise topic!
# this pre-processing was done during training too
local_topics = [t + 1 for t in local_topics]
global_topics = [t + 1 for t in global_topics]

=> Fetching local topics


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

=> Fetching global topics


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [51]:
predictor = Predictor(args)
sampled_clips_from_same_topic = predictor.sample_clips_from_same_topic(local_topics, global_topics, args['num_to_sample'])
knn = predictor.predict(clip_queries, local_topics, global_topics, sampled_clips_from_same_topic)

<All keys matched successfully>
250 clips with same local topic
662 clips with same global topic


100%|██████████| 60/60 [04:18<00:00,  4.30s/it]

tensor([[ 2017, 16181,  2014]])





In [52]:
for i in range(len(knn)):
  print(f"Query: {clip_queries[i]}")

  for j, df_row in knn[i].iterrows():
      local_topic_name = df_row['topic_name_local']
      global_topic_name = df_row['topic_name_global']
      transcript = df_row['transcript']
      show_name, episode_name = df_row['show_name'], df_row['episode_name']
      episode_description = df_row['episode_description']
      
      print("="*80)
      print(f"{j}-NN:")
      print(f"Show name: {show_name}")
      print(f"Episode name: {episode_name}")
      print(f"Local topic: {local_topic_name}")
      print(f"Global topic: {global_topic_name}")
      print(f"Transcript: {transcript}")
      print("="*80)
      print()

Query: 
    
0-NN:
Show name: Not Guilty
Episode name: Insanity: John Hinckley Jr. Pt. 2
Local topic: 6_investigators_sutcliffe_detectives_murders
Global topic: 9_police_investigators_murders_detectives
Transcript: at a range and there was evidence of a potential thwarted attempt to find President Jimmy Carter in Nashville. Hinckley had also laid out his plans in several Love Letters to Jodie Foster this all directly. Did the defense's insanity claims Adelman explain to the jury that they would hear from several psychiatrists who would all insist that Hinckley was in complete control of his actions when he pulled the trigger. He was seeking an easy route to fame. He emphasized the fact that Hinckley committed his crime in front of hundreds of eyewitnesses and dozens of reporters and news cameras thousands more Across the Nation saw his Jen's and he wanted it that way to drive home that point Adelman played for the jury a videotape of the scene outside the Washington Hilton Hotel everyo

## Example 2: Sports
We first get the global and local topics for the new inputs, and then inspect the top-K retrieved clips.

In [31]:
clip_queries = [
    """
    Here we are. All right, quarterbacks. Running backs, this is our rookie overview show now we're doing this little earlier than we have done in years past. The anticipation just too strong. Got to put on the dynasty pants and start at the quarterback position. Starting with Trevor Lawrence, it would be. I hear he's good. Yeah, I would say it would be a shock heard around the world if Trevor Lawrence somehow didn't go number one. So we know where he's going. This is an advantage, right, in breaking down a player's future potential. Everybody knows who Trevor Lawrence is, number one recruit coming out of high school and delivered on all the promise. Right. Coming into the league. And this was tanking for Trevor two years ago for Miami. This is a player that is built like an NFL quarterback above average. Six to twenty. Yeah. Whoof, great athleticism. Great arm strength. Yes. Can make all the throws a mature quarterback, like of all the quarterbacks coming out, he's one that I think has displayed the most pro traits, most the most consistently. And he's got above average, you know, escape ability. I think one thing that was really interesting when looking at Trevor Lawrence and I'll get your reaction to it is that he had the highest percentage of attempts in the run, like the run pass options and in screen passes, which is one of those things that, you know, you can't blame a quarterback for what offense they're in in college. They have to just execute that offense. They have lost a few weapons. I mean, you had lost Higgins the year before and then they had some this their offense this year was more of a this is what we have to do. And Trevor Lawrence still succeeded right. In that offense. So I don't look at that as a point against Trevor Lawrence, his entire profile. When you watch his games, it's not Dwayne Haskins like he he had to modify this year, but he still can. He still goes down the field. Yeah. One of the things that puts Trevor Lawrence at everybody's number one spot is longevity. You don't have the single season. Like if I take Joe Burrow's final season and I put that against Trevor Lawrence's final season, Joe Burrow was better to me. I like I like the tape. I don't think Trevor Lawrence is a perfect prospect, you know, but nobody is is 100 percent perfect. What you like about Trevor Lawrence is that, as Andy alluded to, he came in as the number one recruit from high school and then he he improved every year. So you have so much tape as we're going to talk through these other quarterbacks. Not everybody has this long track record of success where he's dealt with the being the guy, he's dealt with the hype, he's dealt with the media, he's overcome everything. He's done everything you wanted to see in college. So he's pretty much everyone's number one pick. And and specifically on this show we're talking about for fantasy. And while he isn't going to be the most prolific rusher, he is mobile. He will add yards on the ground. I see him as a comp to a Justin Herbert or an Aaron Rodgers. That's kind of where I've got him. And I think he'll be great for fantasy, because when you have that capital in that hype, you guarantee he's got like if he sucks, if he comes in and flames out, you still have four years of being an NFL starter and producing something for fantasy. And he's not going to suck. So in rookie draft, because we're going through these guys based off of our our fantasy rookie rankings currently before we know landing places. Are you right now very confident that Jason, he would be the first quarterback you would take in a rookie pick, even though, like guys like Justin Fields and in Lance, who they have massive rushing upside, especially trade, Lance, you would still feel more comfortable going with? I would feels like the safer pick. Yes, because of the because of the safety. Because I do think Trevor Lawrence is probably a ten year starter and is just a you know, he really is. We've heard it over and over and over the best prospect since Andrew Luck. But he he is he's one of those rare guys that was touted super young and has come through every stage along the way. Yeah. n his size, physicality, the ability to run the football and give you a baseline every game like Andrew Luck is a good comp for him, too. And I guess the question then for for fantasy players, because I think we're entering a new era right. Where you can be a rookie quarterback and make a fantasy impact. Justin Herbert did a last year. Jubera did it at times last year. We saw Baker Mayfield do it at times in early season. Cuyler did it. This is no longer you're not hamstrung because a lot of the offenses in the NFL now are innovating. And he goes to Jacksonville. But you don't go to the same Doug Marone offense, right? You have right. You have a renovated scheme. You have some weapons that I think we all like there. Yeah. Lévesque National, Chaak, James Robinson, Marvin and Marvin Jones now. So how long before Lawrence threatens as a fantasy, um, contributor? Is it is it a year one expectation for you because he comes out with the capital? Or is this something that you are you're looking two or three years down the line? I think he's a year one streamer. That's that's how I view him. I don't think he's going to be a locked and loaded Justin Herbert. You've got to start him every week once he catches fire in the rookie year. But I think he's going to have a ton of pretty much more.
    """
]

In [32]:
local_topic_model = BERTopic.load(os.path.join(INTERMEDIATE_DATA_DIR, 'bertopic_model_clips5_10_1_10'))
global_topic_model = BERTopic.load(os.path.join(INTERMEDIATE_DATA_DIR, 'bertopic_model_clips5_100_1_10'))

print('=> Fetching local topics')
local_topics, _ = local_topic_model.transform(clip_queries)
print('=> Fetching global topics')
global_topics, _ = global_topic_model.transform(clip_queries)

# topic+=1 so we don't ignore -1 noise topic!
# this pre-processing was done during training too
local_topics = [t + 1 for t in local_topics]
global_topics = [t + 1 for t in global_topics]

=> Fetching local topics


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

=> Fetching global topics


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [33]:
predictor = Predictor(args)
sampled_clips_from_same_topic = predictor.sample_clips_from_same_topic(local_topics, global_topics, args['num_to_sample'])
knn = predictor.predict(clip_queries, local_topics, global_topics, sampled_clips_from_same_topic)

<All keys matched successfully>
596 clips with same local topic
22637 clips with same global topic


100%|██████████| 60/60 [04:23<00:00,  4.39s/it]

tensor([[38445, 24739, 27873]])





In [34]:
for i in range(len(knn)):
  print(f"Query: {clip_queries[i]}")

  for j, df_row in knn[i].iterrows():
      local_topic_name = df_row['topic_name_local']
      global_topic_name = df_row['topic_name_global']
      transcript = df_row['transcript']
      show_name, episode_name = df_row['show_name'], df_row['episode_name']
      episode_description = df_row['episode_description']
      
      print("="*80)
      print(f"{j}-NN:")
      print(f"Show name: {show_name}")
      print(f"Episode name: {episode_name}")
      print(f"Local topic: {local_topic_name}")
      print(f"Global topic: {global_topic_name}")
      print(f"Transcript: {transcript}")
      print("="*80)
      print()

Query: 
    Here we are. All right, quarterbacks. Running backs, this is our rookie overview show now we're doing this little earlier than we have done in years past. The anticipation just too strong. Got to put on the dynasty pants and start at the quarterback position. Starting with Trevor Lawrence, it would be. I hear he's good. Yeah, I would say it would be a shock heard around the world if Trevor Lawrence somehow didn't go number one. So we know where he's going. This is an advantage, right, in breaking down a player's future potential. Everybody knows who Trevor Lawrence is, number one recruit coming out of high school and delivered on all the promise. Right. Coming into the league. And this was tanking for Trevor two years ago for Miami. This is a player that is built like an NFL quarterback above average. Six to twenty. Yeah. Whoof, great athleticism. Great arm strength. Yes. Can make all the throws a mature quarterback, like of all the quarterbacks coming out, he's one that I 

# Running the sBERT Baseline

In this section, we run our sBERT baseline where we run sBERT on our clip transcripts and compute cosine similarity between clip node features to predict links.

In [54]:
test_mapping = pickle.load(open(os.path.join(DATA_DIR, "data_splits_seed_1234.pkl"), "rb"))['test']

edge_index = test_mapping['edge_label_index']
edge_label = test_mapping['edge_label']
edge_preds = torch.zeros(edge_label.shape)

model = SentenceTransformer("all-MiniLM-L6-v2")
if torch.cuda.is_available():
    model.to("cuda:0")

all_clips = pd.read_csv(os.path.join(INTERMEDIATE_DATA_DIR, "processed_clips_w_global_local_topics.csv"))
all_transcripts = list(all_clips['transcript'])
all_embeds = {}

for i in tqdm.tqdm(range(edge_index.shape[1])):
    n1, n2 = edge_index[0, i], edge_index[1, i]

    if n1 not in all_embeds:
        n1_clip_text = all_transcripts[n1]
        n1_embed = model.encode(n1_clip_text)
        all_embeds.update({n1: n1_embed})

    if n2 not in all_embeds:
        n2_clip_text = all_transcripts[n2]
        n2_embed = model.encode(n2_clip_text)
        all_embeds.update({n2: n2_embed})
        
    edge_preds[i] = max(0, cos_sim(all_embeds[n1], all_embeds[n2]))

100%|██████████| 120366/120366 [58:48<00:00, 34.11it/s]


In [55]:
baseline_test_auroc = compute_auroc(edge_label, edge_preds)
print(f"Baseline test AUROC: {baseline_test_auroc}")

Baseline test AUROC: 0.7566061440364726
