In [1]:
# ==============================================
# Fastai-Training: 1 Band -> restliche (C-1) Bänder
# Daten-Layout: HSI (H,W,C)= (260,1500,49)   wn = [760..1240]nm in 10nm-Schritten
# ==============================================

from fastai.vision.all import *
import torch, numpy as np, scipy.io, glob, random
from pathlib import Path

# -----------------------------
# 0) Hardware-Optimierung
# -----------------------------
torch.backends.cudnn.benchmark = True                 # schnellere CuDNN-Auswahl für konstante Input-Shapes
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")        # bessere (und oft schnellere) GEMMs
device = default_device()                             # fastai-Gerät (cuda, wenn vorhanden)
print("Device:", device)

# -----------------------------
# 1) Utils: MAT-Laden & WL-Mapping
# -----------------------------
def load_cube_mat(path:Path):
    "Lädt .mat -> (cube: (C,H,W) float32, wn: (C,) int)"
    mat = scipy.io.loadmat(path)
    HSI = mat["HSI"].astype(np.float32)              # (H,W,C)
    H, W, C = HSI.shape
    cube = torch.from_numpy(HSI).permute(2,0,1).contiguous()  # -> (C,H,W)
    wn = mat.get("wn", None)
    if wn is not None:
        wn = torch.from_numpy(wn.reshape(-1).astype(np.int32))  # (C,)
    else:
        wn = torch.arange(760, 760+10*C, 10, dtype=torch.int32) # Fallback
    return cube, wn

def band_index_from_nm(wn_vec:torch.Tensor, nm:int) -> int:
    "Nimmt das **nächstgelegene** Band zur gewünschten Wellenlänge"
    return int(torch.argmin(torch.abs(wn_vec - int(nm))).item())

# -----------------------------
# 2) Daten einsammeln
# -----------------------------
data_dir = Path("data")
mat_paths = sorted(glob.glob(str(data_dir/"*.mat")))
assert len(mat_paths)>0, "Keine .mat-Dateien in data/ gefunden."

# wir laden alle Cubes in Liste (für größere Datensätze könntest du lazy-laden)
cubes, wn_all = [], None
for p in mat_paths:
    cube, wn = load_cube_mat(Path(p))
    cubes.append(cube)               # (C,H,W)
    if wn_all is None: wn_all = wn   # nimm die erste WL-Liste als Referenz

C, H, W = cubes[0].shape
print(f"Gelesen: {len(cubes)} Cubes, Shape je (C,H,W)=({C},{H},{W})")
print("Wellenlängen (erste/letzte):", int(wn_all[0]), int(wn_all[-1]))

# -----------------------------
# 3) Bandauswahl (deine Präferenzen)
#    Wunschliste: 550nm (Grün), 800nm (NIR), 1550nm (SWIR)
#    -> wir mappen auf das nächste im Datensatz verfügbare Band
# -----------------------------
desired_nms = [550, 800, 1550]
mapped = [(nm, int(wn_all[band_index_from_nm(wn_all, nm)])) for nm in desired_nms]
print("Band-Mapping (gewünscht -> genutzt):", mapped)

# setze das aktive Eingabeband (nimm eins der gemappten)
input_nm_desired = 800                 # <- hier kannst du 550/800/1550 testen
band_idx = band_index_from_nm(wn_all, input_nm_desired)
print(f"Gewünschtes Band {input_nm_desired}nm -> benutze {int(wn_all[band_idx])}nm (Index {band_idx})")

# -----------------------------
# 4) Bandweise Normalisierung bestimmen (nur aus Train-Cubes)
#     - robust: Schätzung über zufällige Patches
# -----------------------------
def estimate_minmax(cubes, ps=128, draws=400):
    mins, maxs = [], []
    for _ in range(draws):
        cube = random.choice(cubes)   # (C,H,W)
        _, Hc, Wc = cube.shape
        if Hc<ps or Wc<ps: continue
        y0 = random.randint(0, Hc-ps)
        x0 = random.randint(0, Wc-ps)
        patch = cube[:, y0:y0+ps, x0:x0+ps]
        mins.append(patch.view(C,-1).min(dim=1).values)
        maxs.append(patch.view(C,-1).max(dim=1).values)
    mins = torch.stack(mins,0).min(0).values
    maxs = torch.stack(maxs,0).max(0).values
    return mins, maxs

mins, maxs = estimate_minmax(cubes, ps=128, draws=400)
eps = 1e-8

def norm_cube(cube):
    "Bandweise Min-Max Normalisierung mit stabilen Schätzern"
    return (cube - mins.view(-1,1,1)) / (maxs.view(-1,1,1) - mins.view(-1,1,1) + eps)

# -----------------------------
# 5) Dataset: zufällige Patches (x=1 Band, y=C-1 Bänder)
# -----------------------------
class HSIPatchDS(torch.utils.data.Dataset):
    def __init__(self, cubes, band_idx:int, ps=128, n_patches=4000):
        self.cubes, self.k = cubes, band_idx
        self.ps, self.n = ps, n_patches

    def __len__(self): return self.n

    def __getitem__(self, i):
        cube = random.choice(self.cubes)               # (C,H,W)
        C,H,W = cube.shape
        y0 = random.randint(0, H-self.ps)
        x0 = random.randint(0, W-self.ps)
        patch = norm_cube(cube)[:, y0:y0+self.ps, x0:x0+self.ps] # normiert
        x = patch[self.k:self.k+1]                     # (1,ps,ps)  -> Eingabeband
        y = torch.cat([patch[:self.k], patch[self.k+1:]], dim=0) # (C-1,ps,ps)
        return x, y

