In [1]:
import os
import glob
import torch
import numpy as np
import tifffile
import shutil
import torch.nn as nn

IMAGES_DIR = "min_results"     # Folder for results
LABELS_MANUAL = "min_mask_1"   # Manually Labeled mask (with 1,3,0)
PRED_MASKS = "pred_mask_1"     # Output folder
MODEL_PATH = "model_iter_inout.pt"  

os.makedirs(PRED_MASKS, exist_ok=True)

# -- Two channel net (0 -> in focus, 1 -> out of focus)
class Simple2ClassNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 2, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x  # logits => shape [B,2,H,W]

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model = Simple2ClassNet().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

def infer_mask(img_path):
    """
    Inference of net with 2 channel
      - channel 0 => "in focus" => output label=1
      - channel 1 => "out of focus" => output label=3
    """
    img = tifffile.imread(img_path)
    if img.ndim == 3:
        img = img[..., 0]
    img = img.astype(np.float32)
    if img.max() > 1:
        img /= 255.0

    img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(img_tensor)  # [1,2,H,W]
        probs = torch.softmax(logits, dim=1)
        pred_class = probs.argmax(dim=1).squeeze(0).cpu().numpy()  # (H,W) in {0,1}

    # Mapping 0->1, 1->3
    out_mask = np.where(pred_class == 0, 1, 3).astype(np.uint8)
    return out_mask

def main():
    all_imgs = sorted(glob.glob(os.path.join(IMAGES_DIR, "*.tiff")))
    print(f"Finding {len(all_imgs)} images in '{IMAGES_DIR}'...")

    for img_path in all_imgs:
        fname = os.path.basename(img_path)
        label_manual = os.path.join(LABELS_MANUAL, fname)
        label_pred = os.path.join(PRED_MASKS, fname)

        if os.path.exists(label_manual):
            # 1) Reading the manually labeled mask 
            manual_mask = tifffile.imread(label_manual).astype(np.uint8)
            # 2) Inference 
            model_mask = infer_mask(img_path)

            # 3) Componing the full mask: if manual_mask[y,x] in {1,3} then leave manual, otherwise model_mask[y,x]
            final_mask = np.where(manual_mask == 0, model_mask, manual_mask)

            # Saving the image
            tifffile.imwrite(label_pred, final_mask)
            print(f"[DONE] {fname} -> {PRED_MASKS}/ (manual + model)")
        else:
            # Otherwise, pure inference
            out_mask = infer_mask(img_path)
            tifffile.imwrite(label_pred, out_mask)
            print(f"[INF] {fname} -> {PRED_MASKS}/ (predicted mask)")

if __name__ == "__main__":
    main()


Finding 100 images in 'min_results'...
[INF] variance_image_0_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_10_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_11_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_12_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_13_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_14_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_15_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_16_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_17_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_18_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_19_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[INF] variance_image_1_min_grayscale.tiff -> pred_mask_1/ (predicted mask)
[DONE] variance_image_20_min_grayscale.tiff -> pred