In [1]:
import argparse
import os
import pandas as pd
from typing import List
from stable_baselines3 import PPO
from src.config import Configuration
from src.braid_env import BraidEnv
from src.braid_agent import BraidAgent
from src.callbacks import BraidCallback

In [2]:
EXPERIMENTS = {
    1:  {"name": "Baseline_N3",   "train_n": [3], "test_n": [3]},
    2:  {"name": "Baseline_N5",   "train_n": [5], "test_n": [5]},
    3:  {"name": "Baseline_N7",   "train_n": [7], "test_n": [7]},
    
    4:  {"name": "Transfer_3to5", "train_n": [3], "test_n": [5]},
    5:  {"name": "Transfer_5to7", "train_n": [5], "test_n": [7]},
    
    6:  {"name": "Transfer_5to3", "train_n": [5], "test_n": [3]},
    7:  {"name": "Transfer_7to5", "train_n": [7], "test_n": [5]},
    
    8:  {"name": "Transfer_7to3", "train_n": [7], "test_n": [3]},
    
    9:  {"name": "Interpolation", "train_n": [3, 7],    "test_n": [5]},
    10: {"name": "Generalist",    "train_n": [3, 5, 7], "test_n": [3, 5, 7]},
    11: {"name": "OOD_N9",        "train_n": [3, 5, 7], "test_n": [9]}
}

In [3]:
def get_filenames(prefix, strands):
    files = []
    crossings = [8, 16, 24] 
    moves = [10, 50, 100]
    for n in strands:
        for c in crossings:
            for m in moves:
                files.append(f"{prefix}_n{n}_c{c}_m{m}.txt")
    return files

In [4]:
def train(agent: BraidAgent, n_strands: List[int], folder: str):
    print(f"\n--- Starting Training Phase ({folder}) ---")
    print(f"Strands: {n_strands}")
    
    prefix = "train" if folder == "train" else "ft"
    files = get_filenames(prefix, n_strands)
    
    total_budget = 300_000 
    if len(files) > 0:
        steps_per_file = int(total_budget / len(files))
        steps_per_file = max(steps_per_file, 5000)
    else:
        steps_per_file = 0
        print("Warning: No files found to train on.")

    is_finetuning = (folder == "finetune")

    for filename in files:
        full_path = os.path.join(agent.config.DATA_DIR, folder, f"{filename}")
        
        print(f">>> Training on {filename} ({steps_per_file} steps)...")
        try:
            env = BraidEnv(full_path, agent.config.N_STRANDS, agent.config.MAX_LEN, agent.config, finetune_mode=is_finetuning)

            callback = BraidCallback()
            
            log_label = f"{folder}_phase"
            
            agent.train(env, steps_per_file, agent.config.get_model_path(agent.name), callback=callback, log_name=log_label)

        except ValueError as e:
            print(f"    SKIPPING {filename}: {e}")
            continue
        except Exception as e:
             print(f"    ERROR on {filename}: {e}")
             continue

