# Fusion Model & PM2.5 Predictions

This notebook loads the **FusionModel** (Vim + SocialPointEncoder + SparseFusion + regression head), runs a forward pass with dummy or real data, and visualizes the **PM2.5 prediction map** over Delhi NCR.

## 1. Imports and model setup

In [ ]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.model.fusion_model import FusionModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = FusionModel(
    in_channels=3,
    satellite_dim=768,
    patch_size=4,
    vim_depth=4,
    grid_height=64,
    grid_width=64,
    clip_dim=512,
    social_dim=768,
    num_heads=8,
).to(device)
model.eval()
print("FusionModel loaded.")

## 2. Dummy inputs (or load real satellite + social)

For inference you need: **satellite** (B, 3, 64, 64), **social_points** (B, N, 512), **social_lat_lon** (B, N, 2), and **roi**.

In [ ]:
B, N = 2, 20
roi = {"min_lon": 76.8, "min_lat": 28.4, "max_lon": 77.5, "max_lat": 28.9}

satellite = torch.randn(B, 3, 64, 64, device=device) * 0.1 + 0.5
social_points = torch.randn(B, N, 512, device=device) * 0.1
social_lat_lon = torch.zeros(B, N, 2, device=device)
social_lat_lon[..., 0] = roi["min_lon"] + (roi["max_lon"] - roi["min_lon"]) * torch.rand(B, N, device=device)
social_lat_lon[..., 1] = roi["min_lat"] + (roi["max_lat"] - roi["min_lat"]) * torch.rand(B, N, device=device)

with torch.no_grad():
    pm25_map = model(satellite, social_points, social_lat_lon, roi)

print("PM2.5 map shape:", pm25_map.shape)

## 3. Visualize PM2.5 prediction map

In [ ]:
pred = pm25_map[0].cpu().numpy()

fig, ax = plt.subplots(1, 1, figsize=(10, 8))
im = ax.imshow(pred, cmap="YlOrRd", aspect="equal", origin="lower")
ax.set_title("PM2.5 prediction map (sample 0) — Delhi NCR 64×64 grid")
cbar = plt.colorbar(im, ax=ax, label="PM2.5 (predicted)")
plt.tight_layout()
plt.show()

## 4. Overlay on ROI (geographic view)

In [ ]:
from matplotlib import cm

fig, ax = plt.subplots(1, 1, figsize=(10, 8))
ax.set_facecolor("#1a1a2e")
fig.patch.set_facecolor("#1a1a2e")

extent = [roi["min_lon"], roi["max_lon"], roi["min_lat"], roi["max_lat"]]
im = ax.imshow(pred, extent=extent, cmap="YlOrRd", aspect="auto", origin="lower", alpha=0.85)
ax.set_xlim(roi["min_lon"], roi["max_lon"])
ax.set_ylim(roi["min_lat"], roi["max_lat"])
ax.set_xlabel("Longitude", color="#eee")
ax.set_ylabel("Latitude", color="#eee")
ax.tick_params(colors="#aaa")
ax.set_title("PM2.5 prediction over Delhi NCR", color="white", fontsize=14)
plt.colorbar(im, ax=ax, label="PM2.5", shrink=0.7).ax.tick_params(colors="#aaa")
plt.tight_layout()
plt.show()

## 5. Batch view (multiple samples)

In [ ]:
fig, axes = plt.subplots(1, B, figsize=(5 * B, 5))
if B == 1:
    axes = [axes]
for i, ax in enumerate(axes):
    p = pm25_map[i].cpu().numpy()
    ax.imshow(p, cmap="inferno", aspect="equal")
    ax.set_title(f"Sample {i}")
plt.suptitle("PM2.5 prediction maps (batch)")
plt.tight_layout()
plt.show()