In [None]:
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import dataclasses
import os
import tensorflow.compat.v1 as tf
import itertools
import json
from typing import Collection, Mapping

from IPython import display
import numpy as np
import pandas as pd

In [None]:
PROJECT_ROOT = '/your/project/folder/here'
SUBDIR = 'your_subdir'

In [None]:
MODELS = [
    'dyrep',
    'edgebank',
    'jodie',
    'tgn',
    'tgn_structmap',
]
DATA = [
    'tgbl_wiki;cc-subgraph;cc-subgraph;cc-subgraph',
    'tgbl_review;cc-subgraph;cc-subgraph;cc-subgraph',
    'tgbl_flight;AS;AS;AF',
    'tgbl_comment;cc-subgraph;cc-subgraph;cc-subgraph',
    'tgbl_coin;cc-subgraph;cc-subgraph;cc-subgraph',
]
EXPERIMENTS = [
    'transductive',
    'transfer_no_warmstart',
    'transfer_warmstart',
]

In [None]:
# Transform inputs.
@dataclasses.dataclass(frozen=False)
class DatasetSpec:
  dataset: str = dataclasses.field(default_factory=str)
  train_split: str = dataclasses.field(default_factory=str)
  val_split: str = dataclasses.field(default_factory=str)
  test_split: str = dataclasses.field(default_factory=str)

@dataclasses.dataclass(frozen=False)
class ExperimentResults:
  experiment: str = dataclasses.field(default_factory=str)
  train_results: dict[str, float] = dataclasses.field(default_factory=dict)
  test_results: dict[str, float] = dataclasses.field(default_factory=dict)
  val_warmstart_loss_metrics: pd.DataFrame = dataclasses.field(default_factory=pd.DataFrame)
  val_loss_metrics: pd.DataFrame = dataclasses.field(default_factory=pd.DataFrame)
  test_warmstart_loss_metrics: pd.DataFrame = dataclasses.field(default_factory=pd.DataFrame)
  test_loss_metrics: pd.DataFrame = dataclasses.field(default_factory=pd.DataFrame)

@dataclasses.dataclass(frozen=False)
class ModelResults:
  model: str = dataclasses.field(default_factory=str)
  experiment_results: dict[str, ExperimentResults] = dataclasses.field(default_factory=dict)

@dataclasses.dataclass(frozen=False)
class DatasetResults:
  dataset: str = dataclasses.field(default_factory=str)
  model_results: dict[str, ModelResults] = dataclasses.field(default_factory=dict)

DATASET_SPECS = []
DATASETS = []
for dataset_string in DATA:
  dataset, train_split, val_split, test_split = dataset_string.split(';')
  DATASETS.append(dataset)
  DATASET_SPECS.append(DatasetSpec(dataset, train_split, val_split, test_split))

RESULTS_SUBDIR = os.path.join(PROJECT_ROOT, 'experiments', SUBDIR)

In [None]:
ALL_RESULTS = {}
train_results_df = pd.DataFrame(
    index=pd.MultiIndex.from_product(
        [MODELS, EXPERIMENTS],
        names=['Model', 'Experiment']
    ),
    columns=pd.MultiIndex.from_product(
        [DATASETS, ['auc', 'mrr']],
        names=['Dataset', 'Metric']
    )
)
test_results_df = pd.DataFrame(
    index=pd.MultiIndex.from_product(
        [MODELS, EXPERIMENTS],
        names=['Model', 'Experiment']
    ),
    columns=pd.MultiIndex.from_product(
        [DATASETS, ['auc', 'mrr']],
        names=['Dataset', 'Metric']
    )
)
for dataset_spec in DATASET_SPECS:
  dataset_results = DatasetResults(dataset=dataset_spec.dataset)
  for model in MODELS:
    model_results = ModelResults(model=model)
    model_dataset_folder = os.path.join(
        RESULTS_SUBDIR,
        dataset_spec.dataset, 'results',
        f'{model}_{dataset_spec.dataset}_{dataset_spec.train_split}_{dataset_spec.val_split}'
    )
    for experiment in EXPERIMENTS:

      # Extract results for train.
      experiment_results = ExperimentResults(experiment=experiment)
      with tf.io.gfile.GFile(os.path.join(model_dataset_folder, f'{experiment}_results_train.json'), 'r') as f:
        experiment_results.train_results = json.load(f)
      with tf.io.gfile.GFile(os.path.join(model_dataset_folder, f'{experiment}_val_loss.csv'), 'r') as f:
        experiment_results.val_loss_metrics = pd.read_csv(f)
      if not 'no_warmstart' in experiment:
        with tf.io.gfile.GFile(os.path.join(model_dataset_folder, f'{experiment}_val_warmstart_loss.csv'), 'r') as f:
          experiment_results.val_warmstart_loss_metrics = pd.read_csv(f)
      train_results_df.loc[model, experiment].at[dataset_spec.dataset, 'auc'] = experiment_results.train_results['auc']
      train_results_df.loc[model, experiment].at[dataset_spec.dataset, 'mrr'] = experiment_results.train_results['val mrr']

      # Extract results for test.
      # TODO: save test_split-specific results in their own subfolder of the train/val split folder.
      with tf.io.gfile.GFile(os.path.join(model_dataset_folder, f'{experiment}_results_test_{dataset_spec.test_split}.json'), 'r') as f:
        experiment_results.test_results = json.load(f)
      with tf.io.gfile.GFile(os.path.join(model_dataset_folder, f'{experiment}_test_loss.csv'), 'r') as f:
        experiment_results.test_loss_metrics = pd.read_csv(f)
      if not 'no_warmstart' in experiment:
        with tf.io.gfile.GFile(os.path.join(model_dataset_folder, f'{experiment}_test_warmstart_loss.csv'), 'r') as f:
          experiment_results.test_warmstart_loss_metrics = pd.read_csv(f)
      # TODO: align metric variable names across train and test.
      test_results_df.loc[model, experiment].at[dataset_spec.dataset, 'auc'] = experiment_results.test_results['test auc']
      test_results_df.loc[model, experiment].at[dataset_spec.dataset, 'mrr'] = experiment_results.test_results['test mrr']

      model_results.experiment_results[experiment] = experiment_results

    dataset_results.model_results[model] = model_results
  ALL_RESULTS[dataset_spec.dataset] = dataset_results

