In [None]:
import wandb
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
def _get_summary(config):
    summary = {
        "initial_lr": config["init_lr"],
        "mask_ratio": config["mask_ratio"],
        "d_model": config["d_model"],
        "train_batch_size": config["train_batch_size"],
        "patch_length": config["patch_len"],
        "torch_dtype": config["torch_dtype"],
        "train_ratio": config["train_ratio"],
        "val_ratio": config["val_ratio"],
        "test_ratio": config["test_ratio"],
        "random_seed": config["random_seed"],
        "model_name": config["model_name"],
    }

    if "transformer_type" in config:
        summary["transformer_type"] = config["transformer_type"]
    else:
        summary["transformer_type"] = "None"
  
    if "randomly_initialize_backbone" in config:
        summary["randomly_initialize_backbone"] = config["randomly_initialize_backbone"]
    else:
        summary["randomly_initialize_backbone"] = "None"
    
    if "transformer_backbone" in config:
        summary["transformer_backbone"] = config["transformer_backbone"]
    else:
        summary["transformer_backbone"] = "None"

    if "dataset_names" in config and isinstance(config["dataset_names"], str):
        summary["dataset_names"] = config["dataset_names"]
    else:
        summary["dataset_names"] = "Multiple datasets."
    return summary

def _get_metadata_summary(metadata):
    meta = {
        "hostname": metadata["host"],
        "gpu_type": metadata["gpu"],
        "program": metadata["program"],
        "args": metadata["args"],
        "username": metadata["username"],
        "execution_start_time": metadata["startedAt"],
        "git_commit": metadata['git']['commit'],
    }
    if "codePath" in metadata:
        meta["code_path"] = metadata["codePath"]
    else:
        meta["code_path"] = "None"
    return meta

In [None]:
api = wandb.Api()

runs = api.runs("timeseries-foundation-model/Time-series Foundation Model")

summary_list = []
for run in tqdm(runs, total=len(runs)): 
    if "dataset_names" not in run.config: 
        continue
    
    try:
        summary = _get_summary(run.config)
        metadata = _get_metadata_summary(run.metadata)
    except:
        print(f"Failed to get summary for {run.name}")
    summary = {**summary, **metadata}
    
    summary["run_name"] = run.name 
    summary["run_id"] = run.id
    summary["notes"] = run.notes
    summary_list.append(summary)

columns = ['run_name', 'transformer_backbone', 'd_model',
           'model_name', 'transformer_type',  'randomly_initialize_backbone', 
           'patch_length', 'random_seed', 'initial_lr', 
           'mask_ratio', 'train_batch_size', 'torch_dtype', 
           'train_ratio', 'val_ratio', 'test_ratio', 
           'hostname', 'git_commit', 'notes', 'gpu_type',
           'program', 'code_path', 'args', 'username', 
           'execution_start_time', 'dataset_names']

summary = pd.DataFrame(summary_list)
summary.set_index("run_id", inplace=True)
summary = summary[columns]
summary.to_csv("../../assets/data/wandb_runs_summary.csv")

In [None]:
summary.head()

In [None]:
pretraining_runs = [
    "fearless-planet-52", # Old runs
    "charmed-bee-241", "zesty-frost-243", "hopeful-water-240", # New pre-training runs
    "Pre-training - small - patch 4", "Pre-training - small - patch 8, pos emb False" # Ablation experiments
]

summary[summary["run_name"].isin(pretraining_runs)]

In [None]:
import os

### Pre-training runs
small_run = api.run(path=os.path.join(run.entity, run.project, "xxrzrfrd")) # Small run
base_run = api.run(path=os.path.join(run.entity, run.project, "ligbv4fg")) # Base run
large_run = api.run(path=os.path.join(run.entity, run.project, "0diyru56")) # Large run

# old_large_run = api.run(path=os.path.join(run.entity, run.project, "new8a30j")) # Old large run

