<div dir=ltr align=center>
    <font color=0F5298 size=7>Neurosymbolic VQA Program Generator</font><br>
    <br>
    <font color=D2691E size=5>Part 1: Supervised Training (Seq2Seq)</font><br>
</div>

<br/>

---

## **Goal: Supervised Training**

Our first approach is **supervised learning**, also known as "behavioral cloning." The goal is to train a sequence-to-sequence (Seq2Seq) model to imitate the ground-truth data perfectly.

The model will take a tokenized question as input and be trained to output the *exact* tokenized program from the dataset.

**Input:** `[<START>, 'Is', 'there', 'a', 'large', 'sphere', '...', '<END>']` \
**Target:** `[<START>, 'scene', 'filter_shape[sphere]', 'filter_size[large]', 'exist', '<END>']`

We will train two different architectures for this task:
1.  **LSTM-based Seq2Seq** with Attention
2.  **Transformer-based Seq2Seq**

All the logic is contained in `scripts/train.py`, which we will call from this notebook.

In [None]:
import sys
import os
import re
import matplotlib.pyplot as plt

# Add the project root to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Import config for file paths
import src.config as config

## Strategy 1: **LSTM-based Seq2Seq Model**

This model uses a bidirectional LSTM as the encoder and a unidirectional LSTM with attention as the decoder. The full implementation is in `src/models/lstm_seq2seq.py`.

In [None]:
print("--- Starting Supervised Training for LSTM ---")
!python ../scripts/train.py \
    --model_type lstm \
    --train_mode supervised \
    --model_save_path ../models/supervised_lstm.pth \
    --num_iters 100000 \
    --learning_rate 1e-4 \
    --batch_size 64

print("--- LSTM Training Complete ---")

## Strategy 2: **Transformer-based Seq2Seq Model**

This model uses the standard Transformer architecture introduced in "Attention Is All You Need." The full implementation is in `src/models/transformer_seq2seq.py`.

In [None]:
print("--- Starting Supervised Training for Transformer ---")
!python ../scripts/train.py \
    --model_type transformer \
    --train_mode supervised \
    --model_save_path ../models/supervised_transformer.pth \
    --num_iters 100000 \
    --learning_rate 1e-4 \
    --batch_size 64

print("--- Transformer Training Complete ---")

## Step 3: **Analyze Results**

The training script saves logs to `logs/program_generator.log`. We can write a simple parser to extract the validation accuracy at each checkpoint and plot the learning curves for both models.

*(Note: In a real-world scenario, you would use a dedicated tool like TensorBoard or Weights & Biases for logging, which would make plotting much easier.)*

In [None]:
def parse_validation_accuracy(log_file_path, run_name):
    """A simple parser to extract validation accuracies from the log file."""
    accuracies = []
    run_started = False
    # Regex to find "Starting training run: [run_name]"
    start_regex = re.compile(f"Starting training run: {re.escape(run_name)}")
    # Regex to find "Validation Accuracy: [accuracy]%"
    acc_regex = re.compile(r"Validation Accuracy: (\d+\.\d+)%")
    # Regex to find the start of any *other* run
    other_run_regex = re.compile(r"Starting training run:")

    try:
        with open(log_file_path, 'r') as f:
            for line in f:
                if not run_started:
                    if start_regex.search(line):
                        run_started = True
                else:
                    # If we find the start of another run, stop
                    if other_run_regex.search(line) and not start_regex.search(line):
                        break
                    
                    # Find accuracy lines
                    match = acc_regex.search(line)
                    if match:
                        accuracies.append(float(match.group(1)))
    except FileNotFoundError:
        print(f"Log file not found at: {log_file_path}")
        return []
    
    return accuracies

# --- Parse Logs ---
log_path = os.path.join(config.LOG_DIR, "program_generator.log")

lstm_acc = parse_validation_accuracy(log_path, "lstm (supervised)")
transformer_acc = parse_validation_accuracy(log_path, "transformer (supervised)")

print(f"Found {len(lstm_acc)} validation points for LSTM.")
print(f"Found {len(transformer_acc)} validation points for Transformer.")

# --- Plot Results ---
plt.figure(figsize=(12, 6))
plt.title("Supervised Training: Validation Accuracy")
plt.xlabel(f"Validation Step (x {config.VAL_INTERVAL} iterations)")
plt.ylabel("Program Execution Accuracy (%)")

if lstm_acc:
    plt.plot(lstm_acc, label="LSTM (Supervised)", marker='o')
if transformer_acc:
    plt.plot(transformer_acc, label="Transformer (Supervised)", marker='x')

if lstm_acc or transformer_acc:
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()
else:
    print("\nNo data to plot. Did the training scripts run correctly and log validation?")