# SARATR-X Training Experiments on Google Colab

This notebook trains the SARATR-X model with MAE-HiViT tiny backbone using 4 different reconstruction techniques:
1. **Pixel-reconstruction**: SAR → SAR reconstruction
2. **MGF-reconstruction**: SAR → Multi-scale Gradient Features
3. **RGB-reconstruction**: SAR → RGB optical images
4. **Greyscale-reconstruction**: SAR → Greyscale optical images

## Important Notes:
- **GPU**: Optimized for T4 GPU (free tier)
- **Runtime**: Free tier has ~1.5 hour limit - checkpoints saved to Google Drive
- **Dataset**: Sentinel-1 & Sentinel-2 from Kaggle
- **Storage**: All results saved to Google Drive (persists after session ends)

---

## 1. Setup: Mount Google Drive for Persistent Storage

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create directories for experiments
import os
os.makedirs('/content/drive/MyDrive/SARATRX_experiments', exist_ok=True)
os.makedirs('/content/drive/MyDrive/SARATRX_experiments/checkpoints', exist_ok=True)
os.makedirs('/content/drive/MyDrive/SARATRX_experiments/logs', exist_ok=True)

print("✓ Google Drive mounted successfully")
print("✓ Experiment directories created in: /content/drive/MyDrive/SARATRX_experiments/")

## 2. Check GPU Availability

In [None]:
!nvidia-smi
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## 3. Clone Repository and Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/jmaxrdgz/SARATR-X.git
%cd SARATR-X

print("\n✓ Repository cloned successfully")

In [None]:
# Install dependencies
!pip install -q -r requirements.txt
!pip install -q 'urllib3<2.0' kaggle  # For Kaggle dataset download

print("\n✓ Dependencies installed")

## 4. Download Pretrained Weights

Download MAE-HiViT pretrained weights for initialization.

In [None]:
import os
os.makedirs('checkpoints', exist_ok=True)

# Download pretrained MAE-HiViT weights
# Note: Update this URL with the actual location of pretrained weights
!wget -O checkpoints/mae_hivit_base_1600ep.pth https://github.com/zhangxiaosong18/hivit/releases/download/v1.0/mae_hivit_base_1600ep.pth

print("\n✓ Pretrained weights downloaded")

## 5. Setup Kaggle Credentials

To download the Sentinel dataset:
1. Go to https://www.kaggle.com/account
2. Scroll to "API" section
3. Click "Create New API Token" to download `kaggle.json`
4. Upload it in the cell below

In [None]:
from google.colab import files
import os

print("Please upload your kaggle.json file:")
uploaded = files.upload()

# Setup Kaggle credentials
os.makedirs('/root/.kaggle', exist_ok=True)
!cp kaggle.json /root/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json

print("\n✓ Kaggle credentials configured")

## 6. Download Sentinel-1 & Sentinel-2 Dataset from Kaggle

In [None]:
import kaggle
import os

# Create dataset directory
os.makedirs('dataset/sentinel12', exist_ok=True)

# Download the dataset
dataset_name = 'requiemonk/sentinel12-image-pairs-segregated-by-terrain'
print(f"Downloading {dataset_name}...")
print("This may take 5-10 minutes...\n")

kaggle.api.dataset_download_files(
    dataset_name,
    path='dataset/sentinel12',
    unzip=True
)

print("\n✓ Dataset downloaded successfully")

# Check dataset structure
!ls -lh dataset/sentinel12/v_2/

## 7. Preprocess Dataset: Convert PNG to NPY

The model expects `.npy` files but Kaggle provides PNG images. We'll convert them.

In [None]:
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm

