Implementation of test of parity scaling laws in Jupyter notebook.

The code contained here is very similar to that in the rest of the repo, but is not guarenteed to be identical.

This notebook is included for ease of use.

In [1]:
# Package imports

# pip install if necessary
'''
!pip install pandas
!pip install numpy
!pip install tqdm
!pip install matplotlib
!pip install seaborn
!pip install torch
!pip install pathlib
!pip install logging
'''

import numpy as np
import random
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pathlib import Path
import logging
import torch.nn.functional as F


In [13]:
# Parameters

# Parameters
PARAMS = {
    'n_tasks': 1, # number of unique tasks being trained over
    'len_taskcode': 4, # number of bits in the task code
    'num_checks': 5, # number of bits in the message that are used to determine the output
    'len_message': 16, # number of bits in the message
    'num_samples': 1000, # number of samples to generate for each task
    'input_size': 20,  # len_taskcode + len_message. Used for model initialisation
    'output_size': 1, # output size of the model. 1 for binary classification. Do not change
    'learning_rate': 0.005, # constant learning rate. Could introduce a scheduler?
    'batch_size': 32, # batch size used in training. Will be the same throughout a run
    'flop_budget': 1e10, # total number of estimated flops expended per training run
    'task_sample_freq': 1e5,  # the rate at which performance is evaluated. Can give a big performance hit
    'plot_freq': 2e7,  # flop_budget/5
    'samples_per_task': 100 # number of samples to generate for each task in evaluation
}

# Define a range of model configurations
model_configs = [
    {"num_layers": 2, "hidden_size": 8},
    {"num_layers": 4, "hidden_size": 16},
    {"num_layers": 6, "hidden_size": 32},
    {"num_layers": 8, "hidden_size": 64},
    {"num_layers": 10, "hidden_size": 128},
    {"num_layers": 12, "hidden_size": 256},
    {"num_layers": 14, "hidden_size": 512} 
]

In [3]:
# Helper functions
    
def generate_random_binary_string(length):
    # Random binary string of length 'length'
    return ''.join(random.choice(['0', '1']) for _ in range(length))

def generate_dict(n_tasks, len_taskcode, num_checks, len_message):
    # Generate the task codes and their associated check bits
    unique_strings = set()
    tasks_dict = {}
    while len(unique_strings) < n_tasks:
        binary_string = generate_random_binary_string(len_taskcode)
        if binary_string not in unique_strings:
            unique_strings.add(binary_string)
            integer_list = [random.randint(0, len_message-1) for _ in range(num_checks)]
            tasks_dict[binary_string] = integer_list
    return tasks_dict

def generate_dataset(tasks_dict, num_samples, len_taskcode, len_message):
    # Generate a dataset of num_samples samples using tasks specified in tasks_dict
    data = np.zeros((num_samples, len_taskcode + len_message))
    value = np.zeros(num_samples)
    for i in range(num_samples):
        rand_task = np.random.choice(list(tasks_dict))
        rand_checkbits = tasks_dict[rand_task]
        message = generate_random_binary_string(len_message)
        parity_bit = sum(int(message[j]) for j in rand_checkbits) % 2
        data[i] = np.concatenate((np.array(list(rand_task)), np.array(list(message))))
        value[i] = parity_bit
    return [data, value]

def generate_dataset_for_task(task_code, tasks_dict, num_samples, len_taskcode, len_message):
    # Generate a dataset of num_samples samples for a specific task
    # Used primarily for evaluation. Very limited performance improvement from generate_dataset
    data = np.zeros((num_samples, len_taskcode + len_message))
    value = np.zeros(num_samples)
    rand_checkbits = tasks_dict[task_code]
    for i in range(num_samples):
        message = generate_random_binary_string(len_message)
        parity_bit = sum(int(message[j]) for j in rand_checkbits) % 2
        data[i] = np.concatenate((np.array(list(task_code)), np.array(list(message))))
        value[i] = parity_bit
    return [data, value]

class CustomDataset(Dataset):
    def __init__(self, dataframe, device):
        # Convert to numpy first for efficiency
        data_np = dataframe.iloc[:, :-1].values
        target_np = dataframe.iloc[:, -1].values
        
        # Single transfer to device. IO-aware for greater efficiency
        self.data = torch.from_numpy(data_np).float().to(device)
        self.target = torch.from_numpy(target_np).float().to(device)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