# Ablations
small_patch4_run = api.run(path=os.path.join(run.entity, run.project, "wf046knr")) # Small run with patch_len = 4
small_flant5_run = api.run(path=os.path.join(run.entity, run.project, "hvyei2nf")) # Small run with Flan T5 initialization

In [None]:
# Get step losses
small_vals = pd.DataFrame(small_run.scan_history(keys=['_step', 'step_train_loss']))
base_vals = pd.DataFrame(base_run.scan_history(keys=['_step', 'step_train_loss']))
large_vals = pd.DataFrame(large_run.scan_history(keys=['_step', 'step_train_loss']))

# old_large_vals = pd.DataFrame(old_large_run.scan_history(keys=['_step', 'step_train_loss']))

small_patch4_vals = pd.DataFrame(small_patch4_run.scan_history(keys=['_step', 'step_train_loss']))
small_flant5_vals = pd.DataFrame(small_flant5_run.scan_history(keys=['_step', 'step_train_loss']))

In [None]:
span = 500

small_losses = small_vals["step_train_loss"].ewm(span=span).mean().to_numpy()
base_losses = base_vals["step_train_loss"].ewm(span=span).mean().to_numpy()
large_losses = large_vals["step_train_loss"].ewm(span=span).mean().to_numpy()

# old_large_losses = old_large_vals["step_train_loss"].ewm(span=span).mean().to_numpy()

small_patch4_losses = small_patch4_vals["step_train_loss"].ewm(span=span).mean().to_numpy()
small_flant5_losses = small_flant5_vals["step_train_loss"].ewm(span=span).mean().to_numpy()

In [None]:
billion = 1e9
num_patches = 64*2048 # NOTE: x-axis is only approximate, since some patches may not contain any information. Also patches are repeated
steps_in_one_epoch = num_patches*38374/billion

plt.plot(num_patches*np.arange(1, len(small_losses)+1)/billion, 
         small_losses, c='red', label="  40M Small")
plt.plot(num_patches*np.arange(1, len(base_losses)+1)/billion, 
         base_losses, c='blue', label="125M Base")
plt.plot(num_patches*np.arange(1, len(large_losses)+1)/billion, 
         large_losses, c='green', label="385M Large")

# Epochs
plt.axvline(x=steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=2*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=3*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)

# Old run
# plt.plot(num_patches*np.arange(1, len(old_large_losses)+1)/billion, 
#          old_large_losses, c='pink', label="385M Large (O)")

plt.grid(color='lightgray', linestyle='--', linewidth=0.5)  # Set grid color and style

plt.ylim(0.06, 0.2)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel("Processed patches (in Billions)", fontsize=20)
plt.ylabel("Training Loss", fontsize=20)
plt.xlim(0, 2*steps_in_one_epoch)
plt.legend(fontsize=20)
plt.savefig("../../assets/figures/pretraining/moment_family_training_losses.pdf")
plt.show()

In [None]:
plt.plot(num_patches*np.arange(1, len(small_flant5_losses)+1)/billion, 
         small_flant5_losses, c='red', label="Small (Flan-T5)")
plt.plot(num_patches*np.arange(1, len(small_losses)+1)/billion, 
         small_losses, c='blue', label="Small (Random)")

# Epochs
plt.axvline(x=steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=2*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=3*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)

plt.ylim(0.08, 0.2)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.grid(color='lightgray', linestyle='--', linewidth=0.5)
plt.xlabel("Processed patches (in Billions)", fontsize=20)
plt.ylabel("Training Loss", fontsize=20)
plt.xlim(0, 2*steps_in_one_epoch)
plt.legend(fontsize=20)
plt.savefig("../../assets/figures/pretraining/random_vs_flant5_initialization.pdf")
plt.show()

In [None]:
plt.plot(num_patches*np.arange(1, len(small_flant5_losses)+1)/billion, 
         small_flant5_losses, c='red', label="Small (Flan-T5)")
plt.plot(num_patches*np.arange(1, len(small_losses)+1)/billion, 
         small_losses, c='blue', label="Small (Random)")

