In [None]:
# ---------------------------------------------
# IMPORTS
# Core PyTorch + numerical and imaging utilities for Gaussian Splatting demo.
# ---------------------------------------------
import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math
import numpy as np
from torchvision import transforms
from PIL import Image
from MS_SSIM_L1_loss import MS_SSIM_L1_LOSS

In [None]:
# ---------------------------------------------
# DEVICE SELECTION
# Select GPU if available for faster splatting + optimization.
# ---------------------------------------------
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

device

# Loading data
We load a reference target image from a remote URL and convert it to a normalized 100x100 grayscale tensor. This target acts as supervision for our Gaussian parameter optimization. Keeping dimensions small accelerates iterations and illustrates the principle without heavy compute.

In [None]:
# ---------------------------------------------
# DATA LOADING FROM URL
# Robustly resolve ibb.co page URLs to a direct image.
# Tries common meta tags (og:image, twitter:image) and <img src> fallbacks.
# ---------------------------------------------
import re
import requests
from io import BytesIO

image_url = 'https://ibb.co/wNd95BZn'  # Source page hosting the image

def _first_match(html: str, patterns):
    for pat in patterns:
        m = re.search(pat, html, flags=re.IGNORECASE)
        if m:
            # Return first non-empty captured group
            for g in m.groups():
                if g:
                    return g
    return None

def fetch_image_from_url(url: str) -> Image.Image:
    resp = requests.get(url, timeout=30)
    resp.raise_for_status()
    content_type = resp.headers.get('content-type', '').lower()
    # Case 1: Direct image URL
    if 'image' in content_type:
        return Image.open(BytesIO(resp.content))
    
    # Case 2: HTML page (e.g., ibb.co) -> parse for a direct image
    html = resp.text
    candidates = []
    # Prefer Open Graph image first
    candidates.append(_first_match(html, [
        r'<meta[^>]+property=["\']og:image["\'][^>]+content=["\']([^"\']+)["\']',
        r'<meta[^>]+content=["\']([^"\']+)["\'][^>]+property=["\']og:image["\']',
    ]))
    # Try Twitter card image
    candidates.append(_first_match(html, [
        r'<meta[^>]+name=["\']twitter:image["\'][^>]+content=["\']([^"\']+)["\']',
        r'<meta[^>]+content=["\']([^"\']+)["\'][^>]+name=["\']twitter:image["\']',
    ]))
    # Fallback: any <img> referencing i.ibb.co (the direct CDN)
    candidates.append(_first_match(html, [
        r'<img[^>]+src=["\'](https?://i\.ibb\.co/[^"\']+)["\']',
        r'<img[^>]+data-src=["\'](https?://i\.ibb\.co/[^"\']+)["\']',
    ]))
    
    direct_img_url = next((c for c in candidates if c), None)
    if not direct_img_url:
        raise RuntimeError('Could not resolve a direct image URL from the provided page.')
    
    img_resp = requests.get(direct_img_url, timeout=30)
    img_resp.raise_for_status()
    if 'image' not in img_resp.headers.get('content-type','').lower():
        raise RuntimeError('Resolved URL did not return an image content-type.')
    return Image.open(BytesIO(img_resp.content))

# Load the image
image = fetch_image_from_url(image_url)  # PIL Image

# Transform to grayscale 100x100 tensor for optimization
transform = transforms.Compose([
    transforms.Grayscale(),       # Convert RGB/RGBA -> 1 channel
    transforms.Resize((100,100)), # Standard resolution for this demo
    transforms.ToTensor(),        # -> [C,H,W] float in [0,1]
])
target_image = transform(image)  # Shape: [1,100,100]

# Quick visualization of target
plt.figure(figsize=(3, 3))
plt.imshow(target_image.squeeze(), cmap='gray')
plt.axis('off')
plt.show()

# Ejercicio 1: Motor de render con Gaussian Splatting
En este práctico vas a implementar partes clave del motor de render 2D basado en la suma aditiva de gaussianas anisotrópicas.

Cada gaussiana ("splat") tiene:
- Centro `(x, y)` en espacio normalizado `[0,1]^2`.
- Escalas `(sigma_x, sigma_y)` que controlan el spread principal.
- Rotación `theta` (radianes) que orienta la anisotropía.
- Alpha (peso de contribución u "opacidad" relativa).

