In [None]:
# @title Imports

import os
import tempfile
import warnings

import matplotlib as mpl
from matplotlib import font_manager
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
from google3.learning.deepmind.xmanager2.client import xmanager_api
from google3.pyglib import gfile
from google3.pyglib.function_utils import memoize


# Import Google font family
_GOOGLE_SANS_PATH = (
    'google3/third_party/googlefonts/api/googlerestricted/googlesans/'
)


@memoize.Memoize()
def import_google3_fonts(font_path: str) -> None:
  """Import fonts stored in google3 into Matplotlib for use in Colab.

  Args:
    font_path: google3 path to either a directory that contains .ttf fonts or to
      a specific .ttf font file.
  """
  if gfile.IsDirectory(font_path):
    # Create a temp directory as a destination for copied font files.
    tmp_dir = tempfile.mkdtemp()
    # Copy font files from google3 to temp dir.
    gfile.RecursivelyCopyDir(font_path, tmp_dir, overwrite=True)
    # Add font files in directory to matplotlib font_manager.
    font_files = font_manager.findSystemFonts(fontpaths=tmp_dir)
  else:
    # Assume the path points to a file if it's not a directory.
    # Copy ttf file from google3 to temp location.
    tmp_file = tempfile.NamedTemporaryFile(suffix='.ttf')
    tmp_file.close()
    gfile.Copy(font_path, tmp_file.name)
    font_files = [tmp_file.name]

  # Add fonts to default font manager.
  for font_file in font_files:
    font_manager.fontManager.addfont(font_file)


def import_default_google_fonts() -> None:
  """Register a set of default fonts (Roboto, Google Sans) with Matplotlib."""
  # Prepend google_src to google3 paths.
  import_google3_fonts(
      os.path.join('/google_src/head/depot', _GOOGLE_SANS_PATH)
  )


# Import and register Google fonts with Matplotlib so we can use them.
import_default_google_fonts()


pd.set_option('display.max_rows', None)  # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns
# Suppress specific warning
warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 14
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'
plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE - 5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')
mpl.rcParams['font.family'] = 'Google Sans'


def log_float_formatter(y, pos):
  return f'{y:.2f}'

DARKER_BLUE = '#3182BD'

LIGHTER_BLUE = '#9ECAE1'

In [None]:
# @title Few-Shot Classification

data = {
    "Model": ["LSMV2"] * 5 + ["LSMv1"] * 5,
    "Shots": [25, 50, 75, 100, 500] * 2,
    "Hypertension_F1": [
        0.43332,
        0.49379,
        0.53199,
        0.57191,
        0.60128,
        0.4341,
        0.51284,
        0.51375,
        0.55638,
        0.58777,
    ],
    "Anxiety_F1": [
        0.55896,
        0.58799,
        0.59982,
        0.61486,
        0.63638,
        0.55654,
        0.56636,
        0.56299,
        0.56371,
        0.62072,
    ],
}

df = pd.DataFrame(data)
df_long = pd.melt(
    df,
    id_vars=["Model", "Shots"],
    value_vars=["Hypertension_F1", "Anxiety_F1"],
    var_name="Task",
    value_name="F1",
)
df_long["Task"] = df_long["Task"].str.replace("_F1", "")
df_long["Shots"] = df_long["Shots"].astype(str)  # Make x-axis categorical

# Marker and color definitions
marker_dict = {"LSMv1": "o", "LSMV2": "s"}
task_palette = {"Hypertension": "#3182BD", "Anxiety": "#9ECAE1"}

# Plot setup
sns.set(style="white", font_scale=1.2)
plt.figure(figsize=(5, 5))

# Plot each (Task, Model) pair with solid lines
for (task, model), group in df_long.groupby(["Task", "Model"]):
  sns.lineplot(
      data=group,
      x="Shots",
      y="F1",
      color=task_palette[task],
      marker=marker_dict[model],
      markersize=9,
      linewidth=1.5,
      linestyle="-",
  )