# Epochs
plt.axvline(x=steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=2*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=3*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)

plt.xscale("log")
plt.xlim(1e-3, None)
plt.ylim(0.06, 2)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.grid(color='lightgray', linestyle='--', linewidth=0.5)
plt.xlabel("Processed patches (in Billions)", fontsize=20)
plt.ylabel("Training Loss", fontsize=20)
plt.xlim(0, 2*steps_in_one_epoch)
plt.legend(fontsize=20)
plt.savefig("../../assets/figures/pretraining/random_vs_flant5_initialization_log.pdf")
plt.show()

In [None]:
plt.plot(num_patches*np.arange(1, len(small_patch4_losses)+1)/billion, 
         small_patch4_losses, c='red', label="Small (Patch 4)")
plt.plot(num_patches*np.arange(1, len(small_losses)+1)/billion, 
         small_losses, c='blue', label="Small (Patch 8)")

# Epochs
plt.axvline(x=steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=2*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)
plt.axvline(x=3*steps_in_one_epoch, c='black', linestyle='--', linewidth=0.5)


plt.ylim(0.07, 0.2)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.grid(color='lightgray', linestyle='--', linewidth=0.5)
plt.xlabel("Processed patches (in Billions)", fontsize=20)
plt.ylabel("Training Loss", fontsize=20)
plt.xlim(0, 2*steps_in_one_epoch)
plt.legend(fontsize=20)
plt.savefig("../../assets/figures/pretraining/patch_len_4_vs_8.pdf")
plt.show()

### Radar Plot

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

In [None]:
categories = [
    'Imputation',
    "Long-horizon \nForecasting",
    'Short-horizon \nForecasting',
    'Classification',
    'Anomaly\nDetection',
]

model_values = {
    'MOMENT':     [
        0.9929824561, 
        0.9958246347, 
        0.8849206349,
        0.236097561, 
        0.9035087719],
    'GPT4TS':   [
        1, 
        0.9826026444, 
        0.4345238095,
        0.3502439024, 
        0],
    'TimesNet': [
        0.9912280702, 
        0.9485038274, 
        0.4464285714,
        0.1882926829, 
        0.798245614],
}
colors = [
    'orangered',
    'mediumblue',
    'limegreen',
]

linestyles = ['solid', '--', ':']

N = len(categories)
angles = [n / float(N) * 2 * np.pi for n in range(N)]
angles += angles[:1]

fig, ax = plt.subplots(subplot_kw=dict(polar=True))
ax.set_rlabel_position(0)
plt.xticks(angles[:-1], [])
plt.yticks([0.0, 0.33, 0.67, 1.0], [])
ax.tick_params(axis='x', which='major', pad=16, labelsize=16)
plt.ylim(0, 1.0)

plt.text(0, 1.35, 'Imputation', fontsize=16, ha='center')
plt.text(1, 1.25, "Long-horizon Forecasting", fontsize=16, ha='center')
plt.text(2.7, 1.35, "Classification", fontsize=16, ha='center')
plt.text(3.66, 1.37, "Short-horizon\nForecasting", fontsize=16, ha='center')
plt.text(5.25, 1.3, "Anomaly Detection", fontsize=16, ha='center')

for model_name, values in model_values.items():
    values_with_first = values + values[:1]  # close the circle
    c = colors.pop(0)
    linestyle = linestyles.pop(0)
    ax.plot(angles, values_with_first, label=model_name, color=c, linewidth=1, linestyle=linestyle)
    ax.fill(angles, values_with_first, alpha=0.5, color=c)

plt.grid(color='lightgray', linestyle='--', linewidth=1)
legend = plt.legend(fontsize=16, ncol=3, loc='lower center', bbox_to_anchor=(0.5, -0.25))
# get the lines and texts inside legend box
legend_lines = legend.get_lines()
# bulk-set the properties of all lines and texts
plt.setp(legend_lines, linewidth=3)
plt.savefig("../../assets/figures/pretraining/model_comparison.pdf", bbox_inches='tight')
plt.show()