La matriz de covarianza combina escalas + rotación para definir el óvalo (elipse) de cada gaussiana. El render final surge de evaluar las PDFs multivariadas sobre una grilla de píxeles y acumularlas.

Tu objetivo hoy:
1. Completar `build_covariance_matrix` para obtener `cov_mat = R @ S @ S^T @ R^T`.
2. Completar `create_gaussian_image`:
   - Construir matrices de rotación.
   - Crear grilla de coordenadas y vectorizar evaluación.
   - Instanciar la distribución multivariada y acumular contribuciones.
   - (Opcional) Normalizar a `[0,1]` al final.
3. Completar el loop de entrenamiento con cálculo de pérdidas, backward y snapshots.

Recomendaciones:
- Evitá loops Python cuando puedas vectorizar (excepto el loop por gaussiana ya provisto si decidís mantenerlo para claridad).
- Usá `torch.distributions.MultivariateNormal` correctamente (mean en coordenadas de píxel, covarianza 2x2).
- Revisá que no haya gradientes bloqueados (tensores deben tener `requires_grad` donde corresponde).

Cuando termines cada sección, probá render inicial antes de pasar al entrenamiento.
¡Adelante!

In [None]:
# ---------------------------------------------
# GAUSSIAN PRIMITIVE + RENDERING UTILITIES (EJERCICIO)
# En esta versión, varias partes están como "placeholders" para completar.
# Líneas clave quedan pre-completadas; seguí las instrucciones TODO.
# ---------------------------------------------


def build_covariance_matrix(S, R):
    """
    EJERCICIO: construir la matriz de covarianza 2x2 por gaussiana a partir de
    escalas diagonales y rotación.

    Pistas:
    - Usá torch.diag_embed para pasar de [N,2] -> [N,2,2]
    - Recordá que, si S es diagonal, S @ S^T = diag(sigma_x^2, sigma_y^2)
    - Fórmula objetivo: cov = R @ S @ S^T @ R^T
    - Devolvé un tensor [N,2,2]
    """
   
    raise NotImplementedError("Completar: calcular 'cov_mat' y retornarlo")
    # return cov_mat



def create_gaussian_image(centers, scales, rotations, alphas, image_size):
    """
    EJERCICIO: renderizar una imagen 2D como suma de gaussianas anisotrópicas.
    Completá los pasos marcados con TODO. Retorna un tensor [1, H, W].
    """
    # Pre-completado: inicializar imagen en el dispositivo seleccionado
    # Inicialización con requires_grad=True para que el optimizador pueda backpropagar
    image = torch.zeros(image_size, image_size, requires_grad=True, dtype=torch.float32).to(device)

    # Pre-completado: transforms para asegurar rangos válidos
    adjusted_centers = centers               # Se asume en [0,1]
    adjusted_scales  = F.softplus(scales)    # Softplus -> escalas positivas
    adjusted_alphas  = F.sigmoid(alphas)     # Sigmoid -> pesos en [0,1]

    # TODO: construir matrices de rotación 2x2 a partir del ángulo 'rotations' (en radianes)
    

    # Mantener esta línea: compone escala + rotación en la covarianza
    cov_matrices = build_covariance_matrix(adjusted_scales, rotation_matrices).to(device)

    # TODO: crear la grilla de coordenadas de píxeles y vectorizar
    # Pistas:
   
    # Mantener: encabezado del loop para acumular contribuciones de cada gaussiana
    for center, cov_mat, alpha in zip(adjusted_centers, cov_matrices, adjusted_alphas):
        # TODO:
        # 1) Definir MultivariateNormal con mean=(center * image_size) y covariance_matrix=cov_mat
        # 2) Evaluar log_prob en 'coords' y convertir a probs con torch.exp(...)
        # 3) Dar forma [H,W] y sumar a 'image' con: image = image + (alpha * probs)
        raise NotImplementedError("Completar: evaluación de gaussianas y composición aditiva")

    # TODO normalizar la imagen a [0,1] para estabilizar la escala del loss
    # image = image / image.max()

    return image.unsqueeze(0)  # Shape: [1, H, W]