class NeuralNetwork(nn.Module):
    def __init__(self, input_size, output_size, num_layers, hidden_size):
        super(NeuralNetwork, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.BatchNorm1d(hidden_size))
        self.layers.append(nn.Linear(hidden_size, output_size))
    
    def forward(self, x):
        for i, layer in enumerate(self.layers[:-1]):
            if i == 0:
                x = F.relu(layer(x))
            else:
                x = layer(x)
                if i % 2 == 0:
                    x = F.relu(x)
        x = self.layers[-1](x)
        return x

In [4]:
# Plot intermediate progress

def plot_progress(loss_data, accuracy_data, task_accuracy_data, cumulative_flops, exp_dir):
    """
    Plot and save training progress.
    
    Args:
        loss_data: List of tuples (flops, loss)
        accuracy_data: List of tuples (flops, accuracy)
        task_accuracy_data: Dict of lists of tuples (flops, accuracy) for each task
        cumulative_flops: Current total FLOPs
        exp_dir: Path to experiment directory
    """
    # Create plots directory if it doesn't exist
    plots_dir = exp_dir / "intermediate_plots"
    plots_dir.mkdir(exist_ok=True)
    
    # Create figure with multiple subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    
    # Plot loss
    flops_loss, losses = zip(*loss_data)
    ax1.plot(flops_loss, losses)
    ax1.set_xlabel('FLOPs')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss vs FLOPs')
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    
    # Plot overall accuracy
    flops_acc, accuracies = zip(*accuracy_data)
    ax2.plot(flops_acc, accuracies)
    ax2.set_xlabel('FLOPs')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Overall Accuracy vs FLOPs')
    ax2.set_xscale('log')
    
    # Plot task-specific accuracies
    for task_idx, task_data in task_accuracy_data.items():
        if task_data:  # Check if there's data for this task
            flops_task, task_accuracies = zip(*task_data)
            ax3.plot(flops_task, task_accuracies, label=f'Task {task_idx}')
    ax3.set_xlabel('FLOPs')
    ax3.set_ylabel('Accuracy')
    ax3.set_title('Task-Specific Accuracy vs FLOPs')
    ax3.set_xscale('log')
    ax3.legend()
    
    # Adjust layout and save
    plt.tight_layout()
    plot_path = plots_dir / f"progress_{cumulative_flops:.2e}_flops.png"
    plt.savefig(plot_path)
    plt.close()
    
    # Log the plot creation
    logging.info(f"Saved progress plot at {cumulative_flops:.2e} FLOPs to {plot_path}")
    
    # Also save the data as CSV for later analysis
    data_dir = exp_dir / "intermediate_data"
    data_dir.mkdir(exist_ok=True)
    
    # Save loss and accuracy data
    df_metrics = pd.DataFrame({
        'flops': flops_loss,
        'loss': losses,
        'accuracy': accuracies
    })
    df_metrics.to_csv(data_dir / f"metrics_{cumulative_flops:.2e}_flops.csv", index=False)
    
    # Save task-specific accuracy data
    task_data_dict = {}
    for task_idx, task_data in task_accuracy_data.items():
        if task_data:
            flops_task, task_accuracies = zip(*task_data)
            task_data_dict[f'task_{task_idx}_flops'] = flops_task
            task_data_dict[f'task_{task_idx}_accuracy'] = task_accuracies
    
    df_tasks = pd.DataFrame(task_data_dict)
    df_tasks.to_csv(data_dir / f"task_accuracies_{cumulative_flops:.2e}_flops.csv", index=False)

In [5]:
# Main plot

