# Tactile Manipulation - RL Fine-tuning (Fixed Version)

This notebook fine-tunes the BC policy using PPO reinforcement learning.
Fixed version with proper setup and working environment.

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

In [None]:
# Install dependencies
!pip install -q mujoco h5py tensorboard matplotlib tqdm
!pip install -q stable-baselines3[extra] gymnasium

## 2. Clone Repository

In [None]:
# Clone your repository
!git clone https://github.com/ewernn/TactileManipulation.git
%cd TactileManipulation

# Verify structure
!ls -la
!ls tactile-rl/scripts/
!ls tactile-rl/franka_emika_panda/

## 3. Mount Google Drive for Checkpoints

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create RL checkpoint directory
import os
rl_checkpoint_dir = '/content/drive/MyDrive/tactile_manipulation_rl_checkpoints'
os.makedirs(rl_checkpoint_dir, exist_ok=True)
print(f"RL checkpoints will be saved to: {rl_checkpoint_dir}")

## 4. Load Pre-trained BC Policy (Optional)

In [None]:
# Load the BC policy from Drive
bc_model_path = '/content/drive/MyDrive/bc_policy_best.pt'

if os.path.exists(bc_model_path):
    print(f"Found BC model at: {bc_model_path}")
    
    # Load and verify
    checkpoint = torch.load(bc_model_path, weights_only=False)
    print(f"BC model trained for {checkpoint['epoch']} epochs")
    print(f"BC validation loss: {checkpoint['val_loss']:.4f}")
else:
    print("⚠️ BC model not found! RL will train from scratch.")
    bc_model_path = None

## 5. Test Environment First

In [None]:
# Quick test to ensure the environment loads
!cd tactile-rl && python scripts/train_rl_fixed.py --test --no_tactile

## 6. Start RL Training

In [None]:
# Create run directory with timestamp
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
rl_run_dir = f"{rl_checkpoint_dir}/rl_run_{timestamp}"

# Start RL training
# Note: This will take ~30-45 minutes for 1000 episodes on T4
!cd tactile-rl && python scripts/train_rl_fixed.py \
    --episodes 1000 \
    --learning_rate 3e-4 \
    --batch_size 32 \
    --save_dir {rl_run_dir} \
    --no_tactile  # Remove this flag if you want tactile sensing

# For longer training (better results):
# !cd tactile-rl && python scripts/train_rl_fixed.py \
#     --episodes 5000 \
#     --learning_rate 3e-4 \
#     --batch_size 32 \
#     --save_dir {rl_run_dir}

## 7. Monitor Training Progress

In [None]:
# Check if training logs exist
import os
import json
import matplotlib.pyplot as plt

log_file = os.path.join(rl_run_dir, 'training_log.json')

if os.path.exists(log_file):
    with open(log_file, 'r') as f:
        logs = json.load(f)
    
    episodes = [log['episode'] for log in logs]
    rewards = [log['reward'] for log in logs]
    
    # Plot rewards
    plt.figure(figsize=(10, 5))
    plt.plot(episodes, rewards, alpha=0.5)
    
    # Add moving average
    window_size = 50
    if len(rewards) >= window_size:
        moving_avg = []
        for i in range(len(rewards) - window_size + 1):
            window_avg = sum(rewards[i:i+window_size]) / window_size
            moving_avg.append(window_avg)
        plt.plot(episodes[window_size-1:], moving_avg, 'r-', linewidth=2, label=f'{window_size}-episode average')
    
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('RL Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"Latest episode: {episodes[-1]}")
    print(f"Latest reward: {rewards[-1]:.2f}")
    print(f"Average of last 50: {sum(rewards[-50:]) / len(rewards[-50:]):.2f}")
else:
    print("No training logs found yet. Training may still be starting...")

## 8. Load and Test Final Policy

In [None]:
# Test the trained policy
model_path = os.path.join(rl_run_dir, 'final_model.pth')

if os.path.exists(model_path):
    print(f"Testing trained model from: {model_path}")
    
    # Run evaluation
    !cd tactile-rl && python scripts/train_rl_fixed.py \
        --test \
        --load_model {model_path} \
        --episodes 10 \
        --no_tactile
else:
    print("Model not found. Training may still be in progress.")

## 9. Download Trained Model

In [None]:
# Download the final RL policy
from google.colab import files

model_files = [
    'final_model.pth',
    'training_log.json',
    'actor_checkpoint_final.pth',
    'critic_checkpoint_final.pth'
]

for file in model_files:
    file_path = os.path.join(rl_run_dir, file)
    if os.path.exists(file_path):
        print(f"Downloading {file}...")
        files.download(file_path)
    else:
        print(f"{file} not found")

# Also save to permanent location on Drive
final_model = os.path.join(rl_run_dir, 'final_model.pth')
if os.path.exists(final_model):
    !cp {final_model} /content/drive/MyDrive/rl_policy_final.pth
    print("\nRL policy saved to Google Drive as rl_policy_final.pth")

## 10. Compare BC vs RL Performance

In [None]:
# Simple comparison
print("=" * 60)
print("Performance Comparison:")
print("=" * 60)
print("\nBC Policy (Behavior Cloning):")
print("  - Expected success rate: 70-80%")
print("  - Training time: ~1 minute")
print("  - Learns from demonstrations only")

print("\nRL Policy (PPO Fine-tuning):")
print("  - Expected success rate: 85-95%")
print("  - Training time: ~30-45 minutes")
print("  - Learns from environment interaction")
print("  - More robust to variations")
print("=" * 60)

## Troubleshooting

If you encounter issues:

1. **Environment loading error**:
   - Run the test cell (Step 5) first
   - Use `--no_tactile` flag if tactile sensor causes issues

2. **CUDA out of memory**:
   - Reduce batch_size to 16 or 8
   - Restart runtime to clear GPU memory

3. **Training too slow**:
   - Reduce episodes to 500 for quick test
   - Use smaller batch_size

4. **Import errors**:
   - Make sure you're in the TactileManipulation directory
   - Check that all files were cloned properly