# AutoHDR Kaggle Training Notebook

Trains an edge-aware MicroUNet model for automatic lens correction.

If `/kaggle/input` has no mounted data, this notebook falls back to downloading the competition zip into `/kaggle/working`.


In [None]:
import os
import sys
import json
import time
import glob
import zipfile
import argparse
import subprocess
import shutil
from datetime import datetime, timezone
from typing import Optional, Tuple, List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

# Optional token override; leave empty unless explicitly injected via notebook environment.
FALLBACK_KAGGLE_TOKEN = os.getenv("KAGGLE_API_TOKEN", "").strip()


def candidate_roots(root_hint: str) -> List[str]:
    roots: List[str] = []

    def add(p: str):
        if p and os.path.isdir(p) and p not in roots:
            roots.append(p)

    add(root_hint)
    add('/kaggle/input')

    if os.path.isdir('/kaggle/input'):
        for name in sorted(os.listdir('/kaggle/input')):
            add(os.path.join('/kaggle/input', name))

    return roots


def extract_zip_if_needed(root: str) -> str:
    zip_files = []
    for dirpath, _, filenames in os.walk(root):
        for fname in filenames:
            if fname.lower().endswith('.zip'):
                zip_files.append(os.path.join(dirpath, fname))

    if not zip_files:
        return root

    extract_root = '/kaggle/working/autohdr_input_cache'
    os.makedirs(extract_root, exist_ok=True)

    extracted_any = False
    for zpath in zip_files[:8]:
        target = os.path.join(extract_root, os.path.splitext(os.path.basename(zpath))[0])
        if os.path.isdir(target) and any(os.scandir(target)):
            continue
        print(f'[auto-detect] Extracting {zpath} -> {target}')
        os.makedirs(target, exist_ok=True)
        try:
            with zipfile.ZipFile(zpath, 'r') as zf:
                zf.extractall(target)
            extracted_any = True
        except Exception as exc:
            print(f'[auto-detect] Skipping unreadable zip {zpath}: {exc}')

    return extract_root if extracted_any else root


def find_dataset_paths(root_hint: str) -> Tuple[str, Optional[str]]:
    roots = candidate_roots(root_hint)
    print('Candidate roots:')
    for r in roots:
        print(f'  - {r}')

    for base in roots:
        train_dir = None
        test_dir = None

        for root in [base, extract_zip_if_needed(base)]:
            for dirpath, _, filenames in os.walk(root):
                base_name = os.path.basename(dirpath)
                if base_name == 'lens-correction-train-cleaned':
                    train_dir = dirpath
                elif base_name == 'test-originals':
                    test_dir = dirpath
                if train_dir and test_dir:
                    return train_dir, test_dir

            if not train_dir:
                for dirpath, _, filenames in os.walk(root):
                    originals = [f for f in filenames if f.endswith('_original.jpg')]
                    if len(originals) > 100:
                        train_dir = dirpath
                        print(f'[auto-detect] Found {len(originals)} training originals in {dirpath}')
                        break

            if not test_dir:
                for dirpath, _, filenames in os.walk(root):
                    jpgs = [f for f in filenames if f.endswith('.jpg')]
                    if 500 < len(jpgs) < 5000:
                        has_pairs = any(
                            f.endswith('_original.jpg') or f.endswith('_generated.jpg')
                            for f in jpgs[:200]
                        )
                        if not has_pairs:
                            test_dir = dirpath
                            print(f'[auto-detect] Found {len(jpgs)} test images in {dirpath}')
                            break

            if train_dir:
                return train_dir, test_dir

    raise RuntimeError('Could not find training data under scanned roots.')