def main_plot(all_loss_data, all_accuracy_data, all_task_accuracy_data, all_flops):

    # Create final plots
    plt.figure(figsize=(15, 5))

    # Remove the last element of each loss list. The final step of evaluation is always cut short and not the entire batch is used?

    plt.subplot(1, 3, 1)
    for i, config in enumerate(model_configs):
        flops, losses = zip(*all_loss_data[i])
        plt.loglog(flops, losses, label=f'{config["num_layers"]}x{config["hidden_size"]}')
    plt.xlabel('Cumulative FLOPs')
    plt.ylabel('Loss')
    plt.title('Loss vs FLOPs')
    plt.legend()

    plt.subplot(1, 3, 2)
    for i, config in enumerate(model_configs):
        flops, accuracies = zip(*all_accuracy_data[i])
        plt.semilogx(flops, accuracies, label=f'{config["num_layers"]}x{config["hidden_size"]}')
    plt.xlabel('Cumulative FLOPs')
    plt.ylabel('Accuracy')
    plt.title('Accuracy vs FLOPs')
    plt.legend()

    plt.subplot(1, 3, 3)
    for i, config in enumerate(model_configs):
        for task in range(n_tasks):
            flops, accuracies = zip(*all_task_accuracy_data[i][task])
            plt.semilogx(flops, accuracies, label=f'Task {task+1} - {config["num_layers"]}x{config["hidden_size"]}')
    plt.xlabel('Cumulative FLOPs')
    plt.ylabel('Task-specific Accuracy')
    plt.title('Task-specific Accuracies vs FLOPs')
    plt.legend()

    plt.tight_layout()
    plt.show()

In [6]:
# Count FLOPs
from typing import Tuple, Dict # if used, this should be moved to the import section

# No longer in use
#def count_flops_fvcore(model, input_size):
#    input_tensor = torch.randn(2, input_size)
#    flops = FlopCountAnalysis(model, input_tensor)
#    return flops.total() // 2

class FlopCounter:
    def __init__(self, model: nn.Module, input_size: int, batch_size: int):
        self.model = model
        self.input_size = input_size
        self.batch_size = batch_size
        
    def count_linear_flops(self, in_features: int, out_features: int) -> Dict[str, int]:
        """
        Count FLOPs for linear layer operations.
        Forward: Each output element requires in_features multiplications and in_features-1 additions
        Backward: Requires computing gradients for weights, biases, and input
        """
        forward_flops = self.batch_size * out_features * (2 * in_features - 1)  # mult-add pairs
        
        # Backward pass FLOPs:
        # 1. dL/dW computation: batch_size * in_features * out_features * 2
        # 2. dL/db computation: batch_size * out_features
        # 3. dL/dx computation: batch_size * in_features * out_features * 2
        backward_flops = (
            self.batch_size * in_features * out_features * 2 +  # dL/dW
            self.batch_size * out_features +                    # dL/db
            self.batch_size * in_features * out_features * 2    # dL/dx
        )
        
        return {
            "forward": forward_flops,
            "backward": backward_flops
        }
    
    def count_batch_norm_flops(self, num_features: int) -> Dict[str, int]:
        """
        Count FLOPs for batch normalization operations.
        Forward: Computing mean, variance, normalized values, and scaling
        Backward: Computing gradients for gamma, beta, and input
        """
        # Forward pass operations per feature:
        # 1. Mean calculation: batch_size additions
        # 2. Variance calculation: batch_size multiplications and additions
        # 3. Normalization: 4 operations per element (subtract mean, divide by std)
        # 4. Scale and shift: 2 operations per element
        forward_flops = self.batch_size * num_features * (7)
        
        # Backward pass operations:
        # 1. Gradients for gamma and beta: batch_size additions per feature
        # 2. Gradients for input: ~8 operations per element
        backward_flops = self.batch_size * num_features * 10
        
        return {
            "forward": forward_flops,
            "backward": backward_flops
        }
    
    def count_relu_flops(self, num_elements: int) -> Dict[str, int]:
        """
        Count FLOPs for ReLU activation.
        Forward: One comparison per element
        Backward: One multiplication per element (gradient is 0 or 1)
        """
        forward_flops = num_elements  # One comparison per element
        backward_flops = num_elements  # One multiplication per element
        
        return {
            "forward": forward_flops,
            "backward": backward_flops
        }
    
    def calculate_total_flops(self) -> Tuple[int, int]:
        """
        Calculate total FLOPs for both forward and backward passes through the entire model.
        Returns tuple of (forward_flops, backward_flops)
        """
        total_forward_flops = 0
        total_backward_flops = 0
        
        current_size = self.input_size
        
        for layer in self.model.layers:
            if isinstance(layer, nn.Linear):
                flops = self.count_linear_flops(layer.in_features, layer.out_features)
                total_forward_flops += flops["forward"]
                total_backward_flops += flops["backward"]
                current_size = layer.out_features
                
            elif isinstance(layer, nn.BatchNorm1d):
                flops = self.count_batch_norm_flops(current_size)
                total_forward_flops += flops["forward"]
                total_backward_flops += flops["backward"]
                
            # Count ReLU FLOPs after linear layers (except the last one)
            if isinstance(layer, nn.Linear) and layer != self.model.layers[-1]:
                flops = self.count_relu_flops(self.batch_size * current_size)
                total_forward_flops += flops["forward"]
                total_backward_flops += flops["backward"]
        
        return total_forward_flops, total_backward_flops

