In [25]:
import wandb
import csv
from datetime import datetime
import pandas as pd
import json
import matplotlib.pyplot as plt

In [26]:
# Initialize the API
api = wandb.Api()

# Your project name and optional: entity (username or team name)
project_name = "conj_grad_results"
entity_name = "marccgrau"

In [27]:
run_sets = {
    "MNIST_MLP_Small": ["j2gde3kp", "ilwk4wno", "guexx00f", "tv4niig7"],
    "MNIST_MLP_Large": ["63ali7n8", "6zm55eno", "b2qnpy74", "5ful2rbb"],
    "MNIST_CNN_Small": ["kszrm0ea", "0asafs59", "2vgne902", "ic22ebl7"],
    "MNIST_CNN_Large": ["ujixgcdo", "5wuos9yi", "t51i5dks", "p7pa6auj"],
    "FASHION_MNIST_MLP_Small": ["8r1iwdlm", "8fy84vkq", "og30l29x", "dfrvc4pc"],
    "FASHION_MNIST_MLP_Large": ["bav2ftzc", "kzdms3ox", "guorim5o", "jopq3zzs"],
    "FASHION_MNIST_CNN_Small": ["zbhi3m4v", "tvmlui3r", "4srm7ig0", "49sp1iv0"],
    "FASHION_MNIST_CNN_Large": ["aj0xkyq5", "gp0t0ttr", "xwebyszi", "rqu7d7kz"],
    "CIFAR10_MLP_Small": ["cvnblxwv", "o7u7uoel", "f41nyx0o", "9y1avmjz"],
    "CIFAR10_MLP_Large": ["ugyyag4m", "i3jrz9xj", "x9p8tk5u", "a4kj6di2"],
    "CIFAR10_CNN_Small": ["a1o1h84h", "l7vykqin", "6z963duj", "9xx0jg46"],
    "CIFAR10_CNN_Large": ["zdoxnku5", "dwvmuowu", "06jvk85p", "82gwhc77"],
    "CIFAR100_MLP_Small": ["a94flful", "v6wl96y6", "l5iyk4v9", "xjqitgyg"],
    "CIFAR100_MLP_Large": ["jbw1w7hv", "uy0682dq", "jw4qwwgp", "8cd5h1uq"],
    "CIFAR100_CNN_Small": ["ma3spfwg", "kr0kggi4", "66rnbay1", "1zohaktl"],
    "CIFAR100_CNN_Large": ["a9pu12y0", "kcyecx03", "12ojnyk5", "dmud2b8q"],
    "SVHN_MLP_Small": ["bbic60c6", "2eqky0ah", "7ju2wb1l", "1qbwh5tn"],
    "SVHN_MLP_Large": ["jicl7fsd", "svhcuy47", "kluj48tv", "7p1h9mmn"],
    "SVHN_CNN_Small": ["qss10jlt", "99b4sqqd", "2ei08gkw", "qa8sio24"],
    "SVHN_CNN_Large": ["fodz7vhy", "lbc6am62", "2fv1ki4e", "hnkq207z"],
    "MAX_ITERS_CIFAR10": ["pcffz9ad", "hzlfid1d", "bebh07m6", "cvnblxwv", "i490wk2m", "wr99r3rv", "gogt478c", "bgomne1m", "9sg89zkk", "0g8pf5rr"],
    "MAX_ITERS_SVHN": ["di23fp6s", "zt2g483n", "dkvzv005", "2eqky0ah", "7w43iiz4", "cj7zb9cx", "esyzwpfl", "99ue77o6", "9a9cci4l", "ir3fszqb"]
}

In [28]:
def save_to_csv(filename, data):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(data)

In [52]:
color_mapping = {
    'nlcg': '#6C5B7B',  # pastel purple
    'adam': '#C06C84',  # pastel pink
    'rmsprop': '#F67280',  # pastel red
    'sgd': '#F8B195',  # pastel orange
}

