# üöÄ ARAI Saliency Model Training on Google Colab

This notebook trains a U-Net saliency prediction model for the ARAI system.

**‚è±Ô∏è Estimated Time:**
- Synthetic data: 30 minutes
- Full SALICON: 2-4 hours

**üìã Before Starting:**
1. Click **Runtime ‚Üí Change runtime type**
2. Select **GPU** as Hardware accelerator
3. Click **Save**
4. Run cells in order

---

## 1Ô∏è‚É£ Setup: Check GPU and Mount Drive

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: No GPU detected!")
    print("   Go to: Runtime ‚Üí Change runtime type ‚Üí GPU")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create directories for saving models
import os
os.makedirs('/content/drive/MyDrive/ARAI/models', exist_ok=True)
os.makedirs('/content/drive/MyDrive/ARAI/training_logs', exist_ok=True)

print("\n‚úÖ Google Drive mounted successfully!")
print("üìÅ Models will be saved to: /content/drive/MyDrive/ARAI/models")

## 2Ô∏è‚É£ Clone Repository and Install Dependencies

In [None]:
# Clone ARAI repository
!git clone https://github.com/kavishaniy/ARAI-System.git /content/arai
%cd /content/arai/backend

print("‚úÖ Repository cloned successfully!")

In [None]:
# Install required packages (most are pre-installed in Colab)
!pip install -q pillow scipy tqdm matplotlib scikit-image

print("‚úÖ Dependencies installed!")

## 3Ô∏è‚É£ Choose Training Option

### Option A: Quick Test with Synthetic Data (30 minutes) ‚ö°
Run this cell for a quick test:

In [None]:
# Create synthetic dataset for quick testing
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
import os

# Create directories
os.makedirs('/content/synthetic_data/images', exist_ok=True)
os.makedirs('/content/synthetic_data/maps', exist_ok=True)