# Custom legend handles
task_handles = [
    Line2D(
        [0],
        [0],
        color=task_palette["Hypertension"],
        lw=2.5,
        linestyle="-",
        label="Hypertension",
    ),
    Line2D(
        [0],
        [0],
        color=task_palette["Anxiety"],
        lw=2.5,
        linestyle="-",
        label="Anxiety",
    ),
]
model_handles = [
    Line2D(
        [0],
        [0],
        color="black",
        marker="s",
        linestyle="",
        markersize=9,
        label="Ours",
    ),
    Line2D(
        [0],
        [0],
        color="black",
        marker="o",
        linestyle="",
        markersize=9,
        label="LSM",
    ),
]
# Positioning legends
legend1 = plt.legend(
    handles=task_handles,
    title="",
    loc="lower center",
    bbox_to_anchor=(0.45, -0.03),
    ncol=1,
    frameon=False,
)
plt.gca().add_artist(legend1)
plt.legend(
    handles=model_handles,
    title="",
    loc="lower right",
    bbox_to_anchor=(1.0, -0.03),
    frameon=False,
)

# Axis labels and layout
plt.title("Few Shot Performance on Discriminative Tasks")
plt.xlabel("Shots")
plt.ylabel("F1 Score")
sns.despine()
plt.tight_layout()
plt.show()

In [None]:
#@title Missing Breakdown - Sensor Imputation

missing_levels = ["30", "40", "50", "60", "70", "80"]
x = range(len(missing_levels))

# Data grouped by task
data = {
    "Random Imputation": {
        "LSM v2": [0.24638, 0.23804, 0.22285, 0.20237, 0.17039, 0.17164],
        "LSM v1": [0.29002, 0.29619, 0.29366, 0.3005, 0.27011, 0.31816]
    },
    "Temporal Imputation": {
        "LSM v2": [0.55057, 0.52793, 0.45241, 0.42996, 0.3638, 0.3491],
        "LSM v1": [0.74412, 0.71616, 0.62528, 0.56129, 0.53848, 0.4347]
    },
    "Modality Imputation": {
        "LSM v2": [0.22351, 0.25024, 0.27439, 0.29475, 0.3459, 0.36079],
        "LSM v1": [0.38711, 0.42923, 0.43008, 0.43354, 0.52298, 0.49306]
    }
}

fig, axes = plt.subplots(1, 3, figsize=(11, 4), sharey=True)

for i, (task, values) in enumerate(data.items()):
    ax = axes[i]
    ax.tick_params(
        axis='x',          # x-axis ticks
        which='major',     # major ticks only
        direction='out',   # tick pointing outward
        bottom=True,       # (default) ticks at the bottom
        length=3,          # length of tick
        width=1            # thickness
    )
    if i == 0:
      ax.tick_params(
          axis='y',
          which='major',
          direction='out',
          left=True,
          length=4,
          width=1
      )
    ax.plot(missing_levels, values["LSM v2"], marker='o', label='Ours')
    ax.plot(missing_levels, values["LSM v1"], marker='s', label='LSM')
    ax.set_title(task)
    ax.set_xlabel("Missingness Level (%)")
    if i == 0:
        ax.set_ylabel("Mean Squared Error (MSE)")
    if i == 0:
      ax.legend(frameon=False)
plt.tight_layout()
plt.show()

In [None]:
#@title Missing brakdown - Modality Imputation Only


missing_levels = ["30", "40", "50", "60", "70", "80"]
x = range(len(missing_levels))
# Only keep Modality Imputation data
data = {
    "LSM v2": [0.22351, 0.25024, 0.27439, 0.29475, 0.3459, 0.36079],
    "LSM v1": [0.38711, 0.42923, 0.43008, 0.43354, 0.52298, 0.49306]
}
plt.figure(figsize=(4, 3.5))
plt.plot(missing_levels, data["LSM v1"], marker='s', label='LSM', color=LIGHTER_BLUE)
plt.plot(missing_levels, data["LSM v2"], marker='o', label='Ours', color=DARKER_BLUE)

plt.xlabel("Missingness Level (%)")
plt.ylabel("Mean Squared Error (MSE)")
plt.title("Pre-Training Gains")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig("/tmp/signal_imputation_missing_breakdown.pdf", bbox_inches='tight', format="pdf")
%download_file /tmp/signal_imputation_missing_breakdown.pdf
plt.show()

In [None]:
# === Figure setup ===
fig, axs = plt.subplots(1, 2, figsize=(7, 3.5), dpi=100)