# Optimizer loop
We initialize N Gaussian primitives with random parameters. During training we minimize a hybrid loss:
- L1 (pixel-wise) to encourage accurate intensity reconstruction.
- MS-SSIM-based perceptual component for structural fidelity.
Parameters optimized: centers, scales (through softplus), alphas (sigmoid), rotations (angle). Saving intermediate renders every 100 iterations illustrates convergence dynamics.

In [None]:
# ---------------------------------------------
# PARAMETER INITIALIZATION
# Randomly initialize Gaussian parameters. For teaching:
# - centers in [0,1]^2 (multiplied by image_size inside render)
# - scales (pre-softplus) roughly positive after transform
# - alphas random then sigmoid -> blending weight
# - rotations in [0,1] mapped directly to radians (could scale to 2π)
# ---------------------------------------------
gcount = 500            # Number of Gaussian splats
image_size = 100        # Output resolution (H= W= image_size)

isotropic_scales_init_x = torch.rand(gcount, requires_grad=True, device=device)
isotropic_scales_init_y = isotropic_scales_init_x  # Start isotropic; will diverge through optimization

# Wrap trainable tensors in Parameters (double for higher precision if desired)
centers   = torch.nn.Parameter(torch.rand((gcount, 2), requires_grad=True, device=device).double())
scales    = torch.nn.Parameter(10*torch.stack((isotropic_scales_init_x,isotropic_scales_init_y),dim=1).double())
alphas    = torch.nn.Parameter(torch.randn(gcount, 1, requires_grad=True, device=device).double())
rotations = torch.nn.Parameter(torch.rand(gcount, requires_grad=True, device=device).double())

In [None]:
# ---------------------------------------------
# INITIAL RENDER (BEFORE OPTIMIZATION)
# Useful baseline to visualize random splat configuration.
# ---------------------------------------------
gaussian_image = create_gaussian_image(centers, scales, rotations, alphas, image_size)
gaussian_image_init = gaussian_image.clone()  # Preserve initial state for later comparison