def convert_png_to_npy(data_path):
    """Convert PNG images to NPY format for faster loading."""
    data_path = Path(data_path)
    
    # Find all terrain directories
    terrain_dirs = [d for d in data_path.iterdir() if d.is_dir()]
    
    total_converted = 0
    
    for terrain in tqdm(terrain_dirs, desc="Processing terrains"):
        # Process S1 (SAR) images
        s1_dir = terrain / "s1"
        if s1_dir.exists():
            for img_file in s1_dir.glob("*.png"):
                npy_file = img_file.with_suffix('.npy')
                if not npy_file.exists():
                    img = np.array(Image.open(img_file)).astype(np.float32) / 255.0
                    np.save(npy_file, img)
                    total_converted += 1
        
        # Process S2 (Optical) images
        s2_dir = terrain / "s2"
        if s2_dir.exists():
            for img_file in s2_dir.glob("*.png"):
                npy_file = img_file.with_suffix('.npy')
                if not npy_file.exists():
                    img = np.array(Image.open(img_file)).astype(np.float32) / 255.0
                    np.save(npy_file, img)
                    total_converted += 1
    
    return total_converted

# Convert images
print("Converting PNG images to NPY format...")
num_converted = convert_png_to_npy('dataset/sentinel12/v_2')
print(f"\n✓ Converted {num_converted} images to NPY format")

## 8. Verify Dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

data_path = Path('dataset/sentinel12/v_2/')

# Find sample images
terrain_dirs = [d for d in data_path.iterdir() if d.is_dir()]
sample_terrain = terrain_dirs[0]

sar_files = list((sample_terrain / "s1").glob("*.npy"))
opt_files = list((sample_terrain / "s2").glob("*.npy"))

print(f"Dataset statistics:")
print(f"  - Terrain types: {len(terrain_dirs)}")
print(f"  - SAR images: {len(sar_files)} in {sample_terrain.name}")
print(f"  - Optical images: {len(opt_files)} in {sample_terrain.name}")

# Visualize samples
if sar_files and opt_files:
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    
    for i in range(2):
        sar_img = np.load(sar_files[i])
        opt_img = np.load(opt_files[i])
        
        axes[i, 0].imshow(sar_img)
        axes[i, 0].set_title(f'SAR: {sar_files[i].name}')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(opt_img)
        axes[i, 1].set_title(f'Optical: {opt_files[i].name}')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nImage shapes: SAR={sar_img.shape}, Optical={opt_img.shape}")

## 9. Create Experiment Configurations

We'll create 4 different configurations for our experiments.

In [None]:
import yaml
import os

# Base configuration for T4 GPU and free tier constraints
base_config = {
    'model': {
        'name': 'SARATRX',
        'in_chans': 3,
        'mgf_kens': [9, 13, 17],
        'resume': None,
        'norm_pix_loss': False
    },
    'data': {
        'dataset_name': 'sentinel',
        'train_data': 'dataset/sentinel12/v_2',
        'img_size': 256,
        'dataset_std_dev': 1.57,
        'num_workers': 2  # Reduced for Colab
    },
    'train': {
        'init_weights': 'checkpoints/mae_hivit_base_1600ep.pth',
        'seed': 42,
        'n_gpu': 'auto',
        'epochs': 50,  # Reduced for T4 GPU time constraints
        'warmup_epochs': 5,
        'batch_size': 32,  # Reduced from 64 for T4 GPU memory
        'lr': 1.5e-4,
        'weight_decay': 0.05,
        'optimizer_momentum': {
            'beta1': 0.9,
            'beta2': 0.95
        },
        'clip_grad': 5.0
    }
}

# Experiment 1: Pixel-reconstruction SAR-to-SAR
exp1_config = base_config.copy()
exp1_config['model'] = base_config['model'].copy()
exp1_config['model']['target_mode'] = 'optical'  # Will use SAR as both input/target
exp1_config['experiment_name'] = 'exp1_pixel_sar'

# Experiment 2: MGF-reconstruction SAR
exp2_config = base_config.copy()
exp2_config['model'] = base_config['model'].copy()
exp2_config['model']['target_mode'] = 'mgf'
exp2_config['experiment_name'] = 'exp2_mgf_sar'

# Experiment 3: RGB-reconstruction SAR-to-RGB
exp3_config = base_config.copy()
exp3_config['model'] = base_config['model'].copy()
exp3_config['model']['target_mode'] = 'optical'
exp3_config['experiment_name'] = 'exp3_rgb_sar_to_rgb'

