In [None]:
# Import modules
import matplotlib.pyplot as plt 
from matplotlib.colors import to_rgb
import pandas as pd
from scripts.utils.io import load_dict
from glob import glob
import numpy as np
import itertools

# Load target metrics
metrics = {
    'giant': load_dict(snakemake.input[0]),
    'miniature': load_dict(snakemake.input[1])
}

# Load trajectories
files = glob(f"{snakemake.input[2]}/*.parquet")

trajectories_df = {}
for file in files:
    index = int(file.split('/')[-1].split('_')[-1].split('.')[-2])
    trajectories_df[index] = pd.read_parquet(file)
    trajectories_df[index].reset_index(inplace=True)

# Retrieve names of target metrics
targets = list(trajectories_df[0].columns[3:])

# Calculate errors on target metrics
for target in targets:
    for key, df in trajectories_df.items():
        df['error_' + target] = np.abs(df[target] - metrics['giant'][target]) / metrics['giant'][target]

trajectories_df

In [None]:
# Generate combinations of target metrics
combinations = list(itertools.combinations(targets,2))
n_combinations = len(combinations)

In [None]:
def plot_trajectories(metric_1, 
                      metric_2,
                      ax):
    # Label plot
    ax.set_xlabel(metric_1.capitalize())
    ax.set_ylabel(metric_2.capitalize())
    
    # Plot trajectories
    cmap = plt.get_cmap('tab10')
    for replica, trajectory in trajectories_df.items():
        # Calculate color
        n = trajectory[metric_1].shape[0]
        color = np.zeros((n,4))

        # Populate channels
        for i, channel in enumerate(to_rgb(cmap(replica))):
            color[:,i] = channel

        # Populate alpha
        color[:,3] = (np.arange(n) / n)**4
        
        ax.scatter(trajectory[metric_1],
                   trajectory[metric_2],
                   0.5,
                   c=color,
                   linewidth=0)
        
    # Plot target
    ax.scatter(metrics['giant'][metric_1],
               metrics['giant'][metric_2],
               10,
               marker='x',
               color='r') 
    
    # Plot generated graph
    ax.scatter(metrics['miniature'][metric_1],
               metrics['miniature'][metric_2],
               10,
               marker='x',
               color='k')
    
# Visualize metrics in planes
fig, axes = plt.subplots(1,n_combinations,figsize=(3*n_combinations,3),dpi=300,squeeze=False)

# Plot visualizations on planes
for i, combination in enumerate(combinations):
    trajectory_x, trajectory_y = combination 
    plot_trajectories(trajectory_x, trajectory_y, axes[i][0])

plt.savefig(snakemake.output[0],bbox_inches='tight')

In [None]:
n_axes = 2 + len(targets)
fig, axes = plt.subplots(n_axes,1,figsize=(6,3 * len(targets)),dpi=300)

def plot_trajectories(metric,
                      ax):
    for key, df in trajectories_df.items():
        ax.plot(df[metric],linewidth=0.5)

for i, metric in enumerate(['Beta','Energy'] + ['error_' + target for target in targets]):
    plot_trajectories(metric,axes[i])
    axes[i].set_ylabel(metric.capitalize())

    if i < n_axes-1:
        axes[i].set_xticks([])
    else:
        axes[i].set_xlabel("Iteration")

axes[0].set_title("Metric trajectories")
axes[0].legend([f"Replica {num}" for num in range(len(trajectories_df.keys()))])

plt.tight_layout()
plt.savefig(snakemake.output[1],bbox_inches='tight')