print("Creating 200 synthetic training samples...")
for i in tqdm(range(200)):
    # Create UI-like image
    img = np.random.randint(200, 255, (256, 256, 3), dtype=np.uint8)
    
    # Add UI elements (header, buttons, text areas)
    cv2.rectangle(img, (20, 20), (236, 60), (66, 133, 244), -1)  # Header
    cv2.putText(img, 'Title', (80, 45), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
    cv2.rectangle(img, (20, 80), (236, 180), (52, 168, 83), 3)  # CTA Button
    cv2.putText(img, 'Click Me', (70, 135), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 2)
    cv2.rectangle(img, (20, 190), (236, 240), (200, 200, 200), 1)  # Text box
    
    # Generate saliency map
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    
    # Add center bias (humans look at center more)
    y, x = np.ogrid[:256, :256]
    center_bias = np.exp(-((x - 128)**2 + (y - 128)**2) / (2 * 80**2))
    
    # Combine edge detection and center bias
    saliency = (edges / 255.0 * 0.6 + center_bias * 0.4) * 255
    saliency = saliency.astype(np.uint8)
    
    # Save
    Image.fromarray(img).save(f'/content/synthetic_data/images/img_{i:04d}.jpg')
    Image.fromarray(saliency).save(f'/content/synthetic_data/maps/map_{i:04d}.png')

print(f"\n‚úÖ Created 200 synthetic training samples!")
print("üìÅ Location: /content/synthetic_data/")

### Option B: Download SALICON Dataset (for production) üéØ

**Note:** This downloads ~2.5GB of data and takes 30-60 minutes.

Skip this if you're using synthetic data above.

In [None]:
# Download SALICON dataset
!pip install -q gdown
import gdown

os.makedirs('/content/salicon/images', exist_ok=True)
os.makedirs('/content/salicon/maps', exist_ok=True)

print("Downloading SALICON dataset...")
print("This may take 30-60 minutes for ~2.5GB of data\n")

# SALICON Training Images (2GB)
print("1/2 Downloading training images...")
gdown.download(
    'https://drive.google.com/uc?id=1g8j-hTT2exMGdiN84XDHPJJgDXq8YCDH',
    '/content/salicon_images.zip',
    quiet=False
)

# SALICON Training Maps (500MB)
print("\n2/2 Downloading saliency maps...")
gdown.download(
    'https://drive.google.com/uc?id=1jHbhwlMXXFvvLM0dAb0qC2vvXBkF8ZdL',
    '/content/salicon_maps.zip',
    quiet=False
)

# Extract files
print("\nExtracting files...")
!unzip -q /content/salicon_images.zip -d /content/salicon/images/
!unzip -q /content/salicon_maps.zip -d /content/salicon/maps/

print("\n‚úÖ SALICON dataset ready!")
print(f"   Images: {len(os.listdir('/content/salicon/images'))}")
print(f"   Maps: {len(os.listdir('/content/salicon/maps'))}")

## 4Ô∏è‚É£ Train the Model üéì

Choose one of the training cells below based on your dataset choice.

In [None]:
# Train with SYNTHETIC DATA (quick test - 30 minutes)
%cd /content/arai/backend/training

!python train_saliency.py \
    --image_dir /content/synthetic_data/images \
    --saliency_dir /content/synthetic_data/maps \
    --batch_size 16 \
    --num_epochs 20 \
    --learning_rate 1e-4 \
    --save_dir /content/drive/MyDrive/ARAI/models \
    --device cuda

print("\n" + "="*60)
print("‚úÖ Training complete!")
print("üìÅ Model saved to: /content/drive/MyDrive/ARAI/models/saliency_model.pth")
print("="*60)

In [None]:
# Train with SALICON DATASET (production - 2-4 hours)
%cd /content/arai/backend/training

!python train_saliency.py \
    --image_dir /content/salicon/images \
    --saliency_dir /content/salicon/maps \
    --batch_size 16 \
    --num_epochs 50 \
    --learning_rate 1e-4 \
    --save_dir /content/drive/MyDrive/ARAI/models \
    --device cuda

print("\n" + "="*60)
print("‚úÖ Training complete!")
print("üìÅ Model saved to: /content/drive/MyDrive/ARAI/models/saliency_model.pth")
print("="*60)

## 5Ô∏è‚É£ Download Trained Model üì•

In [None]:
# Download model to your computer
from google.colab import files

# Copy from Drive to Colab workspace (for faster download)
!cp /content/drive/MyDrive/ARAI/models/saliency_model.pth /content/

# Check file size
import os
size_mb = os.path.getsize('/content/saliency_model.pth') / 1e6
print(f"Model size: {size_mb:.1f} MB")

# Download to your computer
print("\nDownloading model to your computer...")
files.download('/content/saliency_model.pth')

print("\n‚úÖ Download complete!")
print("\nüìã Next steps:")
print("   1. Move the downloaded file to: backend/models/saliency_model.pth")
print("   2. Restart your backend server")
print("   3. The model will be automatically detected and used!")

## 6Ô∏è‚É£ View Training Results üìä

In [None]:
# Display training curves
from IPython.display import Image as IPImage
import matplotlib.pyplot as plt

curves_path = '/content/drive/MyDrive/ARAI/models/training_curves.png'

if os.path.exists(curves_path):
    display(IPImage(filename=curves_path))
    print("\n‚úÖ Training curves displayed above")
else:
    print("‚ö†Ô∏è Training curves not found. Did training complete?")

## 7Ô∏è‚É£ Test the Model (Optional) üß™

In [None]:
# Test the trained model on a sample image
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import sys

sys.path.append('/content/arai/backend')
from app.ai_modules.comprehensive_attention_analyzer import SaliencyModel
from training.dataset import SaliencyDataset

# Load model
model = SaliencyModel()
model.load_state_dict(torch.load('/content/drive/MyDrive/ARAI/models/saliency_model.pth'))
model.eval()

if torch.cuda.is_available():
    model = model.cuda()

# Load a test image
test_img_path = '/content/synthetic_data/images/img_0000.jpg'  # Change this path
test_img = Image.open(test_img_path).convert('RGB')

# Preprocess
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img_tensor = transform(test_img).unsqueeze(0)
if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()

# Predict
with torch.no_grad():
    saliency_pred = model(img_tensor)
    saliency_pred = saliency_pred.squeeze().cpu().numpy()

# Display results
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].imshow(test_img)
axes[0].set_title('Input Image')
axes[0].axis('off')

axes[1].imshow(saliency_pred, cmap='hot')
axes[1].set_title('Predicted Saliency Map')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print("\n‚úÖ Model test complete!")
print(f"   Saliency range: [{saliency_pred.min():.3f}, {saliency_pred.max():.3f}]")

## üéâ Done!

Your saliency model is trained and saved to Google Drive!

### Next Steps:

1. **Download the model** (if you haven't already):
   - Go to your Google Drive: `My Drive ‚Üí ARAI ‚Üí models`
   - Download `saliency_model.pth` (~45MB)

2. **Place in your project**:
   ```bash
   cd /Users/kavishani/Documents/FYP/arai-system/backend
   mkdir -p models
   # Move downloaded file here
   ```

3. **Restart your backend**:
   ```bash
   cd backend
   python -m uvicorn app.main:app --reload
   ```

4. **Test in your web app**:
   - Upload a design
   - Check the saliency heatmap
   - Compare with previous heuristic-based results

**Expected Improvement:**
- Accuracy: 70-80% ‚Üí 85-95%
- More realistic attention patterns
- Better prediction of user focus areas

---

**Need help?** Check `GOOGLE_COLAB_TRAINING_GUIDE.md` for troubleshooting!