# === Subplot 1: Generative Gains as Bar Plot ===
modalities = [2, 6, 12]
x = np.arange(len(modalities))  # [0, 1, 2]
bar_width = 0.35

LSM_results = [0.73, 0.58, 0.45]
Our_results = [0.17, 0.21, 0.27]

axs[0].bar(x - bar_width/2, LSM_results, width=bar_width, color=LIGHTER_BLUE)
axs[0].bar(x + bar_width/2, Our_results, width=bar_width, color=DARKER_BLUE)

axs[0].set_xlabel("Num of Masked Signals")
axs[0].set_ylabel("Mean Squared Error")
axs[0].set_title("Generative Gains")
axs[0].set_xticks(x)
axs[0].set_xticklabels([str(m) for m in modalities])
axs[0].grid(axis='y', linestyle='--', alpha=0.6)

# === Subplot 2: Discriminative Gains ===
tasks = {
    "Hypertension": {"LSMV2": 0.6367125, "LSMV1": 0.6017575},
    "Anxiety": {"LSMV2": 0.661505, "LSMV1": 0.6250249999999999},
    "Activity": {"LSMV2": 0.4109475, "LSMV1": 0.263265},
}
colors = {'LSMV2': DARKER_BLUE, 'LSMV1': LIGHTER_BLUE}
x_labels = list(tasks.keys())
x_base = np.arange(len(x_labels))
marker_size = 100

for i, task in enumerate(x_labels):
    axs[1].scatter(x_base[i], tasks[task]["LSMV2"], s=marker_size, color=colors['LSMV2'], marker='o')
    axs[1].scatter(x_base[i], tasks[task]["LSMV1"], s=marker_size, color=colors['LSMV1'], marker='o')

axs[1].set_xticks(x_base)
axs[1].set_xticklabels(x_labels, rotation=0)
axs[1].set_ylabel("F1 Score")
axs[1].set_title("Discriminative Gains")
axs[1].grid(axis='y', linestyle='--', alpha=0.6)

# === Custom Legend on Right Subplot (line markers) ===
custom_lines = [
    Line2D([0], [0], color=LIGHTER_BLUE, lw=2, label='LSM-1'),
    Line2D([0], [0], color=DARKER_BLUE, lw=2, label='LSM-2')
]
axs[1].legend(handles=custom_lines, loc='lower left', frameon=False, fontsize=MEDIUM_SIZE)

# === Final Layout Adjustments ===
plt.tight_layout()
plt.subplots_adjust(bottom=0.25)
plt.savefig("/tmp/teaser_lsm.svg", bbox_inches='tight', format="svg")
%download_file /tmp/teaser_lsm.svg
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Define x-axis data
modalities = ["Accel", "PPG", "EDA", "Temp"]
time_segments = ["Morning", "Afternoon", "Evening", "Night"]

reordered_indices = [1, 2, 3, 0]

# Define y-ticks
yticks_config = {
    "Hypertension": [0.56, 0.60, 0.63, 0.66, 0.69],
    "Anxiety": [0.56, 0.60, 0.63, 0.66, 0.69],
    "Activity": [0.12, 0.25, 0.35, 0.45],
}

# Input data
tasks_modality = {
    "Hypertension (F1)": {
        "LSMV2": [0.63214, 0.61234, 0.6506, 0.65177],
        "LSMV1": [0.58648, 0.56764, 0.62624, 0.62667],
        "baseline_v2": 0.65087,
        "baseline_v1": 0.63988,
    },
    "Anxiety (F1)": {
        "LSMV2": [0.60647, 0.67554, 0.68104, 0.68297],
        "LSMV1": [0.58619, 0.62982, 0.64211, 0.64198],
        "baseline_v2": 0.68301,
        "baseline_v1": 0.66959,
    },
    "Activity (F1)": {
        "LSMV2": [0.2023, 0.49046, 0.48797, 0.46306],
        "LSMV1": [0.13689, 0.28638, 0.28191, 0.34788],
        "baseline_v2": 0.47376,
        "baseline_v1": 0.47038,
    },
}

