# FlatGrad: Proof of Concept Experiment

This notebook runs the proof of concept experiment for lambda measurability and target lambda regularization.

## What this does:
- Clones the FlatGrad repository
- Installs dependencies
- Runs the proof of concept experiment with target lambda regularization


## Setup: Clone Repository and Install Dependencies


In [None]:
# Clone the repository
# Replace YOUR_USERNAME with your actual GitHub username
!git clone https://github.com/jacobposchl/flatgrad
%cd FlatGrad


In [None]:
# Install requirements
%pip install -r requirements.txt


## Run Proof of Concept Experiment

The experiment will:
- Train models on MNIST and CIFAR-10
- Test different target lambda values (default: [-4.0, -2.0, -1.0, 0, 1.0, 2.0, 4.0])
- Generate plots and save results to `results/proof_of_concept/`


In [None]:
# Run with default target lambdas: [-4.0, -2.0, -1.0, 0, 1.0, 2.0, 4.0]
!python experiments/proof_of_concept/proof_of_concept.py


## Optional: Run with Custom Target Lambdas

You can specify custom target lambda values and number of epochs:


In [None]:
# Example: Run with custom target lambdas
# !python experiments/proof_of_concept/proof_of_concept.py --target-lambda -3.0 -2.0 -1.0 0.0 1.0

# You can also specify number of epochs (default is 50)
# !python experiments/proof_of_concept/proof_of_concept.py --epochs 30 --target-lambda -2.0 -1.0 0.0


## View Results

Results are saved to `results/proof_of_concept/`:
- `results.txt`: Summary of all experiments
- `mnist/` and `cifar10/`: Plots and metrics for each dataset


In [None]:
# List result files
!ls -la results/proof_of_concept/


In [None]:
# View summary results
!head -50 results/proof_of_concept/results.txt


In [None]:
# Display plots (if running in Colab)
from IPython.display import Image, display
import os

# Display example plots
plot_paths = [
    "results/proof_of_concept/mnist/lambda_evolution/multi_reg_comparison.png",
    "results/proof_of_concept/mnist/metrics_vs_reg/accuracy.png",
    "results/proof_of_concept/cifar10/lambda_evolution/multi_reg_comparison.png",
    "results/proof_of_concept/cifar10/metrics_vs_reg/accuracy.png"
]

for path in plot_paths:
    if os.path.exists(path):
        display(Image(path))
        print(f"\n{path}\n")