plt.figure(figsize=(3, 3))
plt.imshow(gaussian_image.squeeze().detach().cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()

In [None]:
# ---------------------------------------------
# OPTIMIZER + LOSS SETUP
# Adam optimizer over all Gaussian parameters.
# Loss mixing: majority weight on L1, minority on MS-SSIM perceptual term.
# ---------------------------------------------
optimizer = optim.Adam([centers, scales, alphas, rotations], lr=0.01)
num_iterations = 2500  # Feel free to reduce for classroom demos

# L1 loss: pixel-wise absolute difference. Encourages exact intensity matching
# and is robust to outliers compared to L2. Provides a strong local reconstruction signal.
criterion_L1 = torch.nn.L1Loss()

# MS_SSIM_L1_LOSS: custom module combining multi-scale structural similarity (MS-SSIM)
# with a smoothed L1 component. MS-SSIM captures perceptual/structural fidelity
# (edges, contrast relationships) across scales, complementing the purely local
# pixel-wise L1 term. Its output is already a loss (lower is better).
# Expects input tensors of shape [B, C, H, W] on the same device.
# See: https://en.wikipedia.org/wiki/Structural_similarity_index_measure 
criterion_SSIM = MS_SSIM_L1_LOSS(device=device)  # [B,C,H,W]

## Ejercicio 2: Loop de Entrenamiento

Completar la celda del loop para optimizar los parámetros de las gaussianas. El objetivo es minimizar una pérdida híbrida que combine reconstrucción (L1) y percepción/estructura (MS-SSIM) sobre la imagen objetivo.

### Pasos obligatorios
1. Calcular `loss_l1`:
   - Usar `criterion_L1(gaussian_image, target_image)` (ajustar `squeeze()` y `.to(device)` según corresponda).
2. Calcular `loss_ssim`:
   - Usar `criterion_SSIM` con tensores `[B, C, H, W]` (agregar batch dim con `unsqueeze(0)` si hace falta).
3. Combinar pérdidas:
   - Sugerido: `loss = 0.8 * loss_l1 + 0.2 * loss_ssim` (podés experimentar otros pesos y discutirlo luego).
4. Backprop y actualización:
   - `loss.backward()` seguido de `optimizer.step()`.
5. Snapshots periódicos (cada 100 iteraciones, configurable):
   - Crear carpeta `status/` si no existe.
   - Guardar la imagen: `plt.imsave('status/iter_XXXX.png', ...)` usando tensor procesado (`detach()`, `clamp(0,1)`, `cpu()`).
6. Logging:
   - `print` de iteración, `loss` total y componentes (`loss_l1`, `loss_ssim`).
7. Mientras falten implementaciones mantener el `raise NotImplementedError` para evitar correr entrenamiento incompleto.

### Extensiones opcionales
- Scheduler de LR (`torch.optim.lr_scheduler`).
- Gradient clipping (`torch.nn.utils.clip_grad_norm_`).
- Early stopping por estabilización de la pérdida.
- Métrica extra: PSNR o MSE para comparación cuantitativa.
- Guardar GIF de convergencia (ej. con `imageio`).

### Checklist antes de entrenar
- [ ] `build_covariance_matrix` completo.
- [ ] `create_gaussian_image` completo.
- [ ] Loop con pérdidas, backward, step, snapshots y logging.

Cuando todo funcione: retirar el `raise NotImplementedError`.


In [None]:
# ---------------------------------------------
# TRAINING LOOP (EJERCICIO)
# Varias partes fueron reemplazadas por TODOs para que las completes.
# Se preservan llamadas clave y estructura general.
# ---------------------------------------------
for i in range(num_iterations):
    optimizer.zero_grad()  # Resetear gradientes acumulados

    # Render actual con los parámetros corrientes
    gaussian_image = create_gaussian_image(centers, scales, rotations, alphas, image_size)

    # TODO: calcular la pérdida L1 entre gaussian_image y target_image

    # TODO: calcular la pérdida perceptual MS-SSIM usando criterion_SSIM

    # TODO: combinar pérdidas en variable 'loss' (por ejemplo 0.8 * loss_l1 + 0.2 * loss_ssim)

    # TODO: backward() para propagar gradientes y luego optimizer.step() para actualizar parámetros

    # TODO: cada x iteraciones guardar snapshot en carpeta 'status' (crear si no existe)

    # TODO: imprimir métricas de la iteración (loss total y componentes)

    raise NotImplementedError("Completar loop de entrenamiento: cálculo de pérdidas, backward, snapshots y logging")

In [None]:
# ---------------------------------------------
# FINAL RENDER AFTER TRAINING
# ---------------------------------------------
gaussian_image = create_gaussian_image(centers, scales, rotations, alphas, image_size)
plt.figure(figsize=(3, 3))
plt.imshow(gaussian_image.squeeze().detach().cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()

In [None]:
# Target reference image (supervision)
plt.figure(figsize=(3, 3))
plt.imshow(target_image.squeeze(), cmap='gray')
plt.axis('off')
plt.show()

In [None]:
# Initial random configuration (baseline)
plt.figure(figsize=(3, 3))
plt.imshow(gaussian_image_init.squeeze().detach().cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()

In [None]:
# ---------------------------------------------
# SAVE TRAINED PARAMETERS
# Centers, scales (pre-softplus), alphas (pre-sigmoid), rotations.
# Re-load later to resume or visualize.
# ---------------------------------------------
torch.save({
    'centers': centers,
    'scales': scales,
    'alphas': alphas,
    'rotations': rotations
}, 'gaussians.pth')

In [None]:
# ---------------------------------------------
# RELOAD PARAMETERS
# Useful for separate visualization pass or continued fine-tuning.
# ---------------------------------------------
checkpoint = torch.load('gaussians.pth')
centers   = checkpoint['centers']
scales    = checkpoint['scales']
alphas    = checkpoint['alphas']
rotations = checkpoint['rotations']

In [None]:
# ---------------------------------------------
# OPTIONAL: RESCALED SPLATS
# Example of post-processing: adjusting scales for sharper rendering.
# ---------------------------------------------
shrinked_scales = scales * 1.
gaussian_image = create_gaussian_image(centers, shrinked_scales, rotations, alphas, image_size)
plt.figure(figsize=(3, 3))
plt.imshow(gaussian_image.squeeze().detach().cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()