tasks_time = {
    "Hypertension (F1)": {
        "LSMV2": [0.6181, 0.64841, 0.64774, 0.64857, 0.64918],
        "LSMV1": [0.61993, 0.62182, 0.62173, 0.61823, 0.63576],
        "baseline_v2": 0.65087,
        "baseline_v1": 0.63988,
    },
    "Anxiety (F1)": {
        "LSMV2": [0.64817, 0.6809, 0.68426, 0.68282, 0.68269],
        "LSMV1": [0.62439, 0.63504, 0.6382, 0.63837, 0.65236],
        "baseline_v2": 0.68301,
        "baseline_v1": 0.66959,
    },
    "Activity (F1)": {
        "LSMV2": [0.46501, 0.41696, 0.42291, 0.41227, 0.46549],
        "LSMV1": [0.34788, 0.23725, 0.23137, 0.2344, 0.35554],
        "baseline_v2": 0.47376,
        "baseline_v1": 0.47038,
    },
}

# Reorder and filter temporal segment data
tasks_time_reordered = {
    task: {
        "LSMV2": [values["LSMV2"][i] for i in reordered_indices],
        "LSMV1": [values["LSMV1"][i] for i in reordered_indices],
        "baseline_v2": values["baseline_v2"],
        "baseline_v1": values["baseline_v1"],
    }
    for task, values in tasks_time.items()
}

# Colors
DARKER_BLUE = "#1f77b4"
LIGHTER_BLUE = "#aec7e8"

# Plot setup
fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=100)

width = 0.15
marker_style = dict(marker="o", edgecolors="none")
marker_size = 100
colors = {"LSMV2": DARKER_BLUE, "LSMV1": LIGHTER_BLUE}

# Top row: Removed Modality
for ax, (task, values) in zip(axs[0], tasks_modality.items()):
    x = np.arange(len(modalities))
    short_task = task.replace(" (F1)", "")

    ax.axhline(y=values["baseline_v2"], linestyle=":", color=colors["LSMV2"], linewidth=2)
    ax.axhline(y=values["baseline_v1"], linestyle=":", color=colors["LSMV1"], linewidth=2)

    for i in range(len(modalities)):
        ax.plot([x[i] - width, x[i] - width], [values["baseline_v2"], values["LSMV2"][i]], color=colors["LSMV2"], alpha=0.6)
        ax.plot([x[i] + width, x[i] + width], [values["baseline_v1"], values["LSMV1"][i]], color=colors["LSMV1"], alpha=0.6)

    lsmv2_plot = ax.scatter(x - width, values["LSMV2"], s=marker_size, color=colors["LSMV2"], **marker_style)
    lsmv1_plot = ax.scatter(x + width, values["LSMV1"], s=marker_size, color=colors["LSMV1"], **marker_style)

    ax.set_xticks(x)
    ax.set_xticklabels(modalities, ha='center')
    ax.set_yticks(yticks_config[short_task])
    ax.grid(axis="y", linestyle="--", alpha=0.6)

# Bottom row: Removed Temporal Segment
for ax, (task, values) in zip(axs[1], tasks_time_reordered.items()):
    x = np.arange(len(time_segments))
    short_task = task.replace(" (F1)", "")

    ax.axhline(y=values["baseline_v2"], linestyle=":", color=colors["LSMV2"], linewidth=2)
    ax.axhline(y=values["baseline_v1"], linestyle=":", color=colors["LSMV1"], linewidth=2)

    for i in range(len(time_segments)):
        ax.plot([x[i] - width, x[i] - width], [values["baseline_v2"], values["LSMV2"][i]], color=colors["LSMV2"], alpha=0.6)
        ax.plot([x[i] + width, x[i] + width], [values["baseline_v1"], values["LSMV1"][i]], color=colors["LSMV1"], alpha=0.6)

    lsmv2_plot = ax.scatter(x - width, values["LSMV2"], s=marker_size, color=colors["LSMV2"], **marker_style)
    lsmv1_plot = ax.scatter(x + width, values["LSMV1"], s=marker_size, color=colors["LSMV1"], **marker_style)

    ax.set_xticks(x)
    ax.set_xticklabels(time_segments, ha='center', rotation=10)
    ax.set_yticks(yticks_config[short_task])
    ax.grid(axis="y", linestyle="--", alpha=0.6)


# Create custom legend handles
lsmv2_marker = Line2D([0], [0], marker='o', color='w', markerfacecolor=colors["LSMV2"], markersize=10, label='Ours')
lsmv1_marker = Line2D([0], [0], marker='o', color='w', markerfacecolor=colors["LSMV1"], markersize=10, label='LSM')