def get_flops_per_pass(model: nn.Module, input_size: int, batch_size: int) -> Tuple[int, int]:
    """
    Wrapper function to get FLOPs per forward and backward pass.
    
    Args:
        model: PyTorch neural network model
        input_size: Size of input features
        batch_size: Batch size used in training
        
    Returns:
        Tuple of (forward_flops, backward_flops)
    """
    counter = FlopCounter(model, input_size, batch_size)
    return counter.calculate_total_flops()


In [7]:
# Training loop

# This function has been significantly modified to incorporate the new count FLOPs function
# There may be bugs down the line related to how data is collected and displayed. This should be more thoroughly debugged

def train_and_evaluate(model, criterion, optimizer, flop_budget, tasks_dict):
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Move model to device
    model = model.to(device)
    
    # Initialize FlopCounter
    flop_counter = FlopCounter(model, input_size=len_taskcode + len_message, batch_size=batch_size)
    forward_flops, backward_flops = flop_counter.calculate_total_flops()
    
    loss_data = []
    accuracy_data = []
    task_accuracy_data = {i: [] for i in range(n_tasks)}
    cumulative_flops = 0
    epoch = 0
    last_task_sample = 0
    last_plot = 0

    print_rate = flop_budget / 1e1
    disp_flops = 0

    while cumulative_flops < flop_budget:
        if cumulative_flops - print_rate > disp_flops:
            print(f'cumulative_flops: {cumulative_flops} - flop_budget: {flop_budget} - Percentage completion: {(cumulative_flops/flop_budget)*100:.2f}%')
            disp_flops = cumulative_flops
            
        epoch += 1
        [data, value] = generate_dataset(tasks_dict, num_samples)
        df = pd.DataFrame(np.concatenate((data, value.reshape(-1, 1)), axis=1), 
                         columns=[f'feature_{i}' for i in range(len_taskcode + len_message)] + ['target'])
        
        dataset = CustomDataset(df, device)
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        epoch_loss = 0.0
        correct = 0
        total = 0
        model.train()

        for i, (inputs, labels) in enumerate(data_loader):
            # Forward pass
            outputs = model(inputs)
            batch_loss = criterion(outputs, labels.unsqueeze(1))
            predictions = (outputs >= 0.5).squeeze().long()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

            # Backward pass and optimization
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            epoch_loss += batch_loss.item() * inputs.size(0)
            
            # Calculate FLOPs
            batch_flops = forward_flops + backward_flops
            batch_flops += forward_flops + backward_flops
            cumulative_flops += batch_flops

            if cumulative_flops >= flop_budget and i > (len(data_loader) - 1):
                break

        avg_loss = epoch_loss / len(dataset)
        avg_accuracy = correct / total

        loss_data.append((cumulative_flops, avg_loss))
        accuracy_data.append((cumulative_flops, avg_accuracy))
        
        # Task-specific evaluation
        if cumulative_flops - last_task_sample >= task_sample_freq or cumulative_flops >= flop_budget:
            last_task_sample = cumulative_flops
            tasks_list = list(tasks_dict.keys())
            
            for task_idx, task_code in enumerate(tasks_list):
                [data_per_task, value_per_task] = generate_dataset_for_task(task_code, tasks_dict, samples_per_task)
                df_per_task = pd.DataFrame(np.concatenate((data_per_task, value_per_task.reshape(-1, 1)), axis=1), 
                                         columns=[f'feature_{i}' for i in range(len_taskcode + len_message)] + ['target'])
                dataset_per_task = CustomDataset(df_per_task, device)
                loader_per_task = DataLoader(dataset_per_task, batch_size=batch_size, shuffle=True)
                
                model.eval()
                task_correct = 0
                task_total = 0
                
                with torch.no_grad():
                    for inputs, labels in loader_per_task:
                        outputs = model(inputs)
                        predictions = (outputs >= 0.5).squeeze().long()
                        task_correct += (predictions == labels).sum().item()
                        task_total += labels.size(0)
                        cumulative_flops += forward_flops
                        
                task_accuracy = task_correct / task_total
                task_accuracy_data[task_idx].append((cumulative_flops, task_accuracy))
        
        if cumulative_flops - last_plot >= plot_freq:
            last_plot = cumulative_flops
            plot_progress(loss_data, accuracy_data, task_accuracy_data, cumulative_flops, exp_dir)

    return loss_data, accuracy_data, task_accuracy_data, cumulative_flops