def ensure_kaggle_cli_data(download_root: str, max_pairs: int = 3000) -> str:
    os.makedirs(download_root, exist_ok=True)

    token = os.getenv('KAGGLE_API_TOKEN', '').strip() or FALLBACK_KAGGLE_TOKEN

    # Configure credentials when an explicit token is available.
    kaggle_dir = '/root/.kaggle'
    os.makedirs(kaggle_dir, exist_ok=True)
    if token:
        with open(os.path.join(kaggle_dir, 'access_token'), 'w') as f:
            f.write(token)
        with open(os.path.join(kaggle_dir, 'kaggle.json'), 'w') as f:
            f.write(json.dumps({'username': 'token', 'key': token.replace('KGAT_', '')}))
        os.chmod(os.path.join(kaggle_dir, 'access_token'), 0o600)
        os.chmod(os.path.join(kaggle_dir, 'kaggle.json'), 0o600)

    if shutil.which('kaggle') is None:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'kaggle'], check=False)

    try:
        from kaggle.api.kaggle_api_extended import KaggleApi
    except Exception:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'kaggle'], check=True)
        from kaggle.api.kaggle_api_extended import KaggleApi

    api = KaggleApi()
    try:
        api.authenticate()
    except Exception as exc:
        message = (
            "Fallback Kaggle API auth failed. "
            "Attach competition data source or set KAGGLE_API_TOKEN.\n"
            f"Auth error: {exc}"
        )
        raise RuntimeError(message)

    train_rel = 'lens-correction-train-cleaned'
    train_dir = os.path.join(download_root, train_rel)
    os.makedirs(train_dir, exist_ok=True)

    existing_originals = [f for f in os.listdir(train_dir) if f.endswith('_original.jpg')]
    if len(existing_originals) >= max_pairs:
        print(f'[fallback] Using existing cached training files: {len(existing_originals)} originals')
        return download_root

    print('[fallback] Listing competition files...')
    files_response = api.competition_list_files('automatic-lens-correction')
    if hasattr(files_response, 'files'):
        files_iter = files_response.files
    elif isinstance(files_response, list):
        files_iter = files_response
    else:
        try:
            files_iter = list(files_response)
        except TypeError as exc:
            raise RuntimeError(f'Unexpected list-files response type: {type(files_response)}') from exc

    originals = []
    for item in files_iter:
        name = getattr(item, 'name', None)
        if name is None and isinstance(item, dict):
            name = item.get('name')
        if not name:
            continue
        if name.startswith(train_rel + '/') and name.endswith('_original.jpg'):
            originals.append(name)
    originals.sort()

    if not originals:
        raise RuntimeError('No training originals found in competition file listing.')

    selected_originals = originals[: max(1, int(max_pairs))]
    required = set(selected_originals)
    required.update(name.replace('_original.jpg', '_generated.jpg') for name in selected_originals)

    to_download = []
    for rel_path in sorted(required):
        abs_path = os.path.join(download_root, rel_path)
        if not os.path.exists(abs_path):
            to_download.append(rel_path)

    print(
        f'[fallback] Downloading {len(to_download)} files '
        f'for {len(selected_originals)} training pairs...'
    )

    for i, rel_path in enumerate(to_download, start=1):
        local_dir = os.path.join(download_root, os.path.dirname(rel_path))
        os.makedirs(local_dir, exist_ok=True)
        file_name = os.path.basename(rel_path)

        api.competition_download_file(
            competition='automatic-lens-correction',
            file_name=rel_path,
            path=local_dir,
            quiet=True,
        )

        downloaded_path = os.path.join(local_dir, file_name)
        zip_candidate = downloaded_path + '.zip'
        if os.path.exists(zip_candidate):
            with zipfile.ZipFile(zip_candidate, 'r') as zf:
                zf.extractall(local_dir)
            os.remove(zip_candidate)

        if i % 200 == 0 or i == len(to_download):
            print(f'[fallback] Downloaded {i}/{len(to_download)} files')

    return download_root


class AutoHDRDataset(Dataset):
    def __init__(self, samples: List[dict], mode: str, transform=None):
        self.samples = samples
        self.mode = mode
        self.transform = transform

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> dict:
        item = self.samples[idx]
        original = Image.open(item['original']).convert('RGB')
        if self.transform:
            original = self.transform(original)
        out = {'original': original}
        if self.mode in ('train', 'val'):
            generated = Image.open(item['generated']).convert('RGB')
            if self.transform:
                generated = self.transform(generated)
            out['generated'] = generated
        else:
            out['filename'] = item['filename']
        return out


