In [3]:
# =========================
# Single-Band -> Rest HSI Recon (PyTorch)
# Daten: HSI (H,W,C) = (260,1500,49), wn (49,1) = [760..1240] nm
# =========================
import math, numpy as np, scipy.io
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(0)

# ---------- 1) Daten laden ----------
mat = scipy.io.loadmat("data/heatcube_0001.mat")
HSI = mat["HSI"]             # (H,W,C) float64
H, W, C = HSI.shape
HSI = torch.from_numpy(HSI.astype(np.float32))  # -> float32
# nach (C,H,W) transponieren
cube = HSI.permute(2,0,1).contiguous()          # (C,H,W)

# Wellenlängen laden (optional, für Auswahl per nm)
# Falls in anderer MAT-Datei:
# matLibWn = scipy.io.loadmat("data/wavelengths.mat")  # Beispiel
# wn = matLibWn["wn"].reshape(-1)  # (49,)
wn = mat.get("wn", None)
if wn is not None:
    wn = torch.from_numpy(wn.astype(np.int32)).view(-1)  # (49,)
else:
    # Fallback: gleichmäßige 10-nm Schritte ab 760
    wn = torch.arange(760, 760 + 10*C, 10, dtype=torch.int32)

# ---------- 2) Eingabeband wählen ----------
def band_index_from_nm(wn_vector, target_nm):
    # nimmt nächstgelegene Wellenlänge
    idx = int(torch.argmin(torch.abs(wn_vector - int(target_nm))).item())
    return idx

target_wavelength_nm = 1000  # <- hier ändern, z.B. 940, 1100, etc.
band_idx = band_index_from_nm(wn, target_wavelength_nm)
print(f"Eingabeband: {target_wavelength_nm} nm -> index {band_idx}, wn[idx]={int(wn[band_idx])}nm")

# ---------- 3) Patch-Datensatz ----------
class SingleBandToRestHSIPatches(Dataset):
    def __init__(self, cube_CHW, band_idx, patch_size=128, n_patches=2000, split="train", stats=None):
        """
        cube_CHW: (C,H,W) Tensor
        band_idx: int (Eingabeband)
        patch_size: Größe der Quadrate
        n_patches: Anzahl zufälliger Patches, die pro Epoch gezogen werden
        split: "train" oder "val"
        stats: (mins, maxs) für bandweise Min-Max (nur aus Train bestimmt)
        """
        self.cube = cube_CHW
        self.C, self.H, self.W = cube_CHW.shape
        self.k = band_idx
        self.ps = patch_size
        self.n_patches = n_patches if split == "train" else max(n_patches//8, 256)
        self.split = split

        if stats is None and split == "train":
            # bandweise Min/Max nur aus zufälligen Train-Patches schätzen
            mins = []; maxs = []
            with torch.no_grad():
                for _ in range(200):
                    y0 = torch.randint(0, self.H - self.ps + 1, (1,)).item()
                    x0 = torch.randint(0, self.W - self.ps + 1, (1,)).item()
                    patch = self.cube[:, y0:y0+self.ps, x0:x0+self.ps]
                    mn = patch.view(self.C,-1).min(dim=1).values
                    mx = patch.view(self.C,-1).max(dim=1).values
                    mins.append(mn); maxs.append(mx)
            mins = torch.stack(mins,0).min(0).values
            maxs = torch.stack(maxs,0).max(0).values
            self.mins, self.maxs = mins, maxs
        else:
            self.mins, self.maxs = (stats if stats is not None else (None, None))

    def __len__(self):
        return self.n_patches

    def get_stats(self):
        return self.mins, self.maxs

    def __getitem__(self, idx):
        # zufälliger Patch (Train/Val)
        y0 = torch.randint(0, self.H - self.ps + 1, (1,)).item()
        x0 = torch.randint(0, self.W - self.ps + 1, (1,)).item()
        patch = self.cube[:, y0:y0+self.ps, x0:x0+self.ps]  # (C,ps,ps)

        # bandweise Min-Max
        eps = 1e-8
        if self.mins is not None and self.maxs is not None:
            patch = (patch - self.mins.view(-1,1,1)) / (self.maxs.view(-1,1,1) - self.mins.view(-1,1,1) + eps)

        x = patch[self.k:self.k+1]  # (1,ps,ps)
        y = torch.cat([patch[:self.k], patch[self.k+1:]], dim=0)  # (C-1,ps,ps)
        return x, y

# Train/Val Datasets + Normalisierung aus Train
train_set_temp = SingleBandToRestHSIPatches(cube, band_idx, patch_size=128, n_patches=1, split="train")
mins, maxs = train_set_temp.get_stats()
train_set = SingleBandToRestHSIPatches(cube, band_idx, patch_size=128, n_patches=4000, split="train", stats=(mins,maxs))
val_set   = SingleBandToRestHSIPatches(cube, band_idx, patch_size=128, n_patches=800,  split="val",   stats=(mins,maxs))

train_loader = DataLoader(train_set, batch_size=4, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_set,   batch_size=4, shuffle=False, num_workers=0)

# ---------- 4) Modell (kleines UNet-ähnliches CNN) ----------
class SmallHSIRecon(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        ch = 48
        self.enc = nn.Sequential(
            nn.Conv2d(1, ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1), nn.ReLU(inplace=True)
        )
        self.down = nn.Conv2d(ch, ch, 4, stride=2, padding=1)  # 1/2
        self.bottleneck = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1), nn.ReLU(inplace=True),
        )
        self.up = nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1)  # x2
        self.dec = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(ch, out_channels, 3, padding=1),
            nn.Softplus()   # Nichtnegativität; falls deine Daten zentriert sind, ersetze durch Identity
        )

    def forward(self, x):
        e = self.enc(x)
        d = self.down(e)
        b = self.bottleneck(d)
        u = self.up(b)
        if u.shape == e.shape:
            u = u + e
        y = self.dec(u)
        return y