In [None]:
import copy
import seaborn as sns
from matplotlib import pyplot as plt

def plot_eval_metric_curves(
    dataset='tgbl_wiki',
    experiments=EXPERIMENTS,
    model='tgn',
    val=True,
    metric_name='perf',
    master_results_dict=ALL_RESULTS,
    ax=None,
    show_legend=True,
):

  eval_df_string = 'val' if val else 'test'
  plot_dataframes = []

  # Make sure that if there is warmstart experiment that it comes first.
  experiment_list = copy.deepcopy(experiments)
  first_experiment = ''
  for experiment in experiment_list:
    if 'no_warmstart' not in experiment:
      first_experiment = copy.deepcopy(experiment)
  experiment_list.remove(first_experiment)
  experiment_list = [first_experiment] + experiment_list

  warmstart_end_index = 0
  for idx, experiment in enumerate(experiment_list):
    if 'no_warmstart' not in experiment:
      warmstart_df = getattr(
          master_results_dict[dataset].model_results[model].experiment_results[experiment],
          f'{eval_df_string}_warmstart_loss_metrics'
      )
      if idx == 0:
        warmstart_end_index = len(warmstart_df)
      warmstart_df['batch_index'] = list(range(warmstart_end_index))
      warmstart_df = warmstart_df.melt(
          id_vars=['batch_index'],
          value_vars=['loss', 'perf', 'auc'],
          value_name='metric_value',
          var_name='metric_name'
      )
      warmstart_df['experiment'] = experiment
      warmstart_df = warmstart_df[
          warmstart_df.metric_name == metric_name
      ].copy()
      warmstart_df['period'] = 'warmstart'
      plot_dataframes.append(warmstart_df)


    eval_df = getattr(
        master_results_dict[dataset].model_results[model].experiment_results[experiment],
        f'{eval_df_string}_loss_metrics'
    )
    eval_df['batch_index'] = list(range(warmstart_end_index, len(eval_df) + warmstart_end_index))
    eval_df = eval_df.melt(
        id_vars=['batch_index'],
        value_vars=['loss', 'perf', 'auc'],
        value_name='metric_value',
        var_name='metric_name'
    )
    eval_df['experiment'] = experiment
    eval_df = eval_df[eval_df.metric_name == metric_name].copy()
    eval_df['period'] = 'eval'
    plot_dataframes.append(eval_df)

  # Make plot.
  master_plot_dataframe = pd.concat(plot_dataframes, axis=0)

  if ax is None:
    plt.figure(figsize=(15, 10))
    plt.title(f'{dataset} {eval_df_string} {model} {metric_name}')
    sns.lineplot(
        data=master_plot_dataframe,
        x='batch_index',
        y='metric_value',
        hue='experiment',
        style='period',
        legend=show_legend,
    )
    plt.show()
    return master_plot_dataframe
  else:
    ax.set_title('validation' if eval_df_string == 'val' else 'test')
    sns.lineplot(
        data=master_plot_dataframe,
        x='batch_index',
        y='metric_value',
        hue='experiment',
        style='period',
        ax=ax,
        legend=show_legend,
    )

In [None]:
for model in MODELS:
  display.display(display.Markdown(f'# {model}'))

  for dataset in DATASETS:
    display.display(display.Markdown(f'## {dataset}'))

    for val in [True, False]:
      display.display(display.Markdown(f'### {"validation" if val else "test"}'))
      display.display(
          plot_eval_metric_curves(
              model=model,
              dataset=dataset,
              val=val)
          )

In [None]:
def plot_side_by_side_eval_metric_curves(
    *,
    dataset: str,
    experiments: Collection[str] | None = None,
    models: Collection[str] | None = None,
    metric: str ='perf',
    results: Mapping[str, DatasetResults] | None = None,
):
  if not experiments:
    experiments = EXPERIMENTS
  if not models:
    models = MODELS
  if not results:
    results = ALL_RESULTS

  fig = plt.figure(
      figsize=(8, len(models) * 3),
      constrained_layout=True,
  )

  subfigs = fig.subfigures(nrows=len(models), ncols=1)

  for idx, subfig in enumerate(subfigs):
    model = models[idx]
    subfig.suptitle(f'{model=}', fontsize=15, y=1.04)

    ax = subfig.subplots(nrows=1, ncols=2, sharey='row')
    for jdx, val in enumerate([True, False]):
      plot_eval_metric_curves(
          dataset=dataset,
          experiments=experiments,
          model=model,
          val=val,
          metric_name=metric,
          master_results_dict=results,
          ax=ax[jdx],
          show_legend=(idx == jdx == 0)
      )
      ax[jdx].set_ylabel(f'metric={"mrr" if metric == "perf" else metric}')

  fig.suptitle(f'{dataset=}', fontsize=20, y=1.04)
  plt.show()

In [None]:
for dataset in DATASETS:
  plot_side_by_side_eval_metric_curves(dataset=dataset)