<div dir=ltr align=center>
    <font color=0F5298 size=7>Neurosymbolic VQA Program Generator</font><br>
    <br/>
    <font color=4169E1 size=5>Part 2: REINFORCE Fine-Tuning</font><br>
</div>

<br/>

---

## **Goal: REINFORCE Fine-Tuning**

Supervised learning (Part 1) has a major weakness: **exposure bias**.

1.  During training, it *always* sees the correct ground-truth program prefix (teacher forcing).
2.  During inference, it must generate the program based on its *own* previous predictions. If it makes one mistake, the error can cascade, leading to a completely wrong program.

**Solution:** Use Reinforcement Learning (RL) to fine-tune the model. Instead of punishing the model for *syntax* (i.e., not matching the ground-truth program), we reward it for *semantics* (i.e., generating a program that produces the **correct final answer**).

## **The REINFORCE Algorithm**

We use a simple policy gradient algorithm called **REINFORCE**.

1.  **Policy ($\pi_\theta$)**: Our Seq2Seq model. It defines a probability distribution over programs given a question.
2.  **Action ($a$)**: The program generated by *sampling* from the model's output distribution.
3.  **Reward ($R$)**: We run the sampled program through the symbolic `ClevrExecutor`. 
    - **`R = 1.0`** if `executor_answer == ground_truth_answer`
    - **`R = 0.0`** otherwise
4.  **Baseline ($b$)**: To reduce variance, we use a baseline. This is a moving average of past rewards. 
5.  **Advantage ($A$)**: `A = R - b`. 
    - If `A > 0` (the program worked better than average), we *increase* the probability of generating it.
    - If `A < 0` (the program worked worse than average), we *decrease* its probability.

The loss function is: 

$$L = - \sum_{t} \log \pi_\theta(a_t | s) \cdot A$$

This logic is implemented in `src/training/train_reinforce.py`.

In [None]:
import sys
import os

# Add the project root to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Import config for file paths
import src.config as config

## Strategy 1: **Fine-Tune LSTM with REINFORCE**

We load the weights from our best *supervised* LSTM model and use them as the starting point for fine-tuning. This is **critical**â€”training from scratch with RL would be extremely difficult.

In [None]:
print("--- Starting REINFORCE Fine-Tuning for LSTM ---")
!python ../scripts/train.py \
    --model_type lstm \
    --train_mode reinforce \
    --load_model ../models/supervised_lstm.pth \
    --model_save_path ../models/reinforce_lstm.pth \
    --num_iters 50000 \
    --learning_rate 1e-5 \
    --batch_size 64

print("--- LSTM REINFORCE Training Complete ---")

## Strategy 2: **Fine-Tune Transformer with REINFORCE**

Similarly, we load the supervised Transformer as our starting policy.

In [None]:
print("--- Starting REINFORCE Fine-Tuning for Transformer ---")
!python ../scripts/train.py \
    --model_type transformer \
    --train_mode reinforce \
    --load_model ../models/supervised_transformer.pth \
    --model_save_path ../models/reinforce_transformer.pth \
    --num_iters 50000 \
    --learning_rate 1e-5 \
    --batch_size 64

print("--- Transformer REINFORCE Training Complete ---")

## Step 3: **Analyze Results**

Let's parse the log file again, this time comparing the supervised and REINFORCE-tuned models. We hope to see that REINFORCE improves the final validation accuracy.

In [None]:
import re
import matplotlib.pyplot as plt
import os
import src.config as config

# Note: This cell assumes the first code cell in this notebook was run to add the project root to sys.path

def parse_validation_accuracy(log_file_path, run_name):
    """A simple parser to extract validation accuracies from the log file."""
    accuracies = []
    run_started = False
    start_regex = re.compile(f"Starting training run: {re.escape(run_name)}")
    acc_regex = re.compile(r"Validation Accuracy: (\d+\.\d+)%")
    other_run_regex = re.compile(r"Starting training run:")

    try:
        with open(log_file_path, 'r') as f:
            for line in f:
                if not run_started:
                    if start_regex.search(line):
                        run_started = True
                else:
                    if other_run_regex.search(line) and not start_regex.search(line):
                        break
                    match = acc_regex.search(line)
                    if match:
                        accuracies.append(float(match.group(1)))
    except FileNotFoundError:
        print(f"Log file not found at: {log_file_path}")
        return []
    
    return accuracies

# --- Parse Logs ---
log_path = os.path.join(config.LOG_DIR, "program_generator.log")

lstm_sup = parse_validation_accuracy(log_path, "lstm (supervised)")
lstm_rl = parse_validation_accuracy(log_path, "lstm (reinforce)")
trans_sup = parse_validation_accuracy(log_path, "transformer (supervised)")
trans_rl = parse_validation_accuracy(log_path, "transformer (reinforce)")

print(f"Found {len(lstm_sup)} supervised LSTM points.")
print(f"Found {len(lstm_rl)} REINFORCE LSTM points.")
print(f"Found {len(trans_sup)} supervised Transformer points.")
print(f"Found {len(trans_rl)} REINFORCE Transformer points.")

# --- Plot Results ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6), sharey=True)

fig.suptitle('Supervised vs. REINFORCE Fine-Tuning Validation Accuracy')

# LSTM Plot
ax1.set_title("LSTM Model")
ax1.set_xlabel(f"Validation Step (x {config.VAL_INTERVAL} iterations)")
ax1.set_ylabel("Program Execution Accuracy (%)")
if lstm_sup:
    ax1.plot(lstm_sup, label="Supervised", marker='o', linestyle='--')
if lstm_rl:
    ax1.plot(lstm_rl, label="REINFORCE", marker='o')
ax1.legend()
ax1.grid(True, linestyle='--', alpha=0.6)

# Transformer Plot
ax2.set_title("Transformer Model")
ax2.set_xlabel(f"Validation Step (x {config.VAL_INTERVAL} iterations)")
if trans_sup:
    ax2.plot(trans_sup, label="Supervised", marker='x', linestyle='--')
if trans_rl:
    ax2.plot(trans_rl, label="REINFORCE", marker='x')
ax2.legend()
ax2.grid(True, linestyle='--', alpha=0.6)

plt.show()