### Load Images and run Napari

In [1]:
import napari
import numpy as np
from skimage.io import imread_collection
import glob

# Usa glob per ottenere correttamente i percorsi di tutte le immagini RGB
rgb_image_paths = sorted(glob.glob("../data/data1/RGBintegrals/layer_*.png"))

# Carica le immagini RGB come stack 3D
rgb_images = imread_collection(rgb_image_paths).concatenate()

# Crea il viewer Napari con le immagini RGB
viewer = napari.Viewer()
viewer.add_image(rgb_images, name="RGB Integrals", rgb=True)

# Crea layer vuoto per annotazioni manuali
labels_layer = viewer.add_labels(np.zeros(rgb_images.shape[:-1], dtype=int), name="Manual Labels")

napari.run()


In [None]:
import os
import glob
import torch
import numpy as np
import tifffile
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from skimage.io import imread

IMAGES_DIR = "../data/data1/min_results"
MODEL_DIR = "models_iterative"
os.makedirs(MODEL_DIR, exist_ok=True)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

class Simple2ClassNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 2, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x

def partial_cross_entropy(logits, lbls):
    mask_unl = (lbls == 0)
    cl_target = torch.where(lbls == 2, 1, 0)  # in-focus=1, out-of-focus=0

    logits = logits.permute(0,2,3,1).contiguous().view(-1,2)
    cl_target = cl_target.view(-1)
    mask_unl = mask_unl.view(-1)

    labeled_idx = (~mask_unl).nonzero(as_tuple=True)[0]
    if labeled_idx.numel() == 0:
        return torch.tensor(0.0, device=logits.device, requires_grad=True)

    logits_lab = logits[labeled_idx]
    target_lab = cl_target[labeled_idx]

    ce = F.cross_entropy(logits_lab, target_lab, reduction='sum')  # somma della loss
    normalized_ce = ce / labeled_idx.numel()  # dividi per numero pixel annotati
    return normalized_ce

# Dataset rapido interno
class NapariDataset(Dataset):
    def __init__(self, image, mask):
        super().__init__()
        self.image = image.astype(np.float32) / (255.0 if image.max() > 1 else 1.0)
        self.mask = mask.astype(np.float32)

    def __len__(self):
        return 1  # una sola immagine

    def __getitem__(self, idx):
        img = np.expand_dims(self.image, axis=0)  # (1,H,W)
        lbl = self.mask
        return torch.from_numpy(img), torch.from_numpy(lbl)

def train_model(layer_idx, img, mask, model_path, epochs=30, lr=1e-3):
    model = Simple2ClassNet().to(device)

    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"Modello caricato e aggiornato per layer {layer_idx}")
    else:
        print(f"Nuovo modello creato per layer {layer_idx}")

    ds = NapariDataset(img, mask)
    loader = DataLoader(ds, batch_size=1, shuffle=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    print(f"*******Training Layer {layer_idx}*******")
    for epoch in range(epochs):
        epoch_loss = 0.0
        num_batches = 0

        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = partial_cross_entropy(logits, lbls)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        epoch_loss /= num_batches
        print(f"Layer {layer_idx} | Epoch [{epoch+1}/{epochs}] - Normalized Loss: {epoch_loss:.6f}")


def infer_mask(model, img):
    img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(device)
    with torch.no_grad():
        logits = model(img_tensor)
        probs = torch.softmax(logits, dim=1)
        pred_class = probs.argmax(dim=1).squeeze(0).cpu().numpy()
    return pred_class

# Main
labels_data = labels_layer.data  # Napari labels
variance_image_paths = sorted(glob.glob(f"{IMAGES_DIR}/variance_image_*_min_grayscale.tiff"))
num_layers, H, W = labels_data.shape
final_prediction = np.zeros_like(labels_data)

for layer_idx in range(num_layers):
    mask = labels_data[layer_idx]

    img_path = variance_image_paths[layer_idx - 1] if layer_idx > 0 else variance_image_paths[0]
    variance_img = imread(img_path).astype(np.float32)
    variance_img /= (255.0 if variance_img.max() > 1 else 1.0)

    model_path = os.path.join(MODEL_DIR, f"layer_{layer_idx}.pt")

    if np.any(mask > 0):
        model = train_model(layer_idx, variance_img, mask, model_path)
    else:
        existing_models = [idx for idx in range(num_layers) if os.path.exists(os.path.join(MODEL_DIR, f"layer_{idx}.pt"))]

        if existing_models:
            closest_model_idx = min(existing_models, key=lambda idx: abs(layer_idx - idx))
            model_path = os.path.join(MODEL_DIR, f"layer_{closest_model_idx}.pt")
            model = Simple2ClassNet().to(device)
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"Nessuna annotazione manuale layer {layer_idx}, uso modello del layer {closest_model_idx}")
        else:
            print(f"[WARN] Nessun modello disponibile per il layer {layer_idx}, skipping.")
            continue

    pred_mask = infer_mask(model, variance_img)
    final_prediction[layer_idx] = pred_mask

# Rimappa i valori manuali (2=>1, 3=>0) per renderli coerenti con la predizione finale
manual_mapped = np.where(labels_layer.data == 2, 1, np.where(labels_layer.data == 3, 0, 0))

# Adesso sovrascrivi soltanto dove non annotato (==0)
labels_layer.data = np.where(labels_layer.data == 0, final_prediction, manual_mapped)
labels_layer.refresh()



Nessuna annotazione manuale layer 0, uso modello del layer 0
Nessuna annotazione manuale layer 1, uso modello del layer 0
Nessuna annotazione manuale layer 2, uso modello del layer 0
Nessuna annotazione manuale layer 3, uso modello del layer 4
Nessuna annotazione manuale layer 4, uso modello del layer 4
Nessuna annotazione manuale layer 5, uso modello del layer 4
Nessuna annotazione manuale layer 6, uso modello del layer 4
Nessuna annotazione manuale layer 7, uso modello del layer 4
Nessuna annotazione manuale layer 8, uso modello del layer 4
Nessuna annotazione manuale layer 9, uso modello del layer 4
Nessuna annotazione manuale layer 10, uso modello del layer 4
Nessuna annotazione manuale layer 11, uso modello del layer 17
Nessuna annotazione manuale layer 12, uso modello del layer 17
Nessuna annotazione manuale layer 13, uso modello del layer 17
Nessuna annotazione manuale layer 14, uso modello del layer 17
Nessuna annotazione manuale layer 15, uso modello del layer 17
Nessuna annot