In [None]:
history = scanvi_model.history

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df_list = []
for metric_name, metric_df in history.items():
    # If each metric_df is a DataFrame with exactly one column, rename it:
    if isinstance(metric_df, pd.DataFrame) or isinstance(metric_df, pd.Series):
        # Ensure we convert a Series to DataFrame if needed
        if isinstance(metric_df, pd.Series):
            metric_df = metric_df.to_frame()
        old_col = metric_df.columns[0]
        metric_df = metric_df.rename(columns={old_col: metric_name})
        df_list.append(metric_df)
    else:
        print(f"{metric_name} is not a DataFrame or Series.")

df_history = pd.concat(df_list, axis=1)

# Now df_history has one column per metric keyed by epoch.
# Plot a few metrics as examples:

plt.figure(figsize=(6, 4))
plt.plot(df_history["elbo_train"], label="Train ELBO")
plt.title("ELBO Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("ELBO")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(6, 4))
plt.plot(df_history["train_loss_epoch"], label="Train Loss")
plt.plot(df_history["reconstruction_loss_train"], label="Reconstruction Loss")
plt.title("Losses Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss Value")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(6, 4))
plt.plot(df_history["train_accuracy"], label="Train Accuracy")
plt.plot(df_history["train_f1_score"], label="Train F1 Score")
plt.title("Classification Metrics Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Metric")
plt.legend()
plt.ylim(0.9, 1.0)  
plt.tight_layout()
plt.show()
