In [None]:
import pandas as pd
import wandb
from tqdm.notebook import tqdm
import pickle
from os.path import exists
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import math
from matplotlib.ticker import MaxNLocator
#...

font = {'family' : 'times',
        'size'   : 14}

matplotlib.rc('font', **font)

In [None]:
class Experiment:
    def __init__(self, run):
        self.name = run.name
        self.config = run.config
        self.summary = run.summary
        self.history = run.history()
        self.tags = run.tags
        self.run = run
        
    def get_id(self):
        return (self.config['formula'],self.config['mol_idx'])
        
    def get_history(self):
        return np.array(list(self.history['additional_steps'])).cumsum()

In [None]:
def fetch(project):
    api = wandb.Api()
    entity = "some-random-foo"
    hdata = []
    runs = api.runs(entity + "/" + project)
    for run in tqdm(runs):
        try:
            hdata.append(Experiment(run))
        except:
            pass
    return hdata

In [None]:
raw = fetch("scale_master")


In [None]:
# Loop through each Experiment object in raw and print details
for exp in raw:
    try:
        print(f"Name: {exp.name}")
        print(f"Config: {exp.config}")
        print(f"Summary: {exp.summary}")
        print(f"Tags: {exp.tags}")
        print(f"Run Group: {exp.run.group}")  # Assuming run.group gives the run group
        print(f"ID: {exp.get_id()}")
        print(f"History: {exp.get_history()}")
    except KeyError as e:
        print(f"KeyError encountered: {e}")

In [None]:
def get_exp(raw_exp, model):
    def test_method(method, exp):
        if method == "mswag_push" and "mswag_push" in exp.name:
            return True
        elif method == "svgd_push" and "svgd_push" in exp.name:
            return True
        elif method == "ensemble_push" and "ensemble_push" in exp.name:
            return True
        elif method == "ensemble" and "ensemble" in exp.name:
            return True
        elif method == "svgd" and "svgd" in exp.name:
            return True
        elif method == "mswag" and "mswag" in exp.name:
            return True
        else:
            return False
            
    exps = {'mswag_push': {dev: [] for dev in [1, 2, 4]},} 

    for exp in raw_exp:
        num_device = exp.config["num_device"]

        # For baseline methods
        if exp.run.group == "size" and exp.config["model"] == model:
            if test_method("mswag_push", exp):
                exps['mswag_push'][num_device] += [exp]
        if exp.run.group == "size3" and exp.config["model"] == model:
            if test_method("mswag_push", exp):
                exps['mswag_push'][num_device] += [exp]  
                
                
    return exps


In [None]:
def my_plot(model, method, exps, x_unit="", y_unit="", norm=False):
    plt.rcParams["font.family"] = "DejaVu Sans"
    
    def _one(exps, mswag):
        times = {}
        for exp in exps:
            try:
                if mswag:
                    times[exp.config["num_particles"]] = np.log2(exp.history["swag_epoch_time"].mean())
                else:
                    times[exp.config["num_particles"]] = np.log2(exp.history["time"].mean())
            except:
                pass

        myKeys = list(times.keys())
        myKeys.sort()
        ts = [times[i] for i in myKeys]
        return ts, myKeys

    plt.figure()
    
    if 1 in exps[method]:
        time_1, particles_1 = _one(exps[method][1], method == "mswag_push")
        plt.plot(np.log2(particles_1), time_1, marker='o', linestyle='-', label="1 Device")
    if 2 in exps[method]:
        time_2, particles_2 = _one(exps[method][2], method == "mswag_push")
        plt.plot(np.log2(particles_2), time_2, marker='s', linestyle='--', label="2 Devices")
    if 4 in exps[method]:
        time_4, particles_4 = _one(exps[method][4], method == "mswag_push")
        plt.plot(np.log2(particles_4), time_4, marker='^', linestyle=':', label="4 Devices") # Baseline
    
    
    plt.grid(True)  # Adding grid lines
    
    plt.ylim(-2, 8)
    plt.yticks(np.arange(11), [r'$2^{-2}$', r'$2^{-1}$', r'$2^{0}$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$', r'$2^6$', r'$2^7$', r'$2^8$'])
    plt.xticks(np.arange(10), [r'$2^0$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$', r'$2^6$', r'$2^7$', r'$2^8$', r'$2^9$',])
    plt.ylabel(f'Seconds ({y_unit} log scale)')
    plt.xlabel(f'Particles ({x_unit} log scale)')
    
    if method == "ensemble_push":
        method_title = "Ensemble"
    elif method == "mswag_push":
        method_title = "MSWAG"
    else:
        method_title = "Stein VGD"
    
    plt.title(f"{method_title} Push Scaling on {model}")
    plt.legend(loc='upper left') 
    plt.savefig(f'media/{model}_{method}.pdf', format='pdf')


