In [4]:
import os
import glob
import torch
import numpy as np
import tifffile
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ==============================================
# 1) Config
# ==============================================
IMAGES_DIR = "min_results"      # Folder with original images 
LABELS_DIR = "min_mask_1"       # Folder with masks (1=in focus, 3=out-of-focus, 0=unlabeled)
MODEL_SAVE_PATH = "model_iter_inout.pt"


# ==============================================
# 2) Dataset
# ==============================================
class TwoClassDataset(Dataset):
    """
    Read tuple (image, mask).
    Mask can have the following values:
      - 0 => unlabeled (then ignore)
      - 1 => in focus => class 0
      - 3 => out of focus => class 1
    """
    def __init__(self, images_dir, labels_dir):
        super().__init__()
        self.imgs = []
        self.labels = []

        all_img_paths = sorted(glob.glob(os.path.join(images_dir, "*.tiff")))
        for img_path in all_img_paths:
            base = os.path.basename(img_path)
            label_path = os.path.join(labels_dir, base)
            if os.path.exists(label_path):
                self.imgs.append(img_path)
                self.labels.append(label_path)

        print(f"Dataset trovato: {len(self.imgs)} coppie (img, mask)")

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        lbl_path = self.labels[idx]

        img = tifffile.imread(img_path).astype(np.float32)
        lbl = tifffile.imread(lbl_path).astype(np.float32)

        # Just a sanity chekc
        if img.ndim == 3:
            img = img[..., 0]
        if lbl.ndim == 3:
            lbl = lbl[..., 0]

        # Normalizing the image if max>1 
        if img.max() > 1:
            img /= 255.0

        # lbl avrà [0,1,3]. Tenuta in float
        # shape -> (1,H,W)
        img = np.expand_dims(img, axis=0)

        # Tensor conversion
        img_t = torch.from_numpy(img)

        lbl_t = torch.from_numpy(lbl)  # shape (H,W) in float
        return img_t, lbl_t


# ==============================================
# 3) Model with 2 classes (2 channels out)
# ==============================================
class Simple2ClassNet(nn.Module):
    """
    2 layers conv, output=2 channels (logits)
    channel 0 => "in focus"
    channel 1 => "out of focus"
    """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 2, 3, padding=1)  # 2 channels
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x  # shape [B,2,H,W]


# ==============================================
# 4) Training with partial cross-entropy
# ==============================================
def train_inout(epochs=10, lr=1e-3, batch_size=2):
    device = torch.device("mps")

    ds = TwoClassDataset(IMAGES_DIR, LABELS_DIR)
    if len(ds) == 0:
        print("No data!")
        return

    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)
    model = Simple2ClassNet().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    def partial_cross_entropy(logits, lbls):
        """
        logits: shape [B,2,H,W]
        lbls: shape [B,H,W] con valori 0,1,3
          0 => unlabeled
          1 => classe=0
          3 => classe=1
        """
        b, c, h, w = logits.shape
        lbl_resh = lbls.view(b, h, w)  # (B,H,W)
        logits_resh = logits.view(b, c, h, w)

        # building mask unlabeled
        mask_unl = (lbl_resh == 0.0)
        # Building a class target 
        cl_target = torch.where(lbl_resh == 1.0, 0, 1)  # default out-of-focus=1

        # flatten
        logits_resh = logits_resh.permute(0,2,3,1).contiguous()  # (B,H,W,2)
        logits_resh = logits_resh.view(-1, 2)                    # (B*H*W, 2)
        cl_target = cl_target.view(-1)                           # (B*H*W)

        mask_unl = mask_unl.view(-1)  # bool

        # if all pixel are unlabeled, return 0 
        labeled_idx = (~mask_unl).nonzero(as_tuple=True)[0]
        if labeled_idx.numel() == 0:
            return torch.tensor(0.0, device=logits.device, requires_grad=True)

        # Filtering logits and target 
        logits_lab = logits_resh[labeled_idx]    # (num_lab, 2)
        target_lab = cl_target[labeled_idx]      # (num_lab)

        # cross_entropy standard
        ce = F.cross_entropy(logits_lab, target_lab)
        return ce

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            lbls = lbls.to(device)

            optimizer.zero_grad()
            logits = model(imgs)
            loss = partial_cross_entropy(logits, lbls)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(loader)
        print(f"Epoch [{epoch+1}/{epochs}] Loss: {epoch_loss:.6f}")

    # Save
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"Model saved in {MODEL_SAVE_PATH}")

# =================================================
if __name__ == "__main__":
    train_inout(epochs=30, lr=1e-2)


Dataset trovato: 4 coppie (img, mask)
Epoch [1/30] Loss: 0.514876
Epoch [2/30] Loss: 0.205419
Epoch [3/30] Loss: 0.088131
Epoch [4/30] Loss: 0.197348
Epoch [5/30] Loss: 0.209756
Epoch [6/30] Loss: 0.055685
Epoch [7/30] Loss: 0.047094
Epoch [8/30] Loss: 0.133145
Epoch [9/30] Loss: 0.092881
Epoch [10/30] Loss: 0.083119
Epoch [11/30] Loss: 0.059478
Epoch [12/30] Loss: 0.057341
Epoch [13/30] Loss: 0.034911
Epoch [14/30] Loss: 0.018312
Epoch [15/30] Loss: 0.011621
Epoch [16/30] Loss: 0.012398
Epoch [17/30] Loss: 0.012488
Epoch [18/30] Loss: 0.012598
Epoch [19/30] Loss: 0.046125
Epoch [20/30] Loss: 0.035092
Epoch [21/30] Loss: 0.022733
Epoch [22/30] Loss: 0.014331
Epoch [23/30] Loss: 0.011216
Epoch [24/30] Loss: 0.011062
Epoch [25/30] Loss: 0.012562
Epoch [26/30] Loss: 0.009163
Epoch [27/30] Loss: 0.005524
Epoch [28/30] Loss: 0.005950
Epoch [29/30] Loss: 0.002939
Epoch [30/30] Loss: 0.007011
Model saved in model_iter_inout.pt