def build_sample_lists(train_dir: str, test_dir: Optional[str], max_train=None, max_val=None):
    originals = sorted(glob.glob(os.path.join(train_dir, '*_original.jpg')))
    pairs = []
    for op in originals:
        gp = op.replace('_original.jpg', '_generated.jpg')
        if os.path.exists(gp):
            pairs.append({'original': op, 'generated': gp})

    split = int(len(pairs) * 0.95)
    train_samples = pairs[:split]
    val_samples = pairs[split:]

    if max_train:
        train_samples = train_samples[:max_train]
    if max_val:
        val_samples = val_samples[:max_val]

    test_samples = []
    if test_dir and os.path.isdir(test_dir):
        for f in sorted(os.listdir(test_dir)):
            if f.endswith('.jpg'):
                test_samples.append({'original': os.path.join(test_dir, f), 'filename': f})

    return train_samples, val_samples, test_samples


def make_dataloaders(train_dir: str, test_dir: Optional[str], batch_size=8, img_size=256, num_workers=2, max_train=None, max_val=None):
    train_tfm = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02),
        transforms.ToTensor(),
    ])
    eval_tfm = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])

    train_samples, val_samples, test_samples = build_sample_lists(train_dir, test_dir, max_train=max_train, max_val=max_val)
    print(f'Train samples: {len(train_samples):,}')
    print(f'Val samples: {len(val_samples):,}')
    print(f'Test samples: {len(test_samples):,}')

    kwargs = {
        'batch_size': batch_size,
        'num_workers': num_workers,
        'pin_memory': torch.cuda.is_available(),
    }

    train_loader = DataLoader(AutoHDRDataset(train_samples, 'train', train_tfm), shuffle=True, **kwargs)
    val_loader = DataLoader(AutoHDRDataset(val_samples, 'val', eval_tfm), shuffle=False, **kwargs)
    return train_loader, val_loader


class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class MicroUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, 16)
        self.enc2 = ConvBlock(16, 32)
        self.enc3 = ConvBlock(32, 64)
        self.enc4 = ConvBlock(64, 128)
        self.bottleneck = ConvBlock(128, 256)
        self.up4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(256, 128, 3, padding=1))
        self.dec4 = ConvBlock(256, 128)
        self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(128, 64, 3, padding=1))
        self.dec3 = ConvBlock(128, 64)
        self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(64, 32, 3, padding=1))
        self.dec2 = ConvBlock(64, 32)
        self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(32, 16, 3, padding=1))
        self.dec1 = ConvBlock(32, 16)
        self.out_conv = nn.Conv2d(16, out_channels, 1)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))
        d4 = self.dec4(torch.cat([self.up4(b), e4], 1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        return torch.sigmoid(self.out_conv(d1))


class SobelFilter(nn.Module):
    def __init__(self):
        super().__init__()
        gx = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]) / 4.0
        gy = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]) / 4.0
        self.gx = gx.view(1, 1, 3, 3)
        self.gy = gy.view(1, 1, 3, 3)

    def forward(self, img):
        gx = self.gx.to(img.device)
        gy = self.gy.to(img.device)
        gray = 0.2989 * img[:, 0:1] + 0.5870 * img[:, 1:2] + 0.1140 * img[:, 2:3]
        grad_x = nn.functional.conv2d(gray, gx, padding=1)
        grad_y = nn.functional.conv2d(gray, gy, padding=1)
        return torch.sqrt(grad_x**2 + grad_y**2 + 1e-6)


class CombinedLoss(nn.Module):
    def __init__(self, l1_weight=1.0, edge_weight=0.5):
        super().__init__()
        self.l1 = nn.L1Loss()
        self.sobel = SobelFilter()
        self.l1_weight = l1_weight
        self.edge_weight = edge_weight

    def forward(self, pred, target):
        pixel = self.l1(pred, target)
        edge = self.l1(self.sobel(pred), self.sobel(target))
        return self.l1_weight * pixel + self.edge_weight * edge


def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')