# Experiment 4: Greyscale-reconstruction SAR-to-Greyscale
exp4_config = base_config.copy()
exp4_config['model'] = base_config['model'].copy()
exp4_config['model']['target_mode'] = 'optical'  # Will convert optical to greyscale
exp4_config['experiment_name'] = 'exp4_grey_sar_to_grey'
exp4_config['data'] = base_config['data'].copy()
exp4_config['data']['greyscale_target'] = True  # Custom flag

# Save configurations
os.makedirs('config/experiments', exist_ok=True)

configs = {
    'exp1_pixel_sar.yaml': exp1_config,
    'exp2_mgf_sar.yaml': exp2_config,
    'exp3_rgb_sar_to_rgb.yaml': exp3_config,
    'exp4_grey_sar_to_grey.yaml': exp4_config
}

for filename, config in configs.items():
    with open(f'config/experiments/{filename}', 'w') as f:
        yaml.dump(config, f, default_flow_style=False)

print("✓ Created 4 experiment configurations:")
for filename in configs.keys():
    print(f"  - {filename}")

## 10. Create Custom Training Scripts for Experiments

We need custom scripts to handle the different reconstruction targets.

In [None]:
%%writefile train_experiments.py
"""Custom training script for SARATR-X experiments."""
import os
import sys
import yaml
import torch
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

# Import SARATR-X modules
from model.saratrx import SARATRX
from data.data_pretrain import build_loader
import config as config_module

def load_config(config_path):
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)
    
    # Update the global config object
    for key, value in cfg.items():
        if hasattr(config_module.config, key):
            if isinstance(value, dict):
                for subkey, subvalue in value.items():
                    setattr(getattr(config_module.config, key), subkey, subvalue)
            else:
                setattr(config_module.config, key, value)
    
    return cfg

