In [None]:
# Jupyter Notebook: notebooks/analysis.ipynb

# Import necessary libraries
import os
import json
import torch
from train import main as train_model
from parse_config import ConfigParser
import matplotlib.pyplot as plt

# Define paths to configuration files
config_dir = "../configs/"
config_files = ["overfit.json", "underfit.json", "optimal.json"]

# Initialize results dictionary
results = {}

# Train models for each configuration
for config_file in config_files:
    print(f"Training with configuration: {config_file}")
    config_path = os.path.join(config_dir, config_file)
    
    # Load configuration
    with open(config_path, "r") as f:
        config = json.load(f)
    
    # Parse configuration
    config_parser = ConfigParser(config)
    
    # Train the model
    try:
        train_model(config_parser)
        print(f"Training completed for {config_file}")
    except Exception as e:
        print(f"Error during training with {config_file}: {e}")
        continue
    
    # Collect results (e.g., training and validation loss curves)
    # Assuming the training script logs results in a specific directory
    log_dir = config["trainer"]["save_dir"]
    results[config_file] = {
        "train_loss": [],  # Replace with actual training loss data
        "val_loss": [],    # Replace with actual validation loss data
    }

# Visualize results
plt.figure(figsize=(12, 6))
for config_file, data in results.items():
    plt.plot(data["train_loss"], label=f"{config_file} - Train Loss")
    plt.plot(data["val_loss"], label=f"{config_file} - Validation Loss")

plt.title("Training and Validation Loss Curves")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()

# Analyze results
print("Analysis:")
for config_file, data in results.items():
    print(f"\nConfiguration: {config_file}")
    print(f"Final Training Loss: {data['train_loss'][-1]}")
    print(f"Final Validation Loss: {data['val_loss'][-1]}")
    if data["val_loss"][-1] > data["train_loss"][-1]:
        print("Observation: Model is underfitting.")
    elif data["val_loss"][-1] < data["train_loss"][-1]:
        print("Observation: Model is overfitting.")
    else:
        print("Observation: Model is well-balanced.")

# Document findings and conclusions
print("\nConclusions:")
print("1. Overfitting configuration shows...")
print("2. Underfitting configuration shows...")
print("3. Optimal configuration shows...")