def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    total = 0.0
    n = 0
    for batch in loader:
        x = batch['original'].to(device, non_blocking=True)
        y = batch['generated'].to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            with torch.amp.autocast('cuda'):
                pred = model(x)
                loss = criterion(pred, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
        total += loss.item()
        n += 1
    return total / max(n, 1)


@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    n = 0
    for batch in loader:
        x = batch['original'].to(device, non_blocking=True)
        y = batch['generated'].to(device, non_blocking=True)
        pred = model(x)
        loss = criterion(pred, y)
        mae_255 = (pred - y).abs().mean().item() * 255.0
        total_loss += loss.item()
        total_mae += mae_255
        n += 1
    return total_loss / max(n, 1), total_mae / max(n, 1)


def main():
    parser = argparse.ArgumentParser(description='AutoHDR Kaggle MicroUNet training')
    parser.add_argument('--data-root', default='/kaggle/input')
    parser.add_argument('--output-dir', default='/kaggle/working')
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--img-size', type=int, default=256)
    parser.add_argument('--num-workers', type=int, default=2)
    parser.add_argument('--max-train', type=int, default=None)
    parser.add_argument('--max-val', type=int, default=None)
    parser.add_argument('--save-every', type=int, default=2)
    parser.add_argument('--edge-weight', type=float, default=0.5)
    parser.add_argument('--no-amp', action='store_true')
    parser.add_argument('--fallback-max-pairs', type=int, default=3000)
    args, _ = parser.parse_known_args()

    os.makedirs(args.output_dir, exist_ok=True)

    device = get_device()
    print(f'Device: {device}')
    print(f'PyTorch: {torch.__version__}')
    if device.type == 'cuda':
        print(f'GPU: {torch.cuda.get_device_name(0)}')

    print('Locating dataset...')
    try:
        train_dir, test_dir = find_dataset_paths(args.data_root)
    except Exception as exc:
        print(f'[fallback] Primary discovery failed: {exc}')
        dl_root = ensure_kaggle_cli_data(
            '/kaggle/working/autohdr_competition_data',
            max_pairs=args.fallback_max_pairs,
        )
        train_dir, test_dir = find_dataset_paths(dl_root)

    print(f'Train dir: {train_dir}')
    print(f'Test dir: {test_dir}')

    train_loader, val_loader = make_dataloaders(
        train_dir=train_dir,
        test_dir=test_dir,
        batch_size=args.batch_size,
        img_size=args.img_size,
        num_workers=args.num_workers,
        max_train=args.max_train,
        max_val=args.max_val,
    )

    model = MicroUNet().to(device)
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Model: MicroUNet ({params:,} params)')

    criterion = CombinedLoss(l1_weight=1.0, edge_weight=args.edge_weight).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    use_amp = device.type == 'cuda' and not args.no_amp
    scaler = torch.amp.GradScaler('cuda') if use_amp else None
    if use_amp:
        print('Mixed precision: ON')

    history = []
    best_val_loss = float('inf')

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler=scaler)
        val_loss, val_mae = validate(model, val_loader, criterion, device)
        scheduler.step()
        lr = optimizer.param_groups[0]['lr']
        elapsed = time.time() - t0

        improved = val_loss < best_val_loss
        marker = ' ★' if improved else ''
        print(
            f'Epoch {epoch:3d}/{args.epochs} | '
            f'Train: {train_loss:.5f} | '
            f'Val: {val_loss:.5f} | '
            f'MAE(255): {val_mae:.2f} | '
            f'LR: {lr:.6f} | '
            f'{elapsed:.1f}s{marker}'
        )

        entry = {
            'epoch': epoch,
            'train_loss': round(train_loss, 6),
            'val_loss': round(val_loss, 6),
            'val_mae_255': round(val_mae, 4),
            'lr': round(lr, 10),
            'time_s': round(elapsed, 3),
            'timestamp_utc': datetime.now(timezone.utc).isoformat(),
        }
        history.append(entry)

        ckpt = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_mae_255': val_mae,
            'model_name': 'micro_unet',
            'img_size': args.img_size,
            'normalize': False,
            'created_at': datetime.now(timezone.utc).isoformat(),
            'args': vars(args),
        }

        if improved:
            best_val_loss = val_loss
            torch.save(ckpt, os.path.join(args.output_dir, 'best_model.pt'))

        if epoch % max(args.save_every, 1) == 0:
            torch.save(ckpt, os.path.join(args.output_dir, f'checkpoint_epoch{epoch}.pt'))

    history_path = os.path.join(args.output_dir, 'training_history.json')
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)

    print('=' * 70)
    print(f'Training complete. Best val loss: {best_val_loss:.6f}')
    print(f'Artifacts in: {args.output_dir}')
    print('=' * 70)


if __name__ == '__main__':
    main()