def train_experiment(config_path, checkpoint_dir, log_dir):
    """Train a single experiment."""
    # Load configuration
    cfg = load_config(config_path)
    exp_name = cfg.get('experiment_name', 'experiment')
    
    print(f"\n{'='*60}")
    print(f"Starting Experiment: {exp_name}")
    print(f"{'='*60}\n")
    
    # Build data loader
    print("Building data loader...")
    
    # Handle special case for SAR-to-SAR (Experiment 1)
    if exp_name == 'exp1_pixel_sar':
        # For pixel SAR reconstruction, we use SAR as both input and target
        # We'll modify the dataset to return SAR twice
        from data.dataset_sentinel import SentinelDataset
        from data.data_pretrain import PairedTransform
        from torchvision import transforms
        from torch.utils.data import DataLoader
        
        class SARtoSARDataset(SentinelDataset):
            def __getitem__(self, idx):
                sar_img, _ = super().__getitem__(idx)
                return sar_img, sar_img  # Return SAR as both input and target
        
        base_transform = transforms.Compose([
            transforms.RandomResizedCrop(config_module.config.data.img_size, scale=(0.2, 1.0), interpolation=3),
            transforms.Resize((config_module.config.data.img_size, config_module.config.data.img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(contrast=0.5),
        ])
        paired_transform = PairedTransform(base_transform)
        
        train_dataset = SARtoSARDataset(
            data_path=config_module.config.data.train_data,
            transform=paired_transform
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=config_module.config.train.batch_size,
            shuffle=True,
            num_workers=config_module.config.data.num_workers,
            drop_last=True
        )
    
    # Handle greyscale conversion (Experiment 4)
    elif exp_name == 'exp4_grey_sar_to_grey':
        from data.dataset_sentinel import SentinelDataset
        from data.data_pretrain import PairedTransform
        from torchvision import transforms
        from torch.utils.data import DataLoader
        
        class GreyscaleOpticalDataset(SentinelDataset):
            def __getitem__(self, idx):
                sar_img, opt_img = super().__getitem__(idx)
                # Convert optical to greyscale (weighted average)
                grey_img = 0.299 * opt_img[0:1] + 0.587 * opt_img[1:2] + 0.114 * opt_img[2:3]
                grey_img = grey_img.repeat(3, 1, 1)  # Repeat to 3 channels
                return sar_img, grey_img
        
        base_transform = transforms.Compose([
            transforms.RandomResizedCrop(config_module.config.data.img_size, scale=(0.2, 1.0), interpolation=3),
            transforms.Resize((config_module.config.data.img_size, config_module.config.data.img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(contrast=0.5),
        ])
        paired_transform = PairedTransform(base_transform)
        
        train_dataset = GreyscaleOpticalDataset(
            data_path=config_module.config.data.train_data,
            transform=paired_transform
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=config_module.config.train.batch_size,
            shuffle=True,
            num_workers=config_module.config.data.num_workers,
            drop_last=True
        )
    
    else:
        # Standard data loader for MGF and RGB experiments
        train_loader = build_loader(dataset_name='sentinel')
    
    print(f"  Dataset size: {len(train_loader.dataset)}")
    print(f"  Batch size: {config_module.config.train.batch_size}")
    print(f"  Number of batches: {len(train_loader)}")
    
    # Create model
    print("\nInitializing model...")
    model = SARATRX(
        img_size=config_module.config.data.img_size,
        in_chans=config_module.config.model.in_chans,
        mgf_kens=config_module.config.model.mgf_kens,
        target_mode=config_module.config.model.target_mode,
        norm_pix_loss=config_module.config.model.norm_pix_loss
    )
    
    # Setup checkpointing to Google Drive
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(checkpoint_dir, exp_name),
        filename='epoch{epoch:02d}-loss{train_loss:.4f}',
        monitor='train_loss',
        mode='min',
        save_top_k=3,
        save_last=True,
        every_n_epochs=5,  # Save every 5 epochs
        auto_insert_metric_name=False
    )
    
    # Setup logging
    logger = TensorBoardLogger(
        save_dir=log_dir,
        name=exp_name
    )
    
    # Create trainer
    trainer = L.Trainer(
        max_epochs=config_module.config.train.epochs,
        accelerator='auto',
        devices=1,
        precision='16-mixed',  # Mixed precision for T4 GPU
        gradient_clip_val=config_module.config.train.clip_grad,
        callbacks=[checkpoint_callback],
        logger=logger,
        log_every_n_steps=10,
        enable_progress_bar=True
    )
    
    # Train
    print("\nStarting training...")
    trainer.fit(model, train_loader)
    
    print(f"\n✓ Experiment {exp_name} completed!")
    print(f"  Checkpoints saved to: {os.path.join(checkpoint_dir, exp_name)}")
    print(f"  Logs saved to: {os.path.join(log_dir, exp_name)}")
    
    return exp_name

if __name__ == '__main__':
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='Path to config file')
    parser.add_argument('--checkpoint_dir', type=str, default='/content/drive/MyDrive/SARATRX_experiments/checkpoints')
    parser.add_argument('--log_dir', type=str, default='/content/drive/MyDrive/SARATRX_experiments/logs')
    
    args = parser.parse_args()
    
    train_experiment(args.config, args.checkpoint_dir, args.log_dir)

print("✓ Training script created")

## 11. Run Training Experiments

Now we'll run each experiment sequentially. Each experiment will:
- Train for 50 epochs (reduced for T4 GPU constraints)
- Save checkpoints every 5 epochs to Google Drive
- Log training metrics to TensorBoard

**Note**: Due to the 1.5 hour free tier limit, you may need to run experiments across multiple sessions. Checkpoints are saved to Google Drive and persist after the session ends.

### Experiment 1: Pixel-Reconstruction (SAR → SAR)

In [None]:
!python train_experiments.py \
    --config config/experiments/exp1_pixel_sar.yaml \
    --checkpoint_dir /content/drive/MyDrive/SARATRX_experiments/checkpoints \
    --log_dir /content/drive/MyDrive/SARATRX_experiments/logs

### Experiment 2: MGF-Reconstruction (SAR → Multi-scale Gradient Features)

In [None]:
!python train_experiments.py \
    --config config/experiments/exp2_mgf_sar.yaml \
    --checkpoint_dir /content/drive/MyDrive/SARATRX_experiments/checkpoints \
    --log_dir /content/drive/MyDrive/SARATRX_experiments/logs

### Experiment 3: RGB-Reconstruction (SAR → RGB Optical)

In [None]:
!python train_experiments.py \
    --config config/experiments/exp3_rgb_sar_to_rgb.yaml \
    --checkpoint_dir /content/drive/MyDrive/SARATRX_experiments/checkpoints \
    --log_dir /content/drive/MyDrive/SARATRX_experiments/logs

### Experiment 4: Greyscale-Reconstruction (SAR → Greyscale Optical)

In [None]:
!python train_experiments.py \
    --config config/experiments/exp4_grey_sar_to_grey.yaml \
    --checkpoint_dir /content/drive/MyDrive/SARATRX_experiments/checkpoints \
    --log_dir /content/drive/MyDrive/SARATRX_experiments/logs

## 12. Monitor Training with TensorBoard

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/SARATRX_experiments/logs

## 13. View Experiment Results

After training, checkpoints and logs are saved to:
```
/content/drive/MyDrive/SARATRX_experiments/
├── checkpoints/
│   ├── exp1_pixel_sar/
│   ├── exp2_mgf_sar/
│   ├── exp3_rgb_sar_to_rgb/
│   └── exp4_grey_sar_to_grey/
└── logs/
    ├── exp1_pixel_sar/
    ├── exp2_mgf_sar/
    ├── exp3_rgb_sar_to_rgb/
    └── exp4_grey_sar_to_grey/
```

In [None]:
import os

checkpoint_dir = '/content/drive/MyDrive/SARATRX_experiments/checkpoints'

print("Experiment Results:\n")
for exp in ['exp1_pixel_sar', 'exp2_mgf_sar', 'exp3_rgb_sar_to_rgb', 'exp4_grey_sar_to_grey']:
    exp_path = os.path.join(checkpoint_dir, exp)
    if os.path.exists(exp_path):
        ckpt_files = [f for f in os.listdir(exp_path) if f.endswith('.ckpt')]
        print(f"✓ {exp}:")
        print(f"  Checkpoints: {len(ckpt_files)}")
        if ckpt_files:
            for ckpt in sorted(ckpt_files)[-3:]:  # Show last 3
                size_mb = os.path.getsize(os.path.join(exp_path, ckpt)) / (1024*1024)
                print(f"    - {ckpt} ({size_mb:.1f} MB)")
    else:
        print(f"✗ {exp}: Not started")
    print()

## 14. Resume Training (if session expires)

If your Colab session expires before training completes, you can resume from the last checkpoint:

In [None]:
# To resume training from a checkpoint, modify the experiment config:
import yaml

# Example: Resume experiment 1
exp_name = 'exp1_pixel_sar'
config_path = f'config/experiments/{exp_name}.yaml'
checkpoint_path = f'/content/drive/MyDrive/SARATRX_experiments/checkpoints/{exp_name}/last.ckpt'

# Load config
with open(config_path, 'r') as f:
    cfg = yaml.safe_load(f)

# Update resume path
cfg['model']['resume'] = checkpoint_path

# Save updated config
with open(config_path, 'w') as f:
    yaml.dump(cfg, f)

print(f"Updated config to resume from: {checkpoint_path}")
print("\nNow run the experiment cell again to resume training.")

## Summary

This notebook:
1. ✓ Mounts Google Drive for persistent storage
2. ✓ Clones SARATR-X repository
3. ✓ Downloads Sentinel-1&2 dataset from Kaggle
4. ✓ Converts images to NPY format
5. ✓ Creates 4 experiment configurations:
   - Pixel-reconstruction (SAR → SAR)
   - MGF-reconstruction (SAR → MGF)
   - RGB-reconstruction (SAR → RGB)
   - Greyscale-reconstruction (SAR → Greyscale)
6. ✓ Trains each experiment with:
   - T4 GPU optimization (batch size 32, mixed precision)
   - Regular checkpointing to Google Drive
   - TensorBoard logging
7. ✓ Handles session expiry with resume capability

All results persist in Google Drive after session ends!