# Train/Valid Datasets & DataLoaders (fastai)
train_ds = HSIPatchDS(cubes, band_idx, ps=128, n_patches=6000)
valid_ds = HSIPatchDS(cubes, band_idx, ps=128, n_patches=800)

# DataLoaders mit pin_memory, num_workers und channels_last
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=4, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)
dls = DataLoaders(train_dl, valid_dl).cuda()

# -----------------------------
# 6) Loss & Metriken: MSE + SAM
# -----------------------------
class SAMLoss(Module):
    "Spectral Angle Mapper als Loss (je Pixel), gemittelt."
    def __init__(self, eps=1e-8): self.eps = eps
    def forward(self, pred, targ):
        # pred, targ: (B,C,H,W)
        B,C,H,W = pred.shape
        P = rearrange(pred, 'b c h w -> (b h w) c')
        T = rearrange(targ, 'b c h w -> (b h w) c')
        num = (P*T).sum(dim=1)
        den = (P.norm(dim=1)*T.norm(dim=1)).clamp_min(self.eps)
        cos = (num/den).clamp(-1+1e-6, 1-1e-6)
        ang = torch.acos(cos)
        return ang.mean()

class CombinedLoss(Module):
    "0.7*MSE + 0.3*SAM (Startwerte; gerne tunen)"
    def __init__(self, w_mse=0.7, w_sam=0.3):
        self.mse = MSELossFlat()
        self.sam = SAMLoss()
        self.w_mse, self.w_sam = w_mse, w_sam
    def forward(self, pred, targ):
        return self.w_mse*self.mse(pred, targ) + self.w_sam*self.sam(pred, targ)

# Metriken für Validation
def sam_metric(pred, targ): return SAMLoss()(pred, targ)
def rmse_metric(pred, targ): return torch.sqrt(F.mse_loss(pred, targ))
def psnr_metric(pred, targ, data_range=1.0, eps=1e-8):
    mse = F.mse_loss(pred, targ)
    return 20. * torch.log10(data_range / torch.sqrt(mse + eps))

# -----------------------------
# 7) Modell: fastai Unet mit ResNet-Encoder
#     - n_in=1 (ein Eingangskanal: das gewählte Band)
#     - n_out=C-1 (alle anderen Bänder)
#     - y_range=(0,1): SigmoidRange-Klammerung für normierte Targets
# -----------------------------
arch = resnet34
learn = unet_learner(
    dls, arch,
    n_in=1,                    # wir geben 1 Kanal rein (gewähltes Band)
    n_out=C-1,                 # wir wollen C-1 Bänder schätzen
    y_range=(0,1),             # Targets sind [0,1] nach Normierung
    loss_func=CombinedLoss(),
    metrics=[sam_metric, rmse_metric, psnr_metric]
)

# Channels-last für effiziente Tensor-Layouts (oft schneller auf Ampere+)
for m in learn.model.modules():
    if isinstance(m, torch.nn.Conv2d):
        m.weight.data = m.weight.data.to(memory_format=torch.channels_last)
learn.model = learn.model.to(memory_format=torch.channels_last)

# Mixed Precision (FP16) für Speed auf GPU
learn = learn.to_fp16()

# PyTorch 2.x Compiler (bei CUDA 11.8+/PyTorch>=2.0)
try:
    learn.model = torch.compile(learn.model, mode="max-autotune")
    print("torch.compile aktiviert.")
except Exception as e:
    print("torch.compile nicht verfügbar:", e)

# -----------------------------
# 8) Training
# -----------------------------
# 1) Optional: LR-Finder
# learn.lr_find()

# 2) Haupttraining
learn.fine_tune(
    20,                    # Epochen
    base_lr=3e-4,          # Start-LR (tunen!)
    cbs=[
        SaveModelCallback(monitor='valid_loss', fname='best'),
        EarlyStoppingCallback(monitor='valid_loss', patience=5)
    ]
)

# -----------------------------
# 9) Export & Beispiel-Inferenz (ein großer Crop)
# -----------------------------
learn.export("hsi_unet_1toRest_fastai.pkl")
print("Model exported: hsi_unet_1toRest_fastai.pkl")

# Inferenz-Demo (großer Crop aus erstem Cube)
with torch.no_grad():
    cube0 = cubes[0]              # (C,H,W)
    ps = min(256, cube0.shape[1], cube0.shape[2])
    y0 = max((cube0.shape[1]-ps)//2, 0)
    x0 = max((cube0.shape[2]-ps)//2, 0)
    big = norm_cube(cube0)[:, y0:y0+ps, x0:x0+ps]
    x_big = big[band_idx:band_idx+1].unsqueeze(0).to(device)        # (1,1,ps,ps)
    y_true = torch.cat([big[:band_idx], big[band_idx+1:]],0)[None].to(device)  # (1,C-1,ps,ps)

    pred = learn.model(x_big.half())                                 # FP16 fwd
    # Metriken:
    sam = sam_metric(pred.float(), y_true.float()).item()
    rmse = rmse_metric(pred.float(), y_true.float()).item()
    psnr = psnr_metric(pred.float(), y_true.float()).item()
    print(f"Demo-Crop @ {int(wn_all[band_idx])}nm | SAM={sam:.4f} rad | RMSE={rmse:.4f} | PSNR={psnr:.2f} dB")


Device: cpu


  return torch._C._cuda_getDeviceCount() > 0


KeyError: 'HSI'