In [11]:
# Main

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create experiment directory
    exp_name = f"parity_scaling_flops_{PARAMS['flop_budget']:.0e}"
    exp_dir = create_versioned_directory(Path("experiments"), exp_name)
    print(f"Experiment directory: {exp_dir}")

    print(f"FLOP budget: {PARAMS['flop_budget']}")
    
    tasks_dict = generate_dict(
        PARAMS['n_tasks'], 
        PARAMS['len_taskcode'], 
        PARAMS['num_checks'], 
        PARAMS['len_message']
    )
    print(f"Generated tasks dictionary with {len(tasks_dict)} tasks")
    print("tasks_dict = ", tasks_dict.items())
    
    all_results = []
    
    # Add progress bar for model configurations
    for config in tqdm(MODEL_CONFIGS, desc="Training models", position=0, leave=True):
        print(f"\nTraining model with {config['num_layers']} layers and hidden size {config['hidden_size']}")
        model = NeuralNetwork(
            PARAMS['input_size'], 
            PARAMS['output_size'], 
            config["num_layers"], 
            config["hidden_size"]
        ).to(device)
        
        results = train_and_evaluate(
            model=model,
            params=PARAMS,
            tasks_dict=tasks_dict,
            exp_dir=exp_dir,
            model_config=config
        )
        all_results.append(results)
    
    # Create final plots
    main_plot(all_results, exp_dir)

if __name__ == '__main__':
    main()

In [None]:
# Optional post-processing function