# Example usage:
# my_plot("ModelName", "mswag_push", your_experiments_data, x_unit="Particles", y_unit="Seconds")


In [None]:
models = ["transformer2"]
methods = ["mswag_push"]

In [None]:
for model in models:
    for method in methods:
        exps = get_exp(raw, model)
        print(exps)
        my_plot(model, method, exps, norm=True)

In [None]:
def my_plot2(model, method, exps, x_unit="", y_unit="", norm=False):
    plt.rcParams["font.family"] = "DejaVu Sans"
    
    def _one(exps, mswag):
        times = {}
        for exp in exps:
            try:
                if mswag:
                    times[exp.config["num_params"]] = np.log2(exp.history["swag_epoch_time"].mean())
                else:
                    times[exp.config["num_params"]] = np.log2(exp.history["time"].mean())
            except:
                pass

        myKeys = list(times.keys())
        myKeys.sort()
        ts = [times[i] for i in myKeys]
        return ts, myKeys

    plt.figure()
    
    if 1 in exps[method]:
        time_1, params_1 = _one(exps[method][1], method == "mswag_push")
        plt.plot(params_1, time_1, marker='o', linestyle='-', label="1 Device")
    if 2 in exps[method]:
        time_2, params_2 = _one(exps[method][2], method == "mswag_push")
        plt.plot(params_2, time_2, marker='s', linestyle='--', label="2 Devices")
    if 4 in exps[method]:
        time_4, params_4 = _one(exps[method][4], method == "mswag_push")
        plt.plot(params_4, time_4, marker='^', linestyle=':', label="4 Devices")
    
    plt.grid(True)
    
    plt.ylim(-2, 8)
    plt.yticks(np.arange(11), [r'$2^{-2}$', r'$2^{-1}$', r'$2^{0}$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$', r'$2^6$', r'$2^7$', r'$2^8$'])
    plt.xticks(params_1 + params_2 + params_4)  # Assuming that params_1, params_2, and params_4 are the sorted num_params values
    
    plt.ylabel(f'Seconds ({y_unit} log scale)')
    plt.xlabel(f'Params ({x_unit})')
    
    if method == "ensemble_push":
        method_title = "Ensemble"
    elif method == "mswag_push":
        method_title = "MSWAG"
    else:
        method_title = "Stein VGD"
    
    plt.title(f"{method_title} Push Scaling on {model}")
    plt.legend(loc='upper left') 
    plt.savefig(f'media/{model}_{method}.pdf', format='pdf')

# Example usage: 
# my_plot("ModelName", "mswag_push", your_experiments_data, x_unit="Params", y_unit="Seconds")


In [None]:
for model in models:
    for method in methods:
        exps = get_exp(raw, model)
        print(exps)
        my_plot2(model, method, exps, norm=True)

In [None]:
import pandas as pd
from IPython.core.display import display, HTML

# Initialize an empty list to hold the rows of your DataFrame
rows = []

# Loop through each Experiment object in raw and collect details
for exp in raw:
    try:
        group_size = exp.run.group  # Replace with the correct attribute if needed
        if exp.run.group == "size" or exp.run.group == "size3":
            num_params = exp.config.get('num_params', 'N/A')
            num_devices = exp.config.get('num_device', 'N/A')
            num_particles = exp.config.get('num_particles', 'N/A')
            time = exp.summary.get('_runtime', 'N/A')
        
            rows.append([num_params, num_devices, num_particles, time])
    except KeyError as e:
        print(f"KeyError encountered: {e}")

# Create a DataFrame from the list of rows
df = pd.DataFrame(rows, columns=['Number of Parameters', 'Number of Devices', 'Number of Particles', 'Time'])

# Sort the DataFrame by the 'Number of Particles' and 'Number of Devices' columns
df = df.sort_values(by=['Number of Particles', 'Number of Devices'])

# Convert the DataFrame to HTML
html_table = df.to_html(index=False)

# Display the HTML table in the notebook
display(HTML(html_table))