lsmv2_baseline = Line2D([0], [0], linestyle=":", color=colors["LSMV2"], linewidth=3, label="Ours (Without Removal)")
lsmv1_baseline = Line2D([0], [0], linestyle=":", color=colors["LSMV1"], linewidth=3, label="LSM (Without Removal)")

# Shared legend
fig.legend(
    handles=[lsmv2_marker, lsmv1_marker, lsmv2_baseline, lsmv1_baseline],
    labels=["LSM-2", "LSM-1", "LSM-2 (Without Removal)", "LSM-1 (Without Removal)"],
    loc="lower center",
    bbox_to_anchor=(0.5, -0.07),
    frameon=False,
    ncol=4,
    fontsize=MEDIUM_SIZE + 2,
)
plt.tight_layout()
plt.savefig("/tmp/mnar_results.pdf", bbox_inches='tight', format="pdf")
%download_file /tmp/mnar_results.pdf
plt.show()

In [None]:
from matplotlib.ticker import FixedLocator, FuncFormatter, LogLocator, ScalarFormatter

subject_scaling_df = pd.DataFrame({
    "Model": ["Ours"] * 3 + ["LSM"] * 3,
    "Number of Subjects": [100, 1000, 10000] * 2,
    "Loss (MSE)": [1.34262, 0.23113, 0.20482, 1.26583, 0.48524, 0.29725],
})


data_scaling_df = pd.DataFrame({
    "Model": ["Ours"] * 4 + ["LSM"] * 4,
    "Number of Subject-Days": [
        d * 24 for d in [1000, 10000, 100000, 1_000_000] * 2
    ],
    "Loss (MSE)": [
        1.48221,
        0.26917,
        0.21473,
        0.19921,
        1.4027,
        0.66493,
        0.3019,
        0.29521,
    ],
})

compute_scaling_df = pd.DataFrame({
    "Model": ["Ours"] * 3 + ["LSM"] * 3,
    "Number of Training Steps": [1000, 10000, 100000] * 2,
    "Loss (MSE)": [
        0.93637,
        0.25832,
        0.1993,
        0.81927,
        0.32338,
        0.295,
    ],
})


model_scaling_df = pd.DataFrame({
    "Model": ["Ours"] * 3 + ["LSM"] * 3,
    "Number of Model Parameters": [5_800_000, 25_000_000, 110_000_000] * 2,
    "Loss (MSE)": [
        0.21967,
        0.1993,
        0.19211,
        0.31411,
        0.295,
        0.28334,
    ],
})


# Style and aesthetics
sns.set(style="white", font_scale=1.2)
marker_dict = {"LSM": "o", "Ours": "s"}
task_palette = {"Loss (MSE)": "#3182BD", "Activity (F1)": "#9ECAE1"}
model_dict = {"LSM": "#9ECAE1", "Ours": "#3182BD"}


