In [None]:
import os
import pandas as pd
import torch.nn as nn
from src.config import Configuration
from src.braid_env import BraidEnv
from src.braid_agent import BraidAgent
from src.braid_generator import BraidGenerator
from src.callbacks import BraidCallback

In [211]:
config = Configuration()
gen = BraidGenerator(n_strands=7, config=config, seed=42)

crossings_list = [4, 6, 8]
moves_list = [1, 5, 10]
count = 1000

for cr in crossings_list:
    for mv in moves_list:
        filename = f"train_demo_n7_c{cr}_m{mv}"
        path = os.path.join(config.DATA_DIR, "train_demo", f"{filename}.txt")
        
        gen.generate_dataset(count, cr, mv, path, compute_optimal=False)

Generating dataset: 1000 braids, 4 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c4_m1.txt
Generating dataset: 1000 braids, 4 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c4_m5.txt
Generating dataset: 1000 braids, 4 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c4_m10.txt
Generating dataset: 1000 braids, 6 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c6_m1.txt
Generating dataset: 1000 braids, 6 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c6_m5.txt
Generating dataset: 1000 braids, 6 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c6_m10.txt
Generating dataset: 1000 braids, 8 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c8_m1.txt
Generating dataset: 1000 braids, 8 crossings (Optimal=False)...
Done. Saved to ./data/train_demo/train_demo_n7_c8_m5.txt
Generating dataset: 1000 braid

In [212]:
config = Configuration()
gen = BraidGenerator(n_strands=7, config=config, seed=42)

crossings_list = [4, 6, 8]
moves_list = [1, 5, 10]
count = 100

for cr in crossings_list:
    for mv in moves_list:
        filename = f"test_demo_n7_c{cr}_m{mv}"
        path = os.path.join(config.DATA_DIR, "test_demo", f"{filename}.txt")
        
        gen.generate_dataset(count, cr, mv, path, compute_optimal=False)

Generating dataset: 100 braids, 4 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c4_m1.txt
Generating dataset: 100 braids, 4 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c4_m5.txt
Generating dataset: 100 braids, 4 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c4_m10.txt
Generating dataset: 100 braids, 6 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c6_m1.txt
Generating dataset: 100 braids, 6 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c6_m5.txt
Generating dataset: 100 braids, 6 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c6_m10.txt
Generating dataset: 100 braids, 8 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c8_m1.txt
Generating dataset: 100 braids, 8 crossings (Optimal=False)...
Done. Saved to ./data/test_demo/test_demo_n7_c8_m5.txt
Generating dataset: 100 braids, 8 crossings (Optimal=F

In [213]:
def get_filenames(prefix, strands):
    if isinstance(strands, int):
        strands = [strands]
        
    files = []
    n_crossings = [4, 6, 8]
    n_moves = [1, 5, 10]
    
    for n in strands:
        for c in n_crossings:
            for m in n_moves:
                files.append(f"{prefix}_n{n}_c{c}_m{m}.txt")
    return files

In [214]:
def evaluate(agent, files, config, phase_name):
    print(f"\n>>> Evaluation Phase: {phase_name}")
    total_success = 0
    total_episodes = 0
    
    for filename in files:
        full_path = os.path.join(config.DATA_DIR, "test_demo", filename)
        if not os.path.exists(full_path):
             continue

        try:
            env = BraidEnv(full_path, config.N_STRANDS, config.MAX_LEN, config)
            
            file_success = 0
            episodes = 20 
            
            for _ in range(episodes):
                solved, _ = agent.solve(env, max_steps=200)
                if solved: file_success += 1
            
            total_success += file_success
            total_episodes += episodes
            
        except Exception as e: 
            print(f"  Error testing {filename}: {e}")

    if total_episodes > 0:
        score = (total_success / total_episodes) * 100
        print(f"  Overall Success Rate: {score:.1f}% ({total_success}/{total_episodes})")
    else:
        print("  No episodes run.")

In [229]:
TOTAL_BUDGET = 300000
MAX_STRANDS = 10

config = Configuration(n_strands=MAX_STRANDS, max_len=100, total_timesteps=TOTAL_BUDGET)

hyperparams = {
    "learning_rate": config.LEARNING_RATE,
    "ent_coef": config.ENTROPY_COEF,
    "tensorboard_log": os.path.join(config.LOG_DIR, "Exp_7_SMALL"),
}

base_name = "Exp7_SMALL"
agent = BraidAgent(config, hyperparams, name=base_name)

pretrained_path = config.get_model_path(base_name + "_pretrained")

train_files = get_filenames("train_demo", [7])
test_files = get_filenames("test_demo", [7])

if train_files:
    steps_per_file = max(int(TOTAL_BUDGET / len(train_files)), 2048)
    print(f"Phase 1: Training on {len(train_files)} files ({steps_per_file} steps each)...")

    for filename in train_files:
        full_path = os.path.join(config.DATA_DIR, "train_demo", filename)
        
        if not os.path.exists(full_path):
            print(f"  [SKIP] Missing: {filename}")
            continue
            
        try:
            env = BraidEnv(full_path, config.N_STRANDS, config.MAX_LEN, config)
            callback = BraidCallback()
            
            log_label = "train_phase" 
            
            agent.train(env, steps_per_file, pretrained_path, callback, log_label)
            
        except Exception as e:
            print(f"  [FAIL] {filename}: {e}")
else:
    print("CRITICAL: No training files found check your get_filenames parameters.")

evaluate(agent, test_files, config, "PRETRAINED")

Phase 1: Training on 9 files (33333 steps each)...
[Exp7_SMALL] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...
[Exp7_SMALL_pretrained] Training for 33333 steps...

>>> Evaluation Phase: PRETRAINED
  Overall Success Rate: 93.3% (168/180)


In [230]:
test_file = "test_demo_n7_c8_m10.txt"
full_path = os.path.join(config.DATA_DIR, "test_demo", test_file)

env = BraidEnv(full_path, config.N_STRANDS, config.MAX_LEN, config)

obs, _ = env.reset() 
start_word = env.current_braid.word 

print(f"\n--- Solving Braid from {test_file} ---")
print(f"Start: {start_word}")

move_names = ["Commute", "R3", "Remove", "Insert"]

for i in range(50):
    action = agent.predict(env._get_obs(), env)
    
    move_type = action // config.MAX_LEN
    index = action % config.MAX_LEN
    
    next_obs, reward, done, truncated, info = env.step(action)
    
    current_word = [x for x in next_obs[:-1] if x != 0]

    status = "Success" if info['success'] else "Failed"
    print(f"Step {i+1}: {move_names[move_type]} @ {index} -> {status}")
    print(f"       Braid: {current_word}")

    if done or len(current_word) == 0:
        print("\nSOLVED!")
        break
        
    if truncated:
        print("\nFAILED (Max length reached)")
        break


--- Solving Braid from test_demo_n7_c8_m10.txt ---
Start: [6, -6, 2, 6, -6, 3, -3, -2]
Step 1: Remove @ 5 -> Success
       Braid: [np.int32(6), np.int32(-6), np.int32(2), np.int32(6), np.int32(-6), np.int32(-2)]
Step 2: Remove @ 0 -> Success
       Braid: [np.int32(2), np.int32(6), np.int32(-6), np.int32(-2)]
Step 3: Remove @ 1 -> Success
       Braid: [np.int32(2), np.int32(-2)]
Step 4: Remove @ 0 -> Success
       Braid: []

SOLVED!