def create_seaborn_plots(exp_dir: Path, epoch_window: int = 10, confidence_interval: float = 0.95):
    """Create seaborn plots from saved data."""
    # Convert string to Path if necessary
    if isinstance(exp_dir, str):
        exp_dir = Path("experiments") / exp_dir
    elif isinstance(exp_dir, Path):
        exp_dir = Path("experiments") / exp_dir.name
    
    exp_dir = exp_dir.absolute()
    
    print(f"Current working directory: {os.getcwd()}")
    print(f"Looking for directory: {exp_dir}")
    
    # Verify experiment directory exists
    if not exp_dir.exists():
        print("\nContents of experiments directory:")
        try:
            experiments_dir = Path("experiments")
            for item in experiments_dir.iterdir():
                print(f"  {item.name}")
        except Exception as e:
            print(f"Error listing experiments directory: {e}")
        raise ValueError(f"Experiment directory does not exist: {exp_dir}")
    
    # Create seaborn plots directory
    seaborn_dir = exp_dir / "seaborn_plots"
    seaborn_dir.mkdir(exist_ok=True)
    logging.info(f"Created seaborn directory at: {seaborn_dir}")
    
    # Verify data directory exists
    data_dir = exp_dir / "intermediate_data"
    if not data_dir.exists():
        raise ValueError(f"Data directory does not exist: {data_dir}")
    
    # Get all metrics files
    metrics_files = sorted(data_dir.glob("metrics_*.csv"))
    task_files = sorted(data_dir.glob("task_accuracies_*.csv"))
    
    logging.info(f"Found {len(metrics_files)} metrics files and {len(task_files)} task files")
    
    if not metrics_files:
        raise ValueError(f"No metrics files found in {data_dir}")
    #######
    # Combine all metrics data
    all_metrics = []
    for file in metrics_files:
        df = pd.read_csv(file)
        all_metrics.append(df)
    metrics_df = pd.concat(all_metrics, ignore_index=True)
    
    # Create epoch bins for averaging, handling duplicates
    try:
        metrics_df['epoch_bin'] = pd.qcut(metrics_df['flops'], 
                                        q=len(metrics_df)//epoch_window, 
                                        labels=False,
                                        duplicates='drop')
    except ValueError:
        # If qcut fails, use regular cut with logarithmic bins
        n_bins = len(metrics_df)//epoch_window
        metrics_df['epoch_bin'] = pd.cut(np.log10(metrics_df['flops']),
                                       bins=n_bins,
                                       labels=False)
    
    # Task - specific accuracy data
    all_task_data = []
    for file in task_files:
        df = pd.read_csv(file)
        task_cols = [col for col in df.columns if 'task' in col]
        for i in range(0, len(task_cols), 2):
            flops_col = task_cols[i]
            acc_col = task_cols[i+1]
            task_num = flops_col.split('_')[1]
            
            task_df = pd.DataFrame({
                'flops': df[flops_col],
                'accuracy': df[acc_col],
                'task': f'Task {task_num}'
            })
            all_task_data.append(task_df)
    
    task_df = pd.concat(all_task_data, ignore_index=True)
    try:
        task_df['epoch_bin'] = pd.qcut(task_df['flops'], 
                                     q=len(task_df)//epoch_window, 
                                     labels=False,
                                     duplicates='drop')
    except ValueError:
        n_bins = len(task_df)//epoch_window
        task_df['epoch_bin'] = pd.cut(np.log10(task_df['flops']),
                                    bins=n_bins,
                                    labels=False)
        ########

    # Set up the plotting style
    sns.set_style("whitegrid")
    sns.set_palette("husl")
    
    # Create three subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))
    
    # 1. Loss plot
    sns.lineplot(data=metrics_df, 
                x='flops', 
                y='loss',
                errorbar=('ci', confidence_interval),
                ax=ax1)
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.set_title('Training Loss vs FLOPs')
    ax1.set_xlabel('FLOPs')
    ax1.set_ylabel('Loss')
    
    # 2. Accuracy plot
    sns.lineplot(data=metrics_df, 
                x='flops', 
                y='accuracy',
                errorbar=('ci', confidence_interval),
                ax=ax2)
    ax2.set_xscale('log')
    ax2.set_title('Overall Accuracy vs FLOPs')
    ax2.set_xlabel('FLOPs')
    ax2.set_ylabel('Accuracy')
    
    sns.lineplot(data=task_df, 
                x='flops', 
                y='accuracy',
                hue='task',
                errorbar=('ci', confidence_interval),
                ax=ax3)
    ax3.set_xscale('log')
    ax3.set_title('Task-Specific Accuracy vs FLOPs')
    ax3.set_xlabel('FLOPs')
    ax3.set_ylabel('Accuracy')
    
    # Adjust layout and save
    plt.tight_layout()
    plot_path = seaborn_dir / f"seaborn_summary_window{epoch_window}_ci{confidence_interval}.png"
    plt.savefig(plot_path)
    plt.close()
    
    # Also create separate plots for each metric with error bands
    metrics = ['loss', 'accuracy']
    for metric in metrics:
        plt.figure(figsize=(10, 6))
        sns.lineplot(data=metrics_df, 
                    x='flops', 
                    y=metric,
                    errorbar=('ci', confidence_interval))
        plt.xscale('log')
        if metric == 'loss':
            plt.yscale('log')
        plt.title(f'{metric.capitalize()} vs FLOPs')
        plt.xlabel('FLOPs')
        plt.ylabel(metric.capitalize())
        plt.tight_layout()
        plot_path = seaborn_dir / f"seaborn_{metric}_window{epoch_window}_ci{confidence_interval}.png"
        plt.savefig(plot_path)
        plt.close()
    
    # Create task-specific plot
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=task_df, 
                x='flops', 
                y='accuracy',
                hue='task',
                errorbar=('ci', confidence_interval))
    plt.xscale('log')
    plt.title('Task-Specific Accuracy vs FLOPs')
    plt.xlabel('FLOPs')
    plt.ylabel('Accuracy')
    plt.tight_layout()
    plot_path = seaborn_dir / f"seaborn_task_accuracy_window{epoch_window}_ci{confidence_interval}.png"
    plt.savefig(plot_path)
    plt.close()

    logging.info(f"Created seaborn plots in {seaborn_dir}")

#create_seaborn_plots(Path('parity_scaling_flops_1e+10__20250108_000607_v1'))
exp_dir = Path('/experiments/parity_scaling_flops_1e+10__20250108_000607_v1')
# When calling the function
#print(f"Original exp_dir: {exp_dir}")
#print(f"Absolute exp_dir: {exp_dir.resolve()}")
create_seaborn_plots(exp_dir.resolve(), epoch_window = 100) # larger epoch windows result in slowdown
#create_seaborn_plots(Path('workspace/project/experiments/parity_scaling_flops_1e+10__20250108_000607_v1'))