# Plotting helper
def plot_scaling(
    ax,
    df,
    x_col,
    title,
    xlabel_name,
    xlabel_subtitle,
    var_name=None,
    value_vars=["Loss (MSE)"],
    ylabel=True,
):
  df_long = pd.melt(
      df,
      id_vars=["Model", x_col],
      value_vars=value_vars,
      var_name="Task",
      value_name="Loss",
  )
  for (task, model), group in df_long.groupby(["Task", "Model"]):
    sns.lineplot(
        data=group,
        x=x_col,
        y="Loss",
        # ylabel = 'Validation Loss',
        ax=ax,
        color=model_dict[model],
        marker=marker_dict[model],
        markersize=9,
        linewidth=1.5,
        linestyle="-",
    )

  ax.set_title(title, weight='medium')
  ax.set_xlabel(
      r"$\bf{" + xlabel_name + r"}$" + (f" [{var_name}]" if var_name else ""),
      fontsize=MEDIUM_SIZE,
      labelpad=10,
  )
  ax.text(
      0.5,
      -0.26,
      xlabel_subtitle,  # Adjust vertical offset (-0.15) if needed
      transform=ax.transAxes,
      ha="center",
      va="top",
      fontsize=MEDIUM_SIZE,
      color="gray",
  )
  if ylabel:
    ax.set_ylabel("Loss")
  else:
    ax.set_ylabel("")
  ax.set_xscale("log")
  ax.set_yscale("log")
  # yticks = [1.4, 1.0, 0.6, 0.2]
  if "subject" in title.lower():
    yticks = [1.4, 1.0, 0.6, 0.2]
  elif "data" in title.lower():
    yticks = [1.4, 1.0, 0.6, 0.2]
  elif "compute" in title.lower():
    yticks = [1.0, 0.8, 0.6, 0.4, 0.2]
  elif "model" in title.lower():
    yticks = [0.4, 0.3, 0.2, 0.1]
  else:
      yticks = None

  # if yticks:
  #     ax.set_yticks(yticks)
  #     ax.set_yticklabels([f"{y:.1f}" for y in yticks])
  #     ax.set_ylim(min(yticks), max(yticks))
  ax.tick_params(axis="x", which="minor", length=4, width=0.8, direction="out")
  ax.yaxis.set_major_locator(FixedLocator(yticks))
  ax.yaxis.set_major_formatter(
      FuncFormatter(
          lambda y, _: f"{y:.2f}".rstrip("0").rstrip(".")
          if y < 1
          else f"{y:.1f}"
      )
  )
  ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.1f}"))
  for tick in ax.yaxis.get_major_ticks():
    tick.tick1line.set_visible(True)
    tick.tick1line.set_markersize(6)

  for tick in ax.xaxis.get_major_ticks():
    tick.tick1line.set_visible(True)
    tick.tick1line.set_markersize(6)
    tick.tick1line.set_markeredgewidth(1)

  # Custom minor tick positions between 10^2 and 10^8
  minor_ticks = []
  for decade in [1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8]:
    minor_ticks.extend([decade * i for i in range(2, 10)])

  # Add ticks in last decade (10^3 to 10^4)
  minor_ticks.extend([1e4 * i for i in range(2, 10) if 1e4 * i <= 1e4])
  ax.xaxis.set_minor_locator(FixedLocator(minor_ticks))
  ax.tick_params(axis="x", which="minor", length=4, width=0.8, direction="out")

  for tick in ax.xaxis.get_minor_ticks():
    tick.tick1line.set_visible(True)
    tick.tick1line.set_markersize(4)
    tick.tick1line.set_markeredgewidth(0.8)
  ax.set_box_aspect(1)


# Create subplots
fig, axes = plt.subplots(1, 4, figsize=(16, 4), dpi=100)
plot_scaling(
    axes[0],
    subject_scaling_df,
    "Number of Subjects",
    "Subject Scaling",
    xlabel_name="Subjects",
    xlabel_subtitle="Count",
    var_name="",
)
plot_scaling(
    axes[1],
    data_scaling_df,
    "Number of Subject-Days",
    "Data Scaling",
    xlabel_name="Data~Size",
    xlabel_subtitle="Hours",
    var_name="",
    ylabel=False,
)
plot_scaling(
    axes[2],
    compute_scaling_df,
    "Number of Training Steps",
    "Compute Scaling",
    xlabel_name="Training~Steps",
    xlabel_subtitle="Count",
    var_name="",
    ylabel=False,
)
plot_scaling(
    axes[3],
    model_scaling_df,
    "Number of Model Parameters",
    "Model Scaling",
    xlabel_name="Model~Size",
    xlabel_subtitle="Number of Parameters",
    var_name="",
    ylabel=False,
)

model_handles = [
    Line2D(
        [0],
        [0],
        color=model_dict["Ours"],
        marker="s",
        linestyle="-",
        markersize=10,
        label="LSM-2",
    ),
    Line2D(
        [0],
        [0],
        color=model_dict["LSM"],
        marker="o",
        linestyle="-",
        markersize=10,
        label="LSM-1",
    ),
]

# Add combined legend
fig.legend(
    handles=model_handles,
    ncol=4,
    loc="lower center",
    bbox_to_anchor=(0.55, -0.13),
    frameon=False,
    fontsize=MEDIUM_SIZE
)

plt.tight_layout()
plt.savefig("/tmp/scaling_results.svg", bbox_inches='tight', format="svg")
%download_file /tmp/scaling_results.svg
plt.show()