In [53]:
for run_set_name, run_ids in run_sets.items():
    csv_data = []
    csv_data.append(["Run ID", "Run Name", "Optim Name", "Created Timestamp", "Train Loss", "Test Loss", "Highest Training Accuracy", "Highest Testing Accuracy", "Function Calls", "Gradient Calls", "Total Steps"])
    
    train_loss_data = {}
    
    for run_id in run_ids:
        run = api.run(path=f"marccgrau/conj_grad_results/{run_id}")
        run_config = json.loads(run.json_config)
        optim_name = run_config['optimizer']['value']['name']
        
        # Access logged metrics for the run
        history = run.history(keys=["_timestamp", "loss", "val_loss", "train_accuracy", "test_accuracy", "nb_function_calls", "nb_gradient_calls", "steps"], samples=50000)
        
        if optim_name == "NLCGEager":
            optim_name = "NLCG"
        if optim_name not in train_loss_data:
            train_loss_data[optim_name] = {'steps': [], 'loss': []}
        train_loss_data[optim_name]['steps'].extend(history.get('steps', []))
        train_loss_data[optim_name]['loss'].extend(history.get('loss', []))
        
        # Convert the created timestamp to a readable format
        created_time = None
        if len(history._timestamp) > 0:
            created_time = datetime.utcfromtimestamp(history._timestamp[0]).strftime('%Y-%m-%d %H:%M:%S')
        
        # Extract metrics
        min_train_loss = min(history.get('loss', [None]))
        min_test_loss = min(history.get('val_loss', [None]))
        max_train_accuracy = max(history.get('train_accuracy', [None]))
        max_test_accuracy = max(history.get('test_accuracy', [None]))
        func_calls = max(history.get('nb_function_calls', [None]))
        grad_calls = max(history.get('nb_gradient_calls', [None]))
        total_steps = max(history.get('steps', [None]))
        
        # Append data to CSV list
        csv_data.append([run.id, run.name, optim_name, created_time, min_train_loss, min_test_loss, max_train_accuracy, max_test_accuracy, func_calls, grad_calls, total_steps])
    
    # Determine the highest loss value from all the other data series to get a nicer graph
    all_other_losses = [data['loss'] for optim_name, data in train_loss_data.items() if optim_name != 'NLCG' and data['loss']]
    if all_other_losses:
        highest_loss = max([max(losses) for losses in all_other_losses])
    else:
        highest_loss = 0  # Default value if no losses are found

    plt.figure(figsize=(10, 6))
    for optim_name, data in train_loss_data.items():
        # If the optimizer is NLCG, prepend the highest loss value and a 0 step
        if optim_name == 'NLCG':
            data['loss'].insert(0, highest_loss)
            data['steps'].insert(0, 0)
        
        # Use the color mapping to set the color of the line
        plt.plot(data['steps'], data['loss'], label=optim_name, color=color_mapping.get(optim_name.lower()), linewidth=2)


    plt.xlabel('Total Steps')
    plt.ylabel('Train Loss')
    #plt.title(f'Train Loss for {run_set_name.replace("_", " ")}')

    # Remove top and right spines
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Use a light grid
    plt.grid(axis='y', linestyle='--', linewidth=0.5, alpha=0.6)

    # Place the legend at the bottom with a transparent background
    legend = plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.20), ncol=4, frameon=False)
    plt.subplots_adjust(bottom=0.25)  # Adjust the bottom margin to make space for the legend


    # Save the plot
    plot_filename = f"../output/plots/{run_set_name}_train_loss.png"
    plt.savefig(plot_filename)
    plt.close()

    print(f"Plot for {run_set_name} saved to {plot_filename}")
    
    # Save the current run set data to a separate CSV
    csv_filename = f"../output/tables/{run_set_name}_results.csv"
    save_to_csv(csv_filename, csv_data)
    print(f"Data for {run_set_name} saved to {csv_filename}")