# ---------- 5) Metriken ----------
def spectral_angle_mapper(pred, target, eps=1e-8):
    # pred/target: (B,C,H,W)
    B, C, H, W = pred.shape
    P = pred.permute(0,2,3,1).reshape(-1, C)
    T = target.permute(0,2,3,1).reshape(-1, C)
    num = (P*T).sum(1)
    den = (P.norm(1) * T.norm(1)).clamp_min(eps)
    cos = (num/den).clamp(-1+1e-6, 1-1e-6)
    ang = torch.acos(cos)  # Radiant
    return ang.mean()

def rmse(pred, target):
    return torch.sqrt(torch.mean((pred - target) ** 2))

def psnr(pred, target, data_range=1.0, eps=1e-8):
    mse = torch.mean((pred - target) ** 2)
    return 20.0 * torch.log10(data_range / torch.sqrt(mse + eps))

# ---------- 6) Training ----------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SmallHSIRecon(out_channels=C-1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
mse = nn.MSELoss()

def validate():
    model.eval()
    mse_sum = sam_sum = psnr_sum = 0.0
    n = 0
    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)
            y_hat = model(x)
            mse_sum  += mse(y_hat, y).item() * x.size(0)
            sam_sum  += spectral_angle_mapper(y_hat, y).item() * x.size(0)
            psnr_sum += psnr(y_hat, y).item() * x.size(0)
            n += x.size(0)
    return mse_sum/n, sam_sum/n, psnr_sum/n

EPOCHS = 20
for epoch in range(1, EPOCHS+1):
    model.train()
    loss_sum = 0.0
    for x,y in train_loader:
        x,y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = 0.7*mse(y_hat,y) + 0.3*spectral_angle_mapper(y_hat,y)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item() * x.size(0)

    val_mse, val_sam, val_psnr = validate()
    print(f"Epoch {epoch:02d} | train_loss={loss_sum/len(train_set):.4f} "
          f"| val_mse={val_mse:.4f} | val_sam(rad)={val_sam:.4f} | val_psnr(dB)={val_psnr:.2f}")

# ---------- 7) Demo: vollständige Rekonstruktion auf großem Crop ----------
model.eval()
with torch.no_grad():
    # großen Crop wählen (z.B. mittlerer 256x256 Bereich, falls möglich)
    ps = 256
    y0 = max((H-ps)//2, 0); x0 = max((W-ps)//2, 0)
    big = cube[:, y0:y0+ps, x0:x0+ps].clone()  # (C,ps,ps)

    # Normierung anwenden
    eps = 1e-8
    big_n = (big - mins.view(-1,1,1)) / (maxs.view(-1,1,1) - mins.view(-1,1,1) + eps)

    x_big = big_n[band_idx:band_idx+1].unsqueeze(0).to(device)  # (1,1,ps,ps)
    y_true = torch.cat([big_n[:band_idx], big_n[band_idx+1:]], dim=0).unsqueeze(0).to(device)  # (1,C-1,ps,ps)
    y_pred = model(x_big)

    demo_sam  = spectral_angle_mapper(y_pred, y_true).item()
    demo_rmse = rmse(y_pred, y_true).item()
    demo_psnr = psnr(y_pred, y_true).item()
    print(f"\nDemo Crop Metrics @ {int(wn[band_idx])}nm input -> SAM={demo_sam:.4f} rad, RMSE={demo_rmse:.4f}, PSNR={demo_psnr:.2f} dB")


Eingabeband: 1000 nm -> index 24, wn[idx]=1000nm


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.