In [5]:
def test(agent: BraidAgent, n_strands: List[int]):
    print(f"\n--- Starting Testing Phase ---")
    print(f"Test Strands: {n_strands}")
    
    test_files = get_filenames("test", n_strands)

    exp_log_dir = os.path.join(agent.config.LOG_DIR, agent.name)
    
    agent.metrics.reset()
    total_success = 0
    total_episodes = 0
    results_table = []

    for filename in test_files:
        full_path = os.path.join(agent.config.DATA_DIR, "test", f"{filename}")
        print(f"Testing on {filename}...", end=" ")
        
        try:
            env_test = BraidEnv(full_path, agent.config.N_STRANDS, agent.config.MAX_LEN, agent.config)

            file_success = 0
            file_optimal = 0
            total_steps_taken = 0
            optimality_gap_sum = 0
            valid_gap_count = 0
            
            file_episodes = 20
            
            for _ in range(file_episodes):
                obs, _ = env_test.reset()
                
                optimal_steps = -1
                if hasattr(env_test.current_braid, 'optimal_steps'):
                    optimal_steps = env_test.current_braid.optimal_steps

                solved, steps = agent.solve(env_test, max_steps=200)
                
                if solved: 
                    file_success += 1
                    total_steps_taken += steps
                    
                    if optimal_steps > 0:
                        gap = steps - optimal_steps
                        optimality_gap_sum += gap
                        valid_gap_count += 1
                        
                        if gap == 0:
                            file_optimal += 1
            
            total_success += file_success
            total_episodes += file_episodes
            
            rate = (file_success/file_episodes)*100 if file_episodes > 0 else 0
            opt_rate = (file_optimal/file_episodes)*100 if file_episodes > 0 else 0
            
            avg_gap = optimality_gap_sum / valid_gap_count if valid_gap_count > 0 else -1
            gap_str = f"{avg_gap:.1f}" if avg_gap != -1 else "N/A"

            print(f"Score: {rate:.0f}% | Opt: {opt_rate:.0f}% | Gap: {gap_str}")
            
            results_table.append({
                "File": filename, 
                "Score": rate,
                "Optimal_Score": opt_rate,
                "Avg_Gap": avg_gap
            })
            
        except ValueError as e:
            print(f"FAILED to load: {e}")
        except Exception as e:
            print(f"ERROR: {e}")

    if results_table:
        df = pd.DataFrame(results_table)
        avg_score = df["Score"].mean()
        avg_opt = df["Optimal_Score"].mean()
        
        print(f"\n=== RESULTS: {agent.name} ===")
        print(f"Overall Success Rate: {avg_score:.2f}%")
        print(f"Overall Optimal Rate: {avg_opt:.2f}%")

        os.makedirs(exp_log_dir, exist_ok=True)
        csv_path = os.path.join(exp_log_dir, f"results_test.csv")
        df.to_csv(csv_path, index=False)
        print(f"Detailed results saved to {csv_path}")
    else:
        print("No results generated.")

In [6]:
def run_experiment(exp_id):
    if exp_id not in EXPERIMENTS:
        print(f"Experiment {exp_id} not found.")
        return

    exp = EXPERIMENTS[exp_id]
    print(f"\n{'='*60}")
    print(f"RUNNING EXPERIMENT {exp_id}: {exp['name']}")
    print(f"Train Strands: {exp['train_n']} -> Test Strands: {exp['test_n']}")
    print(f"{'='*60}")

    max_strands_needed = max(exp['train_n'] + exp['test_n'])
    
    config = Configuration(
        n_strands=max_strands_needed, 
        max_len=100,
        total_timesteps=300_000,
        learning_rate=0.0003
    )

    agent_name = f"Exp{exp_id}_{exp['name']}_pretrained"
    
    exp_log_dir = os.path.join(config.LOG_DIR, f"Exp_{exp_id}_{exp['name']}")
    
    hyperparams = {
        "learning_rate": config.LEARNING_RATE,
        "ent_coef": config.ENTROPY_COEF,
        "tensorboard_log": exp_log_dir
    }

    agent = BraidAgent(config, hyperparams, name=agent_name)
    train(agent, exp['train_n'], folder="train")
    
    print("\n>>> Testing Pre-Trained Model...")
    test(agent, exp['test_n'])

    agent.name = f"Exp{exp_id}_{exp['name']}_finetuned"
    
    print(f"\n>>> Switching to Fine-Tuning Mode (New Model Name: {agent.name})")
    
    train(agent, exp['train_n'], folder="finetune")
    
    print("\n>>> Testing Fine-Tuned Model...")
    test(agent, exp['test_n'])

    print(f"\nExperiment {exp_id} Complete.")

In [9]:
for i in range(1, 11):
    run_experiment(i)


RUNNING EXPERIMENT 1: Baseline_N3
Train Strands: [3] -> Test Strands: [3]

--- Starting Training Phase (train) ---
Strands: [3]
>>> Training on train_n3_c8_m10.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Training for 33333 steps...
>>> Training on train_n3_c8_m50.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Training for 33333 steps...
>>> Training on train_n3_c8_m100.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Training for 33333 steps...
>>> Training on train_n3_c16_m10.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Training for 33333 steps...
>>> Training on train_n3_c16_m50.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Training for 33333 steps...
>>> Training on train_n3_c16_m100.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Training for 33333 steps...
>>> Training on train_n3_c24_m10.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Training for 33333 steps...
>>> Training on train_n3_c24_m50.txt (33333 steps)...
[Exp1_Baseline_N3_pretrained] Trai