Plot for MNIST_MLP_Small saved to ../output/plots/MNIST_MLP_Small_train_loss.png
Data for MNIST_MLP_Small saved to ../output/tables/MNIST_MLP_Small_results.csv


No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


Plot for MNIST_MLP_Large saved to ../output/plots/MNIST_MLP_Large_train_loss.png
Data for MNIST_MLP_Large saved to ../output/tables/MNIST_MLP_Large_results.csv
Plot for MNIST_CNN saved to ../output/plots/MNIST_CNN_train_loss.png
Data for MNIST_CNN saved to ../output/tables/MNIST_CNN_results.csv
Plot for FASHION_MNIST_MLP_Small saved to ../output/plots/FASHION_MNIST_MLP_Small_train_loss.png
Data for FASHION_MNIST_MLP_Small saved to ../output/tables/FASHION_MNIST_MLP_Small_results.csv


No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


Plot for FASHION_MNIST_MLP_Large saved to ../output/plots/FASHION_MNIST_MLP_Large_train_loss.png
Data for FASHION_MNIST_MLP_Large saved to ../output/tables/FASHION_MNIST_MLP_Large_results.csv
Plot for FASHION_MNIST_CNN saved to ../output/plots/FASHION_MNIST_CNN_train_loss.png
Data for FASHION_MNIST_CNN saved to ../output/tables/FASHION_MNIST_CNN_results.csv
Plot for CIFAR10_MLP_Small saved to ../output/plots/CIFAR10_MLP_Small_train_loss.png
Data for CIFAR10_MLP_Small saved to ../output/tables/CIFAR10_MLP_Small_results.csv


No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


Plot for CIFAR10_MLP_Large saved to ../output/plots/CIFAR10_MLP_Large_train_loss.png
Data for CIFAR10_MLP_Large saved to ../output/tables/CIFAR10_MLP_Large_results.csv
Plot for CIFAR10_CNN saved to ../output/plots/CIFAR10_CNN_train_loss.png
Data for CIFAR10_CNN saved to ../output/tables/CIFAR10_CNN_results.csv
Plot for CIFAR100_MLP_Small saved to ../output/plots/CIFAR100_MLP_Small_train_loss.png
Data for CIFAR100_MLP_Small saved to ../output/tables/CIFAR100_MLP_Small_results.csv


No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


Plot for CIFAR100_MLP_Large saved to ../output/plots/CIFAR100_MLP_Large_train_loss.png
Data for CIFAR100_MLP_Large saved to ../output/tables/CIFAR100_MLP_Large_results.csv
Plot for CIFAR100_CNN saved to ../output/plots/CIFAR100_CNN_train_loss.png
Data for CIFAR100_CNN saved to ../output/tables/CIFAR100_CNN_results.csv
Plot for SVHN_MLP_Small saved to ../output/plots/SVHN_MLP_Small_train_loss.png
Data for SVHN_MLP_Small saved to ../output/tables/SVHN_MLP_Small_results.csv


No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


Plot for SVHN_MLP_Large saved to ../output/plots/SVHN_MLP_Large_train_loss.png
Data for SVHN_MLP_Large saved to ../output/tables/SVHN_MLP_Large_results.csv
Plot for SVHN_CNN saved to ../output/plots/SVHN_CNN_train_loss.png
Data for SVHN_CNN saved to ../output/tables/SVHN_CNN_results.csv
Plot for MAX_ITERS_CIFAR10 saved to ../output/plots/MAX_ITERS_CIFAR10_train_loss.png
Data for MAX_ITERS_CIFAR10 saved to ../output/tables/MAX_ITERS_CIFAR10_results.csv
Plot for MAX_ITERS_SVHN saved to ../output/plots/MAX_ITERS_SVHN_train_loss.png
Data for MAX_ITERS_SVHN saved to ../output/tables/MAX_ITERS_SVHN_results.csv
