# Covid-19 Segmentation

Jesus Ramirez Delgado A01274723

Grant Nathaniel Keegan A01700753

Luis Adrian Uribe Cruz A01783129

Fidel Alexander Bonilla Montalvo A01798199

## 1. Introduction

The COVID-19 pandemic posed one of the most significant recent challenges to healthcare systems worldwide. One of the most widely used methods for diagnosing and monitoring patients was chest computed tomography (CT), which enables the identification of disease-specific patterns. However, manually interpreting hundreds of tomographic slices is a slow, radiologist-dependent, and variable process, opening the door to automatic decision-support tools.

In this context, the objective was to develop machine-learning models capable of automatically segmenting COVID-19–affected regions in axial CT slices. The core task is semantic segmentation—that is, classifying every pixel in the image as belonging to a specific clinical class. The challenge is to evaluate the models’ ability to accurately identify these lesions while optimizing metrics such as the Dice coefficient and Intersection over Union (IoU), which are especially relevant under severe class imbalance.

This task is relevant in two ways: on the one hand, it contributes a potentially useful tool for clinical practice, helping physicians quickly detect critical patterns in patients with COVID-19; on the other, it strengthens research in medical computer vision by providing standardized datasets and fostering open comparisons of methodologies.

To address this problem, we used a reproducible, Jupyter-notebook-based workflow in which we developed an end-to-end pipeline that includes:

Exploratory and preliminary analysis of the datasets.

An ETL process that normalizes images, selects relevant classes, and prevents information leakage when splitting the data.

Implementation and training of a 2D U-Net model in PyTorch, tuned for medical segmentation with limited data.

Visualization and validation of results, both quantitative and qualitative.

## 2. Data Analysis - Grant 

### 2.1 Exploratory Data Analysis (EDA)

The dataset is made up of 4 NumPy arrays (,npy) to train the model:
- images_medseg.npy
- images_radiopedia.npy
- masks_medseg.npy
- masks_radiopedia.npy

And one for testing:
- test_images_medseg.npy

The data sets can be separated by sorce:
- medseg
- radiopedia

Categories

Each source is split into two categories:

- images: axial CT slices stored as grayscale intensities in Hounsfield Units (HU).
Typical shape: (N, H, W, 1) or (N, H, W); dtype usually float32/float64.
Common lung window for visualization: WL≈-600, WW≈1500.

- masks: per-pixel segmentation labels aligned slice-by-slice with images from the same source.
Format can be either:

    - Label map: (N, H, W) with integer classes {0,…,K-1}, or

    - One-hot: (N, H, W, K) where channel K indexes classes.
      In MedSeg, the canonical classes are:

    0: ground glass
    
    1: consolidation
    
    2: lungs other
    
    3: background

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Bandera global
RUN_TRAIN = False  # pon True cuando sí quieras entrenar

# Magic para saltar celdas condicionalmente
from IPython.core.magic import register_cell_magic

@register_cell_magic
def skip_if(line, cell):
    if eval(line, globals(), locals()):
        print(f"Skiped block: {line}")
    else:
        exec(cell, globals(), locals())


In [None]:
prefix = '/kaggle/input/covid-segmentation/'

images_radiopedia = np.load(os.path.join(prefix, 'images_radiopedia.npy')).astype(np.float32)
masks_radiopedia = np.load(os.path.join(prefix, 'masks_radiopedia.npy')).astype(np.int8)
images_medseg = np.load(os.path.join(prefix, 'images_medseg.npy')).astype(np.float32)
masks_medseg = np.load(os.path.join(prefix, 'masks_medseg.npy')).astype(np.int8)

test_images_medseg = np.load(os.path.join(prefix, 'test_images_medseg.npy')).astype(np.float32)

In [None]:
import pandas as pd
import numpy as np

def shape_NHWK(arr, is_image: bool):
    """Return (N,H,W,K) for images/masks from common shapes."""
    if arr.ndim == 4:
        N, H, W, K = arr.shape
        return N, H, W, K
    if arr.ndim == 3:
        N, H, W = arr.shape
        K = 1 if is_image else 1
        return N, H, W, K
    raise ValueError(f"Unexpected shape {arr.shape} (is_image={is_image})")

def describe_K(array_kind: str, K: int):
    if array_kind == "Images":
        return f"{K} (grayscale image)" if K == 1 else str(K)
    # Masks
    if K == 4:
        return "4 (0=GGO, 1=Consolidation, 2=Lungs-other, 3=Background)"
    if K == 2:
        return "2 (0=GGO, 1=Consolidation)"
    if K == 1:
        return "1 (single-label mask)"
    return str(K)

rows = []

# MedSeg
N,H,W,K = shape_NHWK(images_medseg, is_image=True)
rows.append({
    "Split / Source": "MedSeg",
    "Array": "Images",
    "Shape (N, H, W, K)": f"({N}, {H}, {W}, {K})",
    "N": f"{N} (slices)",
    "H×W": f"{H}×{W}",
    "K (channels/classes)": describe_K("Images", K),
})

N,H,W,K = shape_NHWK(masks_medseg, is_image=False)
rows.append({
    "Split / Source": "MedSeg",
    "Array": "Masks",
    "Shape (N, H, W, K)": f"({N}, {H}, {W}, {K})",
    "N": f"{N} (slices)",
    "H×W": f"{H}×{W}",
    "K (channels/classes)": describe_K("Masks", K),
})

# MedSeg (test)
N,H,W,K = shape_NHWK(test_images_medseg, is_image=True)
rows.append({
    "Split / Source": "MedSeg (Test)",
    "Array": "Images",
    "Shape (N, H, W, K)": f"({N}, {H}, {W}, {K})",
    "N": f"{N} (slices)",
    "H×W": f"{H}×{W}",
    "K (channels/classes)": describe_K("Images", K),
})

# Radiopaedia
N,H,W,K = shape_NHWK(images_radiopedia, is_image=True)
rows.append({
    "Split / Source": "Radiopaedia",
    "Array": "Images",
    "Shape (N, H, W, K)": f"({N}, {H}, {W}, {K})",
    "N": f"{N} (slices)",
    "H×W": f"{H}×{W}",
    "K (channels/classes)": describe_K("Images", K),
})

N,H,W,K = shape_NHWK(masks_radiopedia, is_image=False)
rows.append({
    "Split / Source": "Radiopaedia",
    "Array": "Masks",
    "Shape (N, H, W, K)": f"({N}, {H}, {W}, {K})",
    "N": f"{N} (slices)",
    "H×W": f"{H}×{W}",
    "K (channels/classes)": describe_K("Masks", K),
})

# Optional note row
rows.append({
    "Split / Source": "Notes",
    "Array": "Meaning of dimensions",
    "Shape (N, H, W, K)": "",
    "N": "Number of slices",
    "H×W": "Height × Width (pixels)",
    "K (channels/classes)": "Images: 1 channel; Masks: 4 classes (as listed)",
})

df = pd.DataFrame(rows, columns=[
    "Split / Source", "Array", "Shape (N, H, W, K)", "N", "H×W", "K (channels/classes)"
])

display(df)

For a better visualisation of the axial CT slice we took as example the whole data set of radiopedia and made a scanning of:
- Images
- Masks
- Overlay (both datasets)

For a better optimization of resources in this notebook, the .gif's were generated in an outside python enviroment with the following code:

<details>
  <summary> Click to see the code </summary>

    from pathlib import Path
    import numpy as np
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import imageio.v2 as imageio
    
    # ====== CONFIG ======
    IMG_PATH = Path("images_radiopedia.npy")
    MSK_PATH = Path("masks_radiopedia.npy")
    OUT_DIR  = Path("out")
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    
    # ====== HELPERS ======
    def load_npy(p: Path) -> np.ndarray:
        if not p.exists():
            raise FileNotFoundError(p)
        arr = np.load(p)
        return arr
    
    def to_hw(img_array: np.ndarray) -> np.ndarray:
        """(N,H,W,1)->(N,H,W), (N,H,W)->(N,H,W)."""
        if img_array.ndim == 4 and img_array.shape[-1] == 1:
            return img_array[..., 0]
        elif img_array.ndim == 3:
            return img_array
        else:
            raise ValueError(f"Wait for: (N,H,W,1) o (N,H,W), get: {img_array.shape}")
    
    def masks_to_labels(msk: np.ndarray) -> tuple[np.ndarray, int]:
        """(N,H,W,K)->labels (N,H,W) por argmax; (N,H,W)->labels; return (labels, K)."""
        if msk.ndim == 4:
            labels = np.argmax(msk, axis=-1)
            K = msk.shape[-1]
            return labels, K
        elif msk.ndim == 3:
            labels = msk
            K = int(labels.max()) + 1
            return labels, K
        else:
            raise ValueError(f"Mask unexpected shape: {msk.shape}")
    
    def window_hu(img: np.ndarray, wl=-600, ww=1500) -> np.ndarray:
        low, high = wl - ww/2, wl + ww/2
        img_w = np.clip(img, low, high)
        img_w = (img_w - low) / (high - low)  # 0..1
        return img_w
    
    def make_palette(K: int) -> np.ndarray:
        base = np.array([
            [230,  25,  75], [ 60, 180,  75], [  0, 130, 200], [245, 130,  48],
            [145,  30, 180], [ 70, 240, 240], [240,  50, 230], [210, 245,  60],
            [250, 190, 190], [  0, 128, 128],
        ], dtype=np.float32) / 255.0
        if K <= len(base):
            return base[:K]
    
        rng = np.random.default_rng(42)
        extra = rng.random((K - len(base), 3))
        return np.vstack([base, extra])
    
    def labels_to_rgb(labels: np.ndarray, palette: np.ndarray) -> np.ndarray:
        """labels (H,W) -> rgb (H,W,3) usando palette[K,3]."""
        return palette[labels]
    
    def overlay_rgb(base_gray01: np.ndarray, label_map: np.ndarray, palette: np.ndarray, alpha=0.35) -> np.ndarray:
        """base: (H,W) en [0,1], label_map: (H,W) ints, devuelve RGB (H,W,3)."""
        base_rgb = np.dstack([base_gray01]*3)
        color = labels_to_rgb(label_map, palette)
        # Do not color background class
        # mask_fg = (label_map != 0)[..., None]
        mask_fg = np.ones_like(color, dtype=bool)
        return (1 - alpha)*base_rgb + alpha*color*mask_fg + (~mask_fg)*base_rgb
    
    def save_gif(frames: list[np.ndarray], path: Path, fps=15):
        duration = 1.0 / fps
        imageio.mimsave(path, [np.clip((f*255).astype(np.uint8), 0, 255) if f.dtype!=np.uint8 else f for f in frames], duration=duration)
    
    def save_mp4_matplotlib(frames: list[np.ndarray], path: Path, fps=15):
        # usa matplotlib + ffmpeg si lo tienes instalado
        h, w = frames[0].shape[:2]
        fig = plt.figure(figsize=(w/100, h/100), dpi=100)
        ax = plt.axes([0,0,1,1]); ax.axis('off')
        im = ax.imshow(frames[0])
        import matplotlib.animation as animation
        ani = animation.ArtistAnimation(fig, [[ax.imshow(fr, animated=True)] for fr in frames], interval=1000//fps, blit=True, repeat=False)
        ani.save(path, fps=fps, dpi=100)
        plt.close(fig)
    
    # ====== Load data ======
    imgs_raw = load_npy(IMG_PATH)    # (N,H,W,1) o (N,H,W)
    imgs = to_hw(imgs_raw)           # -> (N,H,W)
    N_img = imgs.shape[0]
    print("IMG:", imgs.shape, imgs.min(), imgs.max(), imgs.mean())
    
    labels = None; K = None
    if MSK_PATH.exists():
        masks_raw = load_npy(MSK_PATH)    # (N,H,W,K) o (N,H,W)
        labels, K = masks_to_labels(masks_raw)  # -> (N,H,W), K
        print("MSK labels:", labels.shape, "K=", K)
    else:
        print("Masks not foundedd: Just will be generated images scanning.")
    
    # ====== (1) Scanning just images ======
    frames_img = []
    for i in range(N_img):
        g = window_hu(imgs[i])        # 0..1
        rgb = np.dstack([g]*3)        # gris a RGB
        frames_img.append(rgb)
    save_gif(frames_img, OUT_DIR/"scanning_imagen.gif", fps=15)
    print("Saved:", OUT_DIR/"scanning_imagen.gif")
    
    # ====== (2) Scanning just masks ======
    if labels is not None:
        N_msk = labels.shape[0]
        pal = make_palette(K)
        frames_msk = [labels_to_rgb(labels[i], pal) for i in range(N_msk)]
        save_gif(frames_msk, OUT_DIR/"scanning_masks.gif", fps=15)
        print("Saved:", OUT_DIR/"scanning_masks.gif")
    
    # ====== (3) Scanning Overlay (If match) ======
    if labels is not None:
        N = min(N_img, labels.shape[0])
        if N < N_img or N < labels.shape[0]:
            print(f"Warning: N missmatch (IMG={N_img}, MSK={labels.shape[0]}). Will be used N={N}.")
        pal = make_palette(K)
        frames_overlay = []
        for i in range(N):
            base = window_hu(imgs[i])               # 0..1
            over = overlay_rgb(base, labels[i], pal, alpha=0.35)
            frames_overlay.append(np.clip(over, 0, 1))
        save_gif(frames_overlay, OUT_DIR/"scanning_overlay.gif", fps=15)
        print("Saved:", OUT_DIR/"barrido_overlay.gif")
        
</details>

And next will be showed the overlay .gif

In [None]:
from IPython.display import Video, display

video_path = "/kaggle/input/scanning-video/scanning.mp4"
display(Video(video_path, embed=True))


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

CLASS_NAMES = {0: "GGO", 1: "Consolidation", 2: "Lungs-other", 3: "Background"}

def ensure_4ch_masks(m):
    """Return masks as (N,H,W,4) one-hot/binary per channel.
       Accepts (N,H,W,4) or (N,H,W,2) or (N,H,W) with labels in {0..3}.
    """
    if m.ndim == 4 and m.shape[-1] == 4:
        return (m > 0).astype(np.uint8)
    if m.ndim == 4 and m.shape[-1] == 2:
        # pad to 4 channels (keep first two, zeros for 2 & 3)
        m2 = (m > 0).astype(np.uint8)
        z = np.zeros(m2.shape[:-1] + (2,), dtype=np.uint8)
        return np.concatenate([m2, z], axis=-1)
    if m.ndim == 3:
        # single-channel labels {0..3}
        N,H,W = m.shape
        out = np.zeros((N,H,W,4), dtype=np.uint8)
        for k in range(4):
            out[..., k] = (m == k).astype(np.uint8)
        return out
    raise ValueError(f"Unexpected mask shape: {m.shape}")

def per_split_stats(split_name, masks_4ch):
    """Compute a tidy DataFrame with per-class stats for one split."""
    m = ensure_4ch_masks(masks_4ch)
    N, H, W, C = m.shape
    assert C == 4, "Expecting 4 channels (0..3)"
    total_pix = H * W

    # Per-image positives per class
    pos_img = m.reshape(N, -1, C).sum(axis=1)                 # (N,C)
    any_img = pos_img > 0                                      # (N,C)

    rows = []
    for k in range(C):
        pix_share = m[..., k].mean() * 100.0                   # global pixel share %
        pct_imgs_with = any_img[:, k].mean() * 100.0           # % images with class
        pct_empty = (pos_img[:, k] == 0).mean() * 100.0        # % images empty for class

        # Per-image area % for that class
        area_pct = (pos_img[:, k] / total_pix) * 100.0         # (N,)
        area_nonzero = area_pct[area_pct > 0]
        mean_area = float(area_pct.mean()) if N else np.nan
        median_area = float(np.median(area_nonzero)) if area_nonzero.size else 0.0
        p90_area = float(np.percentile(area_nonzero, 90)) if area_nonzero.size else 0.0

        rows.append({
            "Split": split_name,
            "Class": CLASS_NAMES.get(k, f"ch{k}"),
            "Pixel Share % (global)": round(pix_share, 4),
            "% Images with class": round(pct_imgs_with, 2),
            "% Empty masks": round(pct_empty, 2),
            "Mean area % (per image)": round(mean_area, 4),
            "Median area % (non-zero imgs)": round(median_area, 4),
            "P90 area % (non-zero imgs)": round(p90_area, 4),
        })
    return pd.DataFrame(rows)

def plot_pixel_share(split_name, masks_4ch):
    m = ensure_4ch_masks(masks_4ch)
    fracs = m.mean(axis=(0,1,2)) * 100.0
    labels = [CLASS_NAMES.get(i, f"ch{i}") for i in range(m.shape[-1])]
    plt.figure()
    plt.bar(range(len(fracs)), fracs)
    plt.xticks(range(len(fracs)), labels, rotation=10)
    plt.ylabel("Pixel Share (%)")
    plt.title(f"Global pixel share per class — {split_name}")
    plt.tight_layout()
    plt.show()

def plot_images_with_class(split_name, masks_4ch):
    m = ensure_4ch_masks(masks_4ch)
    pos_img = m.reshape(m.shape[0], -1, m.shape[-1]).sum(axis=1)
    pct = (pos_img > 0).mean(axis=0) * 100.0
    labels = [CLASS_NAMES.get(i, f"ch{i}") for i in range(m.shape[-1])]
    plt.figure()
    plt.bar(range(len(pct)), pct)
    plt.xticks(range(len(pct)), labels, rotation=10)
    plt.ylabel("% of images with class")
    plt.title(f"Presence of classes — {split_name}")
    plt.tight_layout()
    plt.show()

def plot_area_histograms(split_name, masks_4ch):
    """Histograms of per-image area % for GGO(0) and Cons(1), non-zero images only."""
    m = ensure_4ch_masks(masks_4ch)
    N, H, W, _ = m.shape
    total_pix = H * W
    pos_img = m.reshape(N, -1, 4).sum(axis=1)  # (N,4)

    for k in [0, 1]:  # GGO & Consolidation
        area = (pos_img[:, k] / total_pix) * 100.0
        nz = area[area > 0]
        plt.figure()
        plt.hist(nz, bins=30)
        plt.xlabel("% of image pixels")
        plt.ylabel("Count")
        plt.title(f"Lesion area % — {split_name} — {CLASS_NAMES[k]} (non-zero images)")
        plt.tight_layout()
        plt.show()

# ---------- Build the combined stats table ----------
all_stats = []

if 'masks_medseg' in globals():
    df_med = per_split_stats("MedSeg (train)", masks_medseg)
    all_stats.append(df_med)
    plot_pixel_share("MedSeg (train)", masks_medseg)
    plot_images_with_class("MedSeg (train)", masks_medseg)
    plot_area_histograms("MedSeg (train)", masks_medseg)

if 'masks_radiopedia' in globals():
    df_rad = per_split_stats("Radiopaedia (train)", masks_radiopedia)
    all_stats.append(df_rad)
    plot_pixel_share("Radiopaedia (train)", masks_radiopedia)
    plot_images_with_class("Radiopaedia (train)", masks_radiopedia)
    plot_area_histograms("Radiopaedia (train)", masks_radiopedia)

if all_stats:
    stats_df = pd.concat(all_stats, ignore_index=True)
    display(stats_df)
    
else:
    print("No masks found. Make sure masks_medseg and/or masks_radiopedia are loaded.")


### 2.2 ETL Process and Data Desitions

With the data analysis and in accordance with the guidelines we determined that for our ETL process:

1) Standardized images

- What: Converted all images to float32, enforced shape (N, 512, 512, 1), and scaled values to [0,1].

- Why: Models train better when inputs share the same scale and shape. This avoids instability from mixed ranges (e.g., 0–255 vs 0–1) and simplifies batching.

2) Prepared masks for the 2 competition classes

- What: From the original 4-channel masks (GGO, Consolidation, Lungs-other, Background), we kept only GGO (0) and Consolidation (1) → masks shaped (N, 512, 512, 2), binary {0,1}.

- Why: The challenge evaluates only these two classes. Dropping the others simplifies the task and metric computation.

3) Prevented “leakage” when splitting Radiopaedia

- What: Radiopaedia contains 9 3D studies (volumes). Without IDs, we inferred volume boundaries by detecting large jumps between consecutive slices. Then we split by complete volumes (some to train, others to validation; ~80/20).

- Why: Slices from the same volume are very similar. Mixing them across train and validation yields overly optimistic validation (the model “recognizes” the study). Volume-wise splitting prevents information leakage.

4) Balanced MedSeg split

- What: For MedSeg, we performed an 80/20 split stratified by has_any_lesion (presence of any GGO/Consolidation). If stratification was impossible (extreme class imbalance), we used a robust fallback that still produces a valid validation set.

- Why: Ensures validation is informative (not nearly all-empty or all-positive) and reflects the test distribution (the test comes from MedSeg).

5) Combined sources for training and validation

- What: Merged the resulting subsets from MedSeg and Radiopaedia to form X_train/Y_train and X_val/Y_val.

- Why: Training with diverse data and validating on a set that includes MedSeg improves generalization toward the official test set.

6) Data augmentation (on-the-fly, not saved)

- What: Defined an on-the-fly augmentation function (flips, small rotation, mild contrast/gamma) applied during training only.

- Why: Shows the model realistic variations, boosting robustness and reducing overfitting—without increasing disk usage (augmented samples are not saved).

7) Saved artifacts

Why: Makes the workflow reproducible and lets the model start training immediately by loading ready-to-use arrays.

In [None]:
# ============================================================
# COVID-CT Segmentation — Full ETL (saves to /kaggle/working/)
# ============================================================
# Steps:
#  1) Load arrays (if not already loaded) from /kaggle/input/covid-segmentation
#  2) Select classes (0=GGO, 1=Consolidation) -> masks (N,H,W,2) in {0,1}
#  3) Normalize images to float32 in [0,1], enforce (N,H,W,1)
#  4) Infer 9 Radiopaedia volumes (no metadata) via slice-to-slice diffs
#  5) Split Radiopaedia by VOLUME (~80/20)
#  6) Split MedSeg with robust stratifier by has_any_lesion (fallbacks if needed)
#  7) Combine splits (MedSeg + Radiopaedia) into train/val
#  8) Save to /kaggle/working/prepared
#  9) Provide on-the-fly augmentation function
# ============================================================

import os, json
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, GroupShuffleSplit


OUT_DIR = '/kaggle/working/prepared'
os.makedirs(OUT_DIR, exist_ok=True)

assert images_radiopedia is not None and masks_radiopedia is not None, "Radiopaedia arrays not found."
assert images_medseg is not None and masks_medseg is not None, "MedSeg arrays not found."
assert test_images_medseg is not None, "MedSeg test array not found."

# -----------------
# Helpers
# -----------------
def ensure_ch1(x: np.ndarray) -> np.ndarray:
    """(N,H,W) or (N,H,W,1) -> float32 (N,H,W,1) in [0,1]."""
    x = x.astype(np.float32)
    if x.ndim == 3:
        x = x[..., None]
    xmin, xmax = float(x.min()), float(x.max())
    if xmin < 0.0 or xmax > 1.0:
        x = (x - xmin) / (xmax - xmin + 1e-8)
    return x

def masks_to_two_classes(m: np.ndarray) -> np.ndarray:
    """Keep only channels 0=GGO and 1=Consolidation -> (N,H,W,2) in {0,1}."""
    if m.ndim == 4:
        m = (m > 0).astype(np.uint8)
        if m.shape[-1] >= 2:
            return m[..., :2]
        raise ValueError(f"Mask has {m.shape[-1]} channels; expected >=2.")
    if m.ndim == 3:  # label-encoded {0..3}
        N,H,W = m.shape
        out = np.zeros((N,H,W,2), dtype=np.uint8)
        out[...,0] = (m == 0).astype(np.uint8)  # GGO
        out[...,1] = (m == 1).astype(np.uint8)  # Consolidation
        return out
    raise ValueError(f"Unexpected mask shape: {m.shape}")

def has_any_lesion(m2: np.ndarray) -> np.ndarray:
    """Per-slice boolean: any positive pixel in either class."""
    return (m2.reshape(m2.shape[0], -1, 2).sum(axis=1).sum(axis=1) > 0).astype(np.uint8)

def robust_medseg_split(has_lesion: np.ndarray, test_size=0.2, random_state=42):
    """
    Returns train_idx, val_idx for MedSeg even if stratification is impossible.
    Strategy:
      1) If both classes exist and each has >=2 samples -> stratified split.
      2) If both classes exist but one has only 1 sample -> force that sample in val; fill rest from majority.
      3) If only one class exists -> plain random split.
    """
    N = len(has_lesion)
    idx = np.arange(N)
    rng = np.random.default_rng(random_state)

    classes, counts = np.unique(has_lesion, return_counts=True)
    count_map = dict(zip(classes.tolist(), counts.tolist()))
    n_pos = count_map.get(1, 0)
    n_neg = count_map.get(0, 0)
    n_val = max(1, int(round(test_size * N)))

    # Case 1: stratified possible
    if n_pos >= 2 and n_neg >= 2:
        tr_idx, va_idx = train_test_split(
            idx, test_size=test_size, random_state=random_state, stratify=has_lesion
        )
        return tr_idx, va_idx

    pos_idx = idx[has_lesion == 1]
    neg_idx = idx[has_lesion == 0]

    # Case 2: both classes exist but one has only 1 sample
    if (n_pos == 1 and n_neg >= 1) or (n_neg == 1 and n_pos >= 1):
        if n_neg == 1:
            must_have = neg_idx
            pool = pos_idx
        else:
            must_have = pos_idx
            pool = neg_idx
        rest = max(0, n_val - len(must_have))
        take = rng.choice(pool, size=min(rest, len(pool)), replace=False) if rest > 0 else np.array([], dtype=int)
        va_idx = np.unique(np.concatenate([must_have, take]))
        if len(va_idx) < n_val:
            remaining = np.setdiff1d(idx, va_idx, assume_unique=False)
            top_up = rng.choice(remaining, size=n_val - len(va_idx), replace=False)
            va_idx = np.unique(np.concatenate([va_idx, top_up]))
        tr_idx = np.setdiff1d(idx, va_idx, assume_unique=False)
        return tr_idx, va_idx

    # Case 3: only one class exists -> random split
    va_idx = rng.choice(idx, size=n_val, replace=False)
    tr_idx = np.setdiff1d(idx, va_idx, assume_unique=False)
    return tr_idx, va_idx

def infer_radiopaedia_volume_ids(images_: np.ndarray, K: int = 9):
    """
    Infer K volume blocks from slice-to-slice diffs (no metadata).
    Returns:
      volume_id: (N,) int32 in {0..K-1}
      ranges: list of (start, end) indices per volume
      diffs: (N-1,) mean abs diffs used for boundary detection
    """
    X = images_
    if X.ndim == 4: X = X[...,0]
    X = X.astype(np.float32)
    N = len(X)
    diffs = np.array([np.mean(np.abs(X[i] - X[i-1])) for i in range(1, N)], dtype=np.float32)
    num_boundaries = K - 1
    # If N is small or diffs flat, fallback to equal chunks
    if N <= K or np.allclose(diffs, diffs[0]):
        chunks = np.array_split(np.arange(N), K)
        ranges = [(int(c[0]), int(c[-1])+1) for c in chunks]
    else:
        boundary_pos = np.argsort(diffs)[-num_boundaries:] + 1
        boundary_pos = np.sort(boundary_pos)
        starts = np.concatenate([[0], boundary_pos])
        ends   = np.concatenate([boundary_pos, [N]])
        ranges = list(zip(starts, ends))
        # guard against tiny chunks
        if any((e - s) < 10 for s, e in ranges):
            chunks = np.array_split(np.arange(N), K)
            ranges = [(int(c[0]), int(c[-1])+1) for c in chunks]
    vol_id = np.empty(N, dtype=np.int32)
    for vid, (s, e) in enumerate(ranges):
        vol_id[s:e] = vid
    return vol_id, ranges, diffs

# -----------------
# 1) Normalize & select classes
# -----------------
X_med = ensure_ch1(images_medseg)               # (100,512,512,1) in [0,1]
Y_med = masks_to_two_classes(masks_medseg)      # (100,512,512,2)

X_rad = ensure_ch1(images_radiopedia)           # (829,512,512,1)
Y_rad = masks_to_two_classes(masks_radiopedia)  # (829,512,512,2)

X_test = ensure_ch1(test_images_medseg)         # (10,512,512,1)

# -----------------
# 2) Radiopaedia: split by inferred volumes (K=9)
# -----------------
volume_id, vol_ranges, _diffs = infer_radiopaedia_volume_ids(images_radiopedia, K=9)
gss = GroupShuffleSplit(test_size=0.2, n_splits=1, random_state=42)
(rad_tr_idx, rad_val_idx), = gss.split(np.arange(len(X_rad)), groups=volume_id)

# -----------------
# 3) MedSeg: robust stratified split by has_any_lesion (80/20)
# -----------------
y_med_has = has_any_lesion(Y_med)
med_tr_idx, med_val_idx = robust_medseg_split(y_med_has, test_size=0.2, random_state=42)

# -----------------
# 4) Combine splits
# -----------------
X_train = np.concatenate([X_med[med_tr_idx], X_rad[rad_tr_idx]], axis=0)
Y_train = np.concatenate([Y_med[med_tr_idx], Y_rad[rad_tr_idx]], axis=0)

X_val   = np.concatenate([X_med[med_val_idx], X_rad[rad_val_idx]], axis=0)
Y_val   = np.concatenate([Y_med[med_val_idx], Y_rad[rad_val_idx]], axis=0)

print("Train shapes:", X_train.shape, Y_train.shape)
print("Val   shapes:", X_val.shape,   Y_val.shape)
val_sources = (["MedSeg"] * len(med_val_idx)) + (["Radiopaedia"] * len(rad_val_idx))
print("Validation source composition:\n", pd.Series(val_sources).value_counts(normalize=True))

# -----------------
# 5) Save prepared artifacts to /kaggle/working/prepared
# -----------------
np.savez_compressed(os.path.join(OUT_DIR, "train_arrays.npz"), X=X_train, Y=Y_train)
np.savez_compressed(os.path.join(OUT_DIR, "val_arrays.npz"),   X=X_val,   Y=Y_val)
np.savez_compressed(os.path.join(OUT_DIR, "test_medseg.npz"),  X=X_test)

pd.DataFrame({"rad_train_idx": rad_tr_idx}).to_csv(os.path.join(OUT_DIR, "radiopaedia_train_idx.csv"), index=False)
pd.DataFrame({"rad_val_idx":   rad_val_idx}).to_csv(os.path.join(OUT_DIR, "radiopaedia_val_idx.csv"),   index=False)
pd.DataFrame({"med_train_idx": med_tr_idx}).to_csv(os.path.join(OUT_DIR, "medseg_train_idx.csv"),       index=False)
pd.DataFrame({"med_val_idx":   med_val_idx}).to_csv(os.path.join(OUT_DIR, "medseg_val_idx.csv"),        index=False)

with open(os.path.join(OUT_DIR, "info.json"), "w") as f:
    json.dump({
        "images_shape": list(X_train.shape[1:]),
        "masks_shape":  list(Y_train.shape[1:]),
        "train_size":   int(len(X_train)),
        "val_size":     int(len(X_val)),
        "class_map":    {"0": "GGO", "1": "Consolidation"},
        "notes": "Images normalized to [0,1]; masks are binary (2 channels). Radiopaedia split by inferred volumes (9)."
    }, f, indent=2)

print(f"Saved artifacts to: {OUT_DIR}")
print("Files:", sorted(os.listdir(OUT_DIR)))

# -----------------
# 6) On-the-fly augmentations (use inside your training loop)
# -----------------
import cv2
_rng = np.random.default_rng(123)

def augment_pair(img: np.ndarray, mask: np.ndarray):
    """
    img:  (H,W,1) float32 in [0,1]
    mask: (H,W,2) uint8 {0,1}
    returns augmented (img, mask)
    """
    H, W = img.shape[:2]

    # Horizontal flip (p=0.5)
    if _rng.random() < 0.5:
        img  = np.flip(img, axis=1).copy()
        mask = np.flip(mask, axis=1).copy()

    # Vertical flip (p=0.2)
    if _rng.random() < 0.2:
        img  = np.flip(img, axis=0).copy()
        mask = np.flip(mask, axis=0).copy()

    # Small rotation (-12..+12 deg)
    angle = float(_rng.uniform(-12, 12))
    M = cv2.getRotationMatrix2D((W/2, H/2), angle, 1.0)
    img  = cv2.warpAffine(img, M, (W, H), flags=cv2.INTER_LINEAR,  borderMode=cv2.BORDER_REFLECT)
    m0   = cv2.warpAffine(mask[...,0].astype(np.uint8), M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT)
    m1   = cv2.warpAffine(mask[...,1].astype(np.uint8), M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT)
    mask = np.stack([m0, m1], axis=-1)

    # Contrast / gamma jitter
    gamma = float(_rng.uniform(0.85, 1.2))
    img = np.clip(img, 0, 1) ** gamma

    # Ensure types and dims
    if img.ndim == 2:
        img = img[..., None]
    return img.astype(np.float32), (mask > 0.5).astype(np.uint8)

print("augment_pair(img, mask) ready.")


In [None]:
import numpy as np
import pandas as pd

def has_any_lesion_bin(Y):
    # Y: (N,H,W,2) uint8 {0,1}
    return (Y.reshape(Y.shape[0], -1, 2).sum(axis=1).sum(axis=1) > 0)

print("=== SHAPES & DTYPES ===")
print(f"X_train: {X_train.shape}  dtype={X_train.dtype}  range=({X_train.min():.3f},{X_train.max():.3f})")
print(f"Y_train: {Y_train.shape}  dtype={Y_train.dtype}  values={{0,1}}")
print(f"X_val  : {X_val.shape}    dtype={X_val.dtype}    range=({X_val.min():.3f},{X_val.max():.3f})")
print(f"Y_val  : {Y_val.shape}    dtype={Y_val.dtype}    values={{0,1}}")

print("\n=== PER-SET SLICE COUNTS ===")
print(f"Train slices: {len(X_train)}")
print(f"Val   slices: {len(X_val)}")

print("\n=== LESION PRESENCE (has_any_lesion) ===")
train_has = has_any_lesion_bin(Y_train)
val_has   = has_any_lesion_bin(Y_val)
print("Train: ",
      pd.Series(train_has.astype(int)).map({0:"no-lesion",1:"lesion"}).value_counts().to_dict())
print("Val  : ",
      pd.Series(val_has.astype(int)).map({0:"no-lesion",1:"lesion"}).value_counts().to_dict())

try:
    train_sources = (["MedSeg"]*len(med_tr_idx)) + (["Radiopaedia"]*len(rad_tr_idx))
    val_sources   = (["MedSeg"]*len(med_val_idx)) + (["Radiopaedia"]*len(rad_val_idx))
    print("\n=== SOURCE COMPOSITION ===")
    print("Train:", pd.Series(train_sources).value_counts().to_dict())
    print("Val  :", pd.Series(val_sources).value_counts().to_dict())
except NameError:
    print("\n(Source indices not available in this scope.)")

print("\n=== QUICK SANITY CHECKS ===")
# Check few random items for non-empty masks and per-class pixel share
rng = np.random.default_rng(0)
for name, Xs, Ys in [("TRAIN", X_train, Y_train), ("VAL", X_val, Y_val)]:
    idx = rng.choice(len(Xs), size=min(3, len(Xs)), replace=False)
    for i in idx:
        ggo_px  = int(Ys[i,...,0].sum())
        cons_px = int(Ys[i,...,1].sum())
        total   = Ys[i].shape[0]*Ys[i].shape[1]
        print(f"{name} idx={i:4d} | GGO={ggo_px/total*100:.2f}%  Cons={cons_px/total*100:.2f}%  any={ggo_px+cons_px>0}")


In [None]:
import os, numpy as np, matplotlib.pyplot as plt

def _get_arrays():
    if "X_train" in globals() and "Y_train" in globals() and "X_val" in globals() and "Y_val" in globals():
        return X_train, Y_train, X_val, Y_val
    prep = "/kaggle/working/prepared"
    tr = np.load(os.path.join(prep, "train_arrays.npz"))
    va = np.load(os.path.join(prep, "val_arrays.npz"))
    return tr["X"], tr["Y"], va["X"], va["Y"]

X_tr, Y_tr, X_va, Y_va = _get_arrays()

def overlay_masks(img_hw1, mask_hw2, alpha=0.35):
    """
    img_hw1: (H,W,1) float32 [0,1]
    mask_hw2: (H,W,2) uint8 {0,1}  [0]=GGO (verde), [1]=Consolidation (rojo)
    return: (H,W,3) float32 [0,1]
    """
    img = np.clip(img_hw1[..., 0], 0, 1)
    H, W = img.shape
    rgb = np.stack([img, img, img], axis=-1)

    ggo  = mask_hw2[..., 0] > 0
    cons = mask_hw2[..., 1] > 0

    out = rgb.copy()
    out[ggo]  = out[ggo]  * (1 - alpha) + np.array([0.0, 1.0, 0.0]) * alpha
    out[cons] = out[cons] * (1 - alpha) + np.array([1.0, 0.0, 0.0]) * alpha
    return np.clip(out, 0, 1)

def show_grid(X, Y, title="SET", n=8, seed=0, cols=4):
    n = min(n, len(X))
    idxs = np.random.default_rng(seed).choice(len(X), size=n, replace=False)
    rows = int(np.ceil(n / cols))
    plt.figure(figsize=(4*cols, 4*rows))
    for k, i in enumerate(idxs, 1):
        ov = overlay_masks(X[i], Y[i], alpha=0.35)
        ggo_px  = int(Y[i, ..., 0].sum())
        cons_px = int(Y[i, ..., 1].sum())
        total   = Y[i].shape[0] * Y[i].shape[1]
        plt.subplot(rows, cols, k)
        plt.imshow(ov)
        plt.title(f"{title} idx={i} | GGO {100*ggo_px/total:.1f}%  Cons {100*cons_px/total:.1f}%")
        plt.axis("off")
    plt.tight_layout()
    plt.show()

show_grid(X_tr, Y_tr, title="TRAIN", n=8, seed=0, cols=4)
show_grid(X_va, Y_va, title="VAL",   n=8, seed=1, cols=4)


During ETL we never change the spatial grid or the pairing logic—only the mask channels:

- **Same index, same slice:** We construct X_* (images) and Y_* (masks) using the same slice indices for each split/source. Thus X_train[i] and Y_train[i] (and likewise for validation) come from the exact same CT slice.

- **Same resolution and geometry:** Both images and masks keep the original 512×512 grid. We do not resample, crop, or pad in the ETL; the only operation on masks is channel selection (keep 0=GGO, 1=Consolidation; drop the other two). This preserves all pixel coordinates.

- **No independent shuffling:** We split by indices (and, for Radiopaedia, by inferred volume groups) and then apply those indices simultaneously to images and masks. There is never a reordering of images without applying the same reordering to masks.

- **Identical preprocessing path:** Images are normalized to [0,1] and coerced to (H,W,1); masks are cast to binary (H,W,2) by channel slicing. No geometric transforms occur in ETL, so there is no chance to misalign image and mask.

## 3. Model

### 3.1 Why U-Net for this task

After consider between 3 sementation models (U-Net, DeepLab and SegNet), we decide use U-Net, desgined in 2015 by Ronnenber O., Fischer P. and Brox T. at the University of Freiburg, Germany. It follows a symmetric encoder–decoder (“U-shaped”) design with skip connections that copy high-resolution features from the encoder to the decoder, enabling precise boundary localization even with limited training data.U-Net were used in order to resolve a segmentation problem, similar to our competition, U-net result with some advatages like:

- **Built for medical imaging:** U-Net was designed for biomedical segmentation and is known to work well with small labeled datasets by leveraging strong skip connections and data augmentation. 

- **Simple, fast, reliable:** It’s easy to train end-to-end, delivers sharp boundaries via encoder–decoder skips, and has a very large, supportive community/ecosystem. 

- **Good fit for our data scale:** Compared with heavier designs (e.g., DeepLab), plain U-Net typically attains strong performance on limited medical data with modest compute.


### 3.2 Our architecture configuration

1) Input & preprocessing

    - What: CT input (1, H, W) scaled to [0,1]; train at 256×256 (from ETL). Validation/inference can be 512×512 if the model/VRAM allow.
    
    - Why: 256×256 speeds training and fits VRAM; keep a consistent pipeline across train/val/test.
    
    - Alternatives: Train directly at 512×512 (more detail, smaller batch) or use patching.

2) Depth (4-down / 4-up)

    - What: 4 encoder stages with 2×2 max-pool (down to 1/16 res) and 4 decoder stages with skip connections.
    
    - Why: Good trade-off between context and detail for 256×256 / 512×512 with moderate VRAM.
    
    - Alternatives: 3-down (lighter, less context) or 5-down (more capacity/VRAM).

3) Channel width (64 → 1024 → 64)

    - What: Stage widths: 64, 128, 256, 512; bottleneck 1024; mirrored in decoder.
    
    - Why: The classic and stable recipe for small medical datasets.
    
    - Alternatives: Halved widths (32–512) to save VRAM; larger widths if data/compute allow.

4) Basic block: Conv3×3 + BatchNorm + ReLU (×2 per stage)

    - What: Two 3×3, stride=1, padding=1 convolutions with BN and ReLU at each level.
    
    - Why: Proven pattern—stable training and sharp boundaries.
    
    - Alternatives: Instance/GroupNorm (very small batches), LeakyReLU/SiLU, residual (ResU-Net).

5) Downsampling: MaxPool 2×2 (stride 2)

    - What: Resolution reduction after each encoder block (except the bottom).
    
    - Why: Simple, robust; preserves scale semantics.
    
    - Alternatives: Strided conv (s=2) or AvgPool.

6) Upsampling: Bilinear ×2 + Conv2×2, with concat skip

    - What: ×2 interpolation, concatenate encoder skip, then two Conv+BN+ReLU.
    
    - Why: Avoids checkerboard artifacts, lightweight, works very well.
    
    - Alternatives: TransposedConv 2×2 (learned upsampling; slightly heavier).

7) Output head: Conv1×1 → 2 logits, Sigmoid at inference

    - What: Map 64 channels to 2 (GGO, Consolidation); apply per-channel Sigmoid at inference.
    
    - Why: This is a per-pixel multilabel problem (classes are not mutually exclusive).
    
    - Alternatives: Softmax for mutually exclusive classes (not our case).

8) Loss function: BCEWithLogits + Dice (mean over channels)

    - What: Combine BCE (on logits) with Dice.
    
    - Why: Dice copes with class imbalance; BCE stabilizes gradients.
    
    - Alternatives: Tversky (α,β) to bias FN/FP, Focal (often with Dice) for many negatives.

9) Optimizer & learning rate

    - What: AdamW (lr 1e-3, weight decay 1e-4), cosine annealing LR (T_max = epochs).
    
    - Why: Fast convergence and good generalization; cosine is simple and effective.
    
    - Alternatives: Adam (no WD), ReduceLROnPlateau (patience 5, factor 0.5), SGD+momentum (better with lots of data).
  
### 3.3 Regulatizarion Techniques

Our implementation of the model, we applied L2 regularization, also known as ridge regression through the optimizer AdamW from PyTorch. Code line:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
Within those parameters, we add the parameters lr = 1e-3 (learning rate) and weight_decay=1e-4.  This technique penalizes large weight values by calculating the square of the weights during training, encouraging the network to maintain smaller and more stable parameters, and reducing the risk of overfitting.
In our implementation, the learning_rate parameter is established to control the size of the weight training updates. According to the PyTorch AdamW documentation, lr = 1e-3 is the default setting that works best for the optimizer, so we chose not to move it. As a higher learning rate speed would speed up the training but make it unstable for our model. And a lower training rate would make the model more stable, but slower at learning.

we chose our parameter 1e-4 for weight_decay, which can be tuned to control the strength coefficient of the weight penalties. A smaller value (1e-5) would cause weaker regularization and not show much change, while a larger value (1e-3), would apply stronger penalties that would negatively affect the model. We had to tune the parameter closely to ensure that it reduced overfitting while not being too high. This balance contributed to stabilizing the training process and improving validation performance in our U-Net model.

Also we use augmentation as showed in the previous blocks, this is part of a regularization technique because it helps to prevent overfitting increasing the divertsity of data by an artifitial samples generation, changing charactetistics of the data like rotatiom, changing color, contrast, illumination or size characteristics.

## 4. Training

### 4.1 Hardware

Our model were trained with the buil-in Kaggle resources:
- CPU: Intel Xeon 2.20 GHz, 4 vCPU cores.
- RAM: 32 GB.
- GPU: NVIDIA Tesla P100 GPU, 3584 Cuda cores, 16 GB.

We the available resources we get a run time of 2 hour approximately.

### 4.2 U-Net configuration


- SEED: 42

- Expected shapes: img: (H,W,1), mask: (H,W,2)

- Augmentations: augment_pair method (declared above).

- Batch size: 8

- num_workers: 2

- Input channels: 1

- Classes: 2

- Depth: classic 4 downs / 4 ups

- Down path: 64 → 128 → 256 → 512 → bottleneck 1024

- Up path with skip-concat: (1024+512)→512, (512+256)→256, (256+128)→128, (128+64)→64

- Upsampling: bilinear upsample (align_corners=False)

- Blocks: Conv(3×3, bias=False) + BatchNorm2d + ReLU ×2 per block

- Training loss: BCEWithLogitsLoss + soft Dice (per-class, eps 1e-7)

- Optimizer: AdamW(lr=1e-3, weight_decay=1e-4)
(betas default to PyTorch: (0.9, 0.999), eps=1e-8)

- Epochs: 60

- Early stop: not enabled; trains full 60 epochs while tracking best

In [None]:
#%%skip_if True
# ============================
# U-Net (2D) + Training & Eval
# ============================
import os, math, time, random, contextlib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# --------------------
# Reproducibility
# --------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# --------------------
# Device & AMP (new API; only if CUDA)
# --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_CUDA = (device.type == "cuda")
print("Device:", device)

if USE_CUDA:
    scaler = torch.amp.GradScaler('cuda')
    autocast_ctx = lambda: torch.amp.autocast('cuda', dtype=torch.float16)
else:
    scaler = None
    autocast_ctx = contextlib.nullcontext  # no-op on CPU

# --------------------
# Dataset (uses your augment_pair from ETL if provided)
# --------------------
class NumpySegDataset(Dataset):
    def __init__(self, X, Y=None, train=False, apply_aug_fn=None):
        self.X = X; self.Y = Y
        self.train = train; self.apply_aug_fn = apply_aug_fn
    def __len__(self): return len(self.X)
    def __getitem__(self, i):
        img = self.X[i]  # (H,W,1)
        if self.Y is None:
            if self.apply_aug_fn is not None and self.train:
                img, _ = self.apply_aug_fn(img, np.zeros((*img.shape[:2], 2), dtype=np.uint8))
            x = torch.from_numpy(img).permute(2,0,1).float()
            return x
        mask = self.Y[i] # (H,W,2)
        if self.apply_aug_fn is not None and self.train:
            img, mask = self.apply_aug_fn(img, mask)
        x = torch.from_numpy(img).permute(2,0,1).float()
        y = torch.from_numpy(mask).permute(2,0,1).float()
        return x, y

# --------------------
# U-Net (classic, 4-down/4-up, 64..1024)
# --------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x): return self.conv(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
            self.conv = DoubleConv(in_ch, out_ch)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
            self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        dy = x2.size(2) - x1.size(2)
        dx = x2.size(3) - x1.size(3)
        x1 = F.pad(x1, [dx//2, dx - dx//2, dy//2, dy - dy//2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)
    def forward(self, x): return self.conv(x)

class UNet2D(nn.Module):
    def __init__(self, n_channels=1, n_classes=2, bilinear=True):
        super().__init__()
        self.inc   = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1   = Up(1024+512, 512, bilinear)
        self.up2   = Up(512+256, 256, bilinear)
        self.up3   = Up(256+128, 128, bilinear)
        self.up4   = Up(128+64,  64,  bilinear)
        self.outc  = OutConv(64, n_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x  = self.up1(x5, x4)
        x  = self.up2(x,  x3)
        x  = self.up3(x,  x2)
        x  = self.up4(x,  x1)
        return self.outc(x)  # (B,2,H,W)

# --------------------
# Loss: BCEWithLogits + Dice
# --------------------
class BCEDiceLoss(nn.Module):
    def __init__(self, eps=1e-7):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.eps = eps
    def dice_soft(self, logits, targets):
        probs = torch.sigmoid(logits)
        dims = (0,2,3)
        inter = (probs * targets).sum(dims)
        den   = (probs + targets).sum(dims).clamp_min(self.eps)
        dice  = (2*inter) / den
        return 1 - dice.mean()
    def forward(self, logits, targets):
        return self.bce(logits, targets) + self.dice_soft(logits, targets)

# --------------------
# DataLoaders (hooks your arrays & augment_pair)
# --------------------
try:
    X_train, Y_train, X_val, Y_val, X_test  # noqa
except NameError:
    prep = "/kaggle/working/prepared"
    tr = np.load(os.path.join(prep, "train_arrays.npz"))
    va = np.load(os.path.join(prep, "val_arrays.npz"))
    te = np.load(os.path.join(prep, "test_medseg.npz"))
    X_train, Y_train = tr["X"], tr["Y"]
    X_val,   Y_val   = va["X"], va["Y"]
    X_test          = te["X"]
    print("Loaded arrays from prepared/.")

if "augment_pair" not in globals():
    def augment_pair(img, mask): return img, mask  # identity

BATCH = 8
train_ds = NumpySegDataset(X_train, Y_train, train=True,  apply_aug_fn=augment_pair)
val_ds   = NumpySegDataset(X_val,   Y_val,   train=False, apply_aug_fn=None)
pin = USE_CUDA  # only pin on CUDA
train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=2, pin_memory=pin)
val_dl   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=pin)

# --------------------
# Model, optimizer, schedule
# --------------------
model = UNet2D(n_channels=1, n_classes=2, bilinear=True).to(device)
criterion = BCEDiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
EPOCHS = 60
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# --------------------
# Metrics
# --------------------
@torch.no_grad()
def dice_score(logits, targets, thr=0.5, eps=1e-7):
    probs = torch.sigmoid(logits)
    preds = (probs > thr).float()
    dims = (0,2,3)
    inter = (preds * targets).sum(dims)
    den   = (preds + targets).sum(dims).clamp_min(eps)
    dice_per_ch = (2*inter) / den
    return dice_per_ch.detach().cpu().numpy(), float(dice_per_ch.mean().item())

# =====================
# Metrics + Viz helpers
# =====================
import copy, matplotlib.pyplot as plt, pandas as pd

THR = 0.45  # threshold for binarizing probs (tune later)

@torch.no_grad()
def evaluate_epoch(model, loader, criterion, thr=0.5, device=device, use_cuda=USE_CUDA):
    """
    Returns:
      loss_avg, dice_mean, acc_mean, conf (dict with 2x2 matrices per class)
    Conf matrices are in sklearn order: [[TN, FP],[FN, TP]]
    """
    model.eval()
    n_pix = 0
    loss_sum = 0.0

    # Aggregates for Dice
    inter_sum = np.zeros(2, dtype=np.float64)
    den_sum   = np.zeros(2, dtype=np.float64)

    # Aggregates for accuracy & confusion
    TP = np.zeros(2, dtype=np.float64)
    FP = np.zeros(2, dtype=np.float64)
    FN = np.zeros(2, dtype=np.float64)
    TN = np.zeros(2, dtype=np.float64)

    for xb, yb in loader:
        xb = xb.to(device, non_blocking=use_cuda)
        yb = yb.to(device, non_blocking=use_cuda)  # (B,2,H,W)

        logits = model(xb)
        loss = criterion(logits, yb)
        loss_sum += loss.item() * xb.size(0)

        probs = torch.sigmoid(logits)
        preds = (probs > thr).float()

        # Dice aggregations
        inter = (preds * yb).sum(dim=(0,2,3))       # (2,)
        den   = (preds + yb).sum(dim=(0,2,3))       # (2,)
        inter_sum += inter.detach().cpu().numpy()
        den_sum   += den.detach().cpu().numpy()

        # Accuracy + confusion per class (binary, per pixel)
        for c in range(2):
            p = preds[:, c]
            t = yb[:, c]
            TP[c] += (p * t).sum().item()
            FP[c] += (p * (1 - t)).sum().item()
            FN[c] += ((1 - p) * t).sum().item()
            TN[c] += ((1 - p) * (1 - t)).sum().item()

        n_pix += yb.numel() // 2  # total pixels per class summed over batch

    # Dice per-class and mean
    dice_pc = (2 * inter_sum + 1e-7) / (den_sum + 1e-7)
    dice_mean = float(np.nanmean(dice_pc))

    # Pixel accuracy per class and mean
    acc_pc = (TP + TN) / (TP + FP + FN + TN + 1e-7)
    acc_mean = float(np.nanmean(acc_pc))

    loss_avg = loss_sum / len(loader.dataset)

    conf = {
        "GGO": np.array([[TN[0], FP[0]], [FN[0], TP[0]]], dtype=np.float64),
        "CONS": np.array([[TN[1], FP[1]], [FN[1], TP[1]]], dtype=np.float64),
        "dice_pc": dice_pc, "acc_pc": acc_pc,
    }
    return loss_avg, dice_mean, acc_mean, conf

def plot_training_curves(history, out_dir="/kaggle/working"):
    epochs = history["epoch"]
    # Accuracy
    plt.figure(figsize=(7,5))
    plt.plot(epochs, history["train_acc"], label="Train acc")
    plt.plot(epochs, history["val_acc"],   label="Val acc")
    plt.xlabel("Epoch"); plt.ylabel("Pixel accuracy (mean over classes)")
    plt.title("Accuracy — train vs val")
    plt.legend(); plt.tight_layout()
    plt.savefig(f"{out_dir}/accuracy.png", dpi=150)
    plt.show()

    # Loss
    plt.figure(figsize=(7,5))
    plt.plot(epochs, history["train_loss"], label="Train loss")
    plt.plot(epochs, history["val_loss"],   label="Val loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.title("Loss — train vs val")
    plt.legend(); plt.tight_layout()
    plt.savefig(f"{out_dir}/loss.png", dpi=150)
    plt.show()

def plot_confusions(conf, out_path="/kaggle/working/confusion.png"):
    fig, axs = plt.subplots(1, 2, figsize=(10,4))
    titles = ["GGO", "Consolidation"]
    mats = [conf["GGO"], conf["CONS"]]
    for ax, title, M in zip(axs, titles, mats):
        im = ax.imshow(M, interpolation="nearest")
        ax.set_title(f"{title} — Confusion")
        ax.set_xticks([0,1]); ax.set_yticks([0,1])
        ax.set_xticklabels(["TN","FP"]); ax.set_yticklabels(["FN","TP"])
        # annotate
        for i in range(2):
            for j in range(2):
                ax.text(j, i, f"{int(M[i,j])}", ha="center", va="center")
    fig.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.show()

# ======================================
# Train / Validate
# ======================================
best_val = -1.0
ckpt_path = "/kaggle/working/best_unet.pt"

# extra loader to evaluate TRAIN without augmentations
train_eval_ds = NumpySegDataset(X_train, Y_train, train=False, apply_aug_fn=None)
train_eval_dl = DataLoader(train_eval_ds, batch_size=BATCH, shuffle=False,
                           num_workers=2, pin_memory=USE_CUDA)

history = {"epoch": [], "train_loss": [], "val_loss": [],
           "train_dice": [], "val_dice": [],
           "train_acc": [],  "val_acc": []}

for epoch in range(1, EPOCHS+1):
    # ------- TRAIN -------
    model.train()
    tr_loss = 0.0
    for xb, yb in train_dl:
        xb = xb.to(device, non_blocking=USE_CUDA)
        yb = yb.to(device, non_blocking=USE_CUDA)
        optimizer.zero_grad(set_to_none=True)
        with autocast_ctx():
            logits = model(xb)
            loss = criterion(logits, yb)
        if USE_CUDA:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        tr_loss += loss.item() * xb.size(0)
    tr_loss /= len(train_ds)

    # ------- EVAL (train & val) -------
    train_loss_eval, train_dice, train_acc, _      = evaluate_epoch(model, train_eval_dl, criterion, thr=THR)
    val_loss_eval,   val_dice,   val_acc,   val_c  = evaluate_epoch(model, val_dl,       criterion, thr=THR)

    # step scheduler (keep your cosine)
    scheduler.step()

    # track best on val Dice
    if val_dice > best_val:
        best_val = val_dice
        torch.save(model.state_dict(), ckpt_path)

    # ---- record & print ----
    history["epoch"].append(epoch)
    history["train_loss"].append(tr_loss)
    history["val_loss"].append(val_loss_eval)
    history["train_dice"].append(train_dice)
    history["val_dice"].append(val_dice)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch:03d}/{EPOCHS} | lr={scheduler.get_last_lr()[0]:.2e} | "
          f"train_loss={tr_loss:.4f} val_loss={val_loss_eval:.4f} "
          f"train_acc={train_acc:.4f} val_acc={val_acc:.4f} "
          f"train_dice={train_dice:.4f} val_dice={val_dice:.4f} best={best_val:.4f}")

# Save final history & plots
hist_df = pd.DataFrame(history)
hist_csv = "/kaggle/working/history.csv"
hist_df.to_csv(hist_csv, index=False)
print("Saved training history to:", hist_csv)
plot_training_curves(history)
best_state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(best_state)
_, _, _, val_conf = evaluate_epoch(model, val_dl, criterion, thr=THR)
plot_confusions(val_conf, out_path="/kaggle/working/confusion.png")


print("Best val Dice:", best_val)
print("Saved best weights to:", ckpt_path)
print("Figures saved to: /kaggle/working/accuracy.png, /kaggle/working/loss.png, /kaggle/working/confusion.png")

## 5. Results

In [None]:
#%%skip_if True

import numpy as np
import torch

@torch.no_grad()
def eval_metrics(val_loader, thr=0.45, eps=1e-7, count_empty_as=1.0):
    
    model.eval()

    inter = np.zeros(2, dtype=np.float64)
    union = np.zeros(2, dtype=np.float64)
    den_dice = np.zeros(2, dtype=np.float64)
    tp = np.zeros(2, dtype=np.float64)
    fp = np.zeros(2, dtype=np.float64)
    fn = np.zeros(2, dtype=np.float64)
    tn = np.zeros(2, dtype=np.float64)

    empty_cases = np.zeros(2, dtype=np.int64)
    used_cases  = np.zeros(2, dtype=np.int64)

    for xb, yb in val_loader:
        xb = xb.to(device); yb = yb.to(device)           # yb: (B,2,H,W) {0,1}
        logits = model(xb)
        probs = torch.sigmoid(logits)
        preds = (probs > thr).float()

        for c in range(2):
            p = preds[:, c]
            t = yb[:, c]

            tp[c] += (p * t).sum().item()
            fp[c] += (p * (1 - t)).sum().item()
            fn[c] += ((1 - p) * t).sum().item()
            tn[c] += ((1 - p) * (1 - t)).sum().item()

            inter_c = (p * t).sum(dim=(1,2))
            sum_c   = (p + t).sum(dim=(1,2))
            union_c = ((p + t) > 0).sum(dim=(1,2))
            both_empty = (sum_c == 0)

            if count_empty_as is None:
                mask = ~both_empty
                inter[c]   += inter_c[mask].sum().item()
                den_dice[c] += (sum_c[mask]).sum().item()
                union[c]   += ((p[mask] + t[mask]) > 0).sum().item()
                used_cases[c] += int(mask.sum().item())
            else:
                inter[c]   += inter_c.sum().item()
                den_dice[c] += sum_c.sum().item()
                union[c]   += ((p + t) > 0).sum().item()
                empty_cases[c] += int(both_empty.sum().item())
                used_cases[c]  += int(len(both_empty))

    # Metrics per class
    dice = np.zeros(2, dtype=np.float64)
    iou  = np.zeros(2, dtype=np.float64)
    acc  = np.zeros(2, dtype=np.float64)

    for c in range(2):
        if count_empty_as is None:
            dice[c] = (2*inter[c] + eps) / (den_dice[c] + eps) if used_cases[c] > 0 else np.nan
        else:
            dice[c] = (2*inter[c] + eps) / (den_dice[c] + eps)

        iou[c] = dice[c] / (2 - dice[c]) if np.isfinite(dice[c]) else np.nan

        # Accuracy per píxel
        total = tp[c] + fp[c] + fn[c] + tn[c]
        acc[c] = (tp[c] + tn[c]) / total if total > 0 else np.nan

    metrics = {
        "dice_ggo": float(dice[0]), "dice_cons": float(dice[1]),
        "iou_ggo": float(iou[0]),   "iou_cons":  float(iou[1]),
        "acc_ggo": float(acc[0]),   "acc_cons":  float(acc[1]),
        "dice_mean": float(np.nanmean(dice)),
        "iou_mean":  float(np.nanmean(iou)),
        "acc_mean":  float(np.nanmean(acc)),
        "empty_cases_per_class": {"ggo": int(empty_cases[0]), "cons": int(empty_cases[1])},
        "used_images_per_class": {"ggo": int(used_cases[0]), "cons": int(used_cases[1])},
        "thr": float(thr),
        "confusion": {
            "ggo": {"tp": int(tp[0]), "fp": int(fp[0]), "fn": int(fn[0]), "tn": int(tn[0])},
            "cons":{"tp": int(tp[1]), "fp": int(fp[1]), "fn": int(fn[1]), "tn": int(tn[1])},
        }
    }
    return metrics

def print_metrics_table(metrics, save_csv="/kaggle/working/val_metrics_summary.csv"):
    # Filas per class
    rows = []
    rows.append({
        "class": "ggo",
        "dice": metrics["dice_ggo"],
        "iou":  metrics["iou_ggo"],
        "acc":  metrics["acc_ggo"],
    })
    rows.append({
        "class": "cons",
        "dice": metrics["dice_cons"],
        "iou":  metrics["iou_cons"],
        "acc":  metrics["acc_cons"],
    })
    # Mean Rows
    rows.append({
        "class": "mean",
        "dice": metrics["dice_mean"],
        "iou":  metrics["iou_mean"],
        "acc":  metrics["acc_mean"],
    })

    df = pd.DataFrame(rows, columns=["class","dice","iou","acc"])

    df_fmt = df.copy()
    for col in ["dice","iou","acc"]:
        df_fmt[col] = df_fmt[col].apply(lambda x: np.nan if x is None else float(x))
        df_fmt[col] = df_fmt[col].map(lambda v: f"{v:.3f}" if pd.notnull(v) else "nan")

    thr = metrics.get("thr", None)
    print("="*46)
    print(f" Validation metrics table (thr={thr}) ".center(46, "="))
    print("="*46)
    print(df_fmt.to_string(index=False))
    print("="*46)

    df.to_csv(save_csv, index=False)
    print("Saved table CSV ->", save_csv)

metrics = eval_metrics(val_dl, thr=0.45, count_empty_as=1.0)
print_metrics_table(metrics)


In [None]:
#%%skip_if True

import os, cv2, matplotlib.pyplot as plt
import numpy as np
import torch

best_state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(best_state)
model.eval()


def render_examples(val_loader, out_dir="/kaggle/working/figs", n_samples=8, thr=0.45, show=True):
    os.makedirs(out_dir, exist_ok=True)
    model.eval()
    saved = 0

    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device); yb = yb.to(device)
            probs = torch.sigmoid(model(xb))
            preds = (probs > thr).float()

            for i in range(xb.size(0)):
                if saved >= n_samples:
                    return
                img = xb[i,0].cpu().numpy()                       # (H,W)
                gt  = yb[i].cpu().numpy().astype(np.uint8)        # (2,H,W)
                pr  = preds[i].cpu().numpy().astype(np.uint8)     # (2,H,W)

                # Dice por imagen/clase
                def dice_np(p, t, eps=1e-7):
                    inter = (p & t).sum()
                    den   = (p | t).sum() if (p.max()==1 and t.max()==1) else (p + t).sum()
                    return (2*inter + eps)/(den + eps) if den > 0 else 1.0
                d_ggo  = dice_np(pr[0].astype(bool), gt[0].astype(bool))
                d_cons = dice_np(pr[1].astype(bool), gt[1].astype(bool))

                # Overlays (verde = GGO, azul = Cons)
                img8 = (img*255).clip(0,255).astype(np.uint8)
                gt_ov = cv2.cvtColor(img8, cv2.COLOR_GRAY2BGR)
                pr_ov = gt_ov.copy()
                ggo_gt, cons_gt = gt[0].astype(bool),  gt[1].astype(bool)
                ggo_pr, cons_pr = pr[0].astype(bool),  pr[1].astype(bool)

                gt_ov[ggo_gt]  = (0.6*gt_ov[ggo_gt]  + 0.4*np.array([0,255,0])).astype(np.uint8)
                gt_ov[cons_gt] = (0.6*gt_ov[cons_gt] + 0.4*np.array([0,0,255])).astype(np.uint8)
                pr_ov[ggo_pr]  = (0.6*pr_ov[ggo_pr]  + 0.4*np.array([0,255,0])).astype(np.uint8)
                pr_ov[cons_pr] = (0.6*pr_ov[cons_pr] + 0.4*np.array([0,0,255])).astype(np.uint8)

                # Panel 1x3
                fig, axs = plt.subplots(1,3, figsize=(12,4))
                axs[0].imshow(img, cmap='gray'); axs[0].set_title('Input'); axs[0].axis('off')
                axs[1].imshow(gt_ov[..., ::-1]); axs[1].set_title('Ground Truth'); axs[1].axis('off')
                axs[2].imshow(pr_ov[..., ::-1])
                axs[2].set_title(f'Prediction\nDice GGO {d_ggo:.2f} | Cons {d_cons:.2f}')
                axs[2].axis('off')
                plt.tight_layout()

                # Guardar + mostrar
                fig.savefig(os.path.join(out_dir, f"val_example_{saved:02d}.png"), dpi=150)
                if show:
                    plt.show()
                plt.close(fig)

                saved += 1

render_examples(val_dl, n_samples=12, thr=0.45, show=True)
print("Saved in /kaggle/working/figs")

In [None]:
#%%skip_if True

import pandas as pd

@torch.no_grad()
def per_image_report(val_loader, thr=0.45, eps=1e-7, count_empty_as=1.0):
    rows = []
    model.eval()
    for xb, yb in val_loader:
        xb = xb.to(device); yb = yb.to(device)
        logits = model(xb)
        probs = torch.sigmoid(logits)
        preds = (probs > thr).float()
        B = xb.size(0)
        for i in range(B):
            row = {"thr": thr}
            for c, cname in enumerate(["ggo", "cons"]):
                p = preds[i,c]; t = yb[i,c]
                inter = (p*t).sum().item()
                den   = (p + t).sum().item()
                if den == 0:
                    dice = 1.0 if count_empty_as is not None else np.nan
                    iou  = 1.0 if count_empty_as is not None else np.nan
                else:
                    dice = (2*inter + eps)/(den + eps)
                    iou  = dice/(2 - dice)
                row[f"dice_{cname}"] = dice
                row[f"iou_{cname}"]  = iou
            row["dice_mean"] = np.nanmean([row["dice_ggo"], row["dice_cons"]])
            row["iou_mean"]  = np.nanmean([row["iou_ggo"], row["iou_cons"]])
            rows.append(row)
    return pd.DataFrame(rows)

df_report = per_image_report(val_dl, thr=0.45, count_empty_as=1.0)
df_report.to_csv("/kaggle/working/val_metrics_per_image.csv", index=False)

summary = {
    "dice_ggo_mean":  float(df_report["dice_ggo"].mean()),
    "dice_cons_mean": float(df_report["dice_cons"].mean()),
    "dice_mean":      float(df_report["dice_mean"].mean()),
    "iou_ggo_mean":   float(df_report["iou_ggo"].mean()),
    "iou_cons_mean":  float(df_report["iou_cons"].mean()),
    "iou_mean":       float(df_report["iou_mean"].mean()),
    "n_images":       int(len(df_report)),
    "thr":            0.45
}
print("Summary:", summary)
print("CSV:", "/kaggle/working/val_metrics_per_image.csv")


## 6. Conclusions

In [None]:
# Mini viewer (2x2) para: Accuracy, Confusion, Loss, Results
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

ROOT = "/kaggle/input/result-training-images"

# Orden fijo y candidatos por título
entries = [
    ("Accuracy",  ["accuarcy.png", "accuracy.png"]),  # intenta ambas
    ("Confusion", ["confusion.png"]),
    ("Loss",      ["loss.png"]),
    ("Results",   ["results.png"]),
]

def first_existing(root, candidates):
    for name in candidates:
        p = os.path.join(root, name)
        if os.path.exists(p) and os.path.getsize(p) > 0:
            return p
    return None

# Figura 2x2
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()

for ax, (title, cands) in zip(axes, entries):
    ax.axis("off")
    path = first_existing(ROOT, cands)
    if path:
        ax.imshow(mpimg.imread(path))
        ax.set_title(title)
    else:
        ax.text(0.5, 0.5, f"{title}\n(file not found)", ha="center", va="center", fontsize=12)

fig.suptitle("Training Metrics Snapshots", y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.96])

out_path = "/kaggle/working/training_images_grid.png"
plt.savefig(out_path, dpi=150)
print(f"[OK] Saved grid to {out_path}")

plt.show()


### 6.1 Accuarcy (per pixel)

The accuracy, although 99%, does not indicate good model results. Although the background and lung classes are not being targeted by the model, not classifying them already artificially increases their magnitude. If the model classified the entire image as background, it would have more than 95% accuracy.
The Grounded-glass opacity class reaches a coefficient (F1) of ~0.83, indicating that the vast majority of predicted pixels do match the actual mask, but its IoU of ~0.64 indicates that it still has many false values, making it difficult to delimit the edges.
On the other hand, the consolidation class is the minority in the image, so it does not have enough weight for correct identification, demonstrated by its coefficient of ~0.51 where only half of the predicted pixels are correct, and an IoU of ~0.36 confirms that the model is very conservative when it comes to masking a pixel with this class.

All of this, combined with the difference between the loss functions of the training and validation sets, indicates an overfitting problem with moderate bias and average variance.

### 6.2 Confusion Matrix

GGO

- TN=38,712,350, FP=217,945, FN=386,092, TP=1,315,933

- Precision 0.858, Recall 0.773, F1/Dice 0.813, Acc 0.985

Good balance, some misses (FN) but many correct positives (TP).

Consolidation

- TN=40,377,709, FP=91,471, FN=74,901, TP=88,239

- Precision 0.491, Recall 0.541, F1/Dice 0.515, Acc 0.996

Model is conservative (low FP) but still misses many CONS (FN comparable to TP). High accuracy is again mostly background.

### 6.3 Loss train vs val

- Train loss drops smoothly as it progress 1.15 → 0.37.

- Val loss improves early but then plateaus 0.90→0.76 with small bumps.

- Persistent gap (≈ 0.39) ⇒ overfitting: the model probably memorizes train patterns better than it generalizes.

- Notice that from the 30th-35th epoch the results improve vert smoothly or keeps between a range.

### 6.5 Underfitting/Overfitting?

Under/overfitting verdict

There is no evidence of underfitting respladed by our results, low train loss and Dice 0.70 however probably we have a small or mild overfitting, this due loss validation value do not improve as in training, also a stable Dice value but smaller than train. That behaviour, Validation accuracy near-perfect mainly due to huge TN counts (background), while foreground performance—especially CONS—lags (Dice ~0.52) are signals of overfitting.




## 7. Adjust hyperparameters and our model vs other

To amplify our vision we decide take action in two steps: 

1) Modify the hyperparameters of our current implementation.
2) Use another model to compare the results.

Note: This results are just to get more information about the performance of our current implementation and is not used it in test/submission.

## 7.1 Adjust hyperparameters

### Hyperparameter changes — what and why

- Loss variants:

    - BCEWithLogits + Soft Dice → balances calibration (BCE) with overlap quality (Dice); stabilizes training and optimizes Dice directly.

    - Weighted BCE+Dice (pos_weight) → counteracts class imbalance by up-weighting positive pixels per class.

    - Tversky (α=0.7, β=0.3) → trades off FP vs FN; useful for small/rare lesions.

- Upsampling head: Bilinear ↔ Transposed Conv → test whether learned upsampling sharpens boundaries vs smoother bilinear.

- Learning rate: trial 1e-3 → 5e-4 on AdamW (wd=1e-4) → reduces oscillations/overfit; still decays with cosine annealing.

- Decision threshold: tune around 0.45 → better Dice/accuracy trade-off for this dataset.

- Epoch budget for ablations: ~20–30 epochs → quick feedback; keep full training for the main model.

Everything else held constant: same data splits, augmentations, batch size, metrics—so gains are attributable to the tweak.

In [None]:
import time, copy, math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

# ---------- Extra losses ----------
class BCEDiceLossWeighted(nn.Module):
    """BCEWithLogits (with pos_weight) + Soft Dice."""
    def __init__(self, pos_weight=None, eps=1e-7):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        self.eps = eps
    def dice_soft(self, logits, targets):
        probs = torch.sigmoid(logits)
        inter = (probs * targets).sum(dim=(0,2,3))
        den   = (probs + targets).sum(dim=(0,2,3)).clamp_min(self.eps)
        dice  = (2*inter) / den
        return 1 - dice.mean()
    def forward(self, logits, targets):
        return self.bce(logits, targets) + self.dice_soft(logits, targets)

class TverskyLoss(nn.Module):
    """Tversky loss (generalized Dice): alpha penalizes FN, beta penalizes FP."""
    def __init__(self, alpha=0.7, beta=0.3, eps=1e-7):
        super().__init__()
        self.alpha, self.beta, self.eps = float(alpha), float(beta), float(eps)
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        TP = (probs * targets).sum(dim=(0,2,3))
        FP = (probs * (1 - targets)).sum(dim=(0,2,3))
        FN = ((1 - probs) * targets).sum(dim=(0,2,3))
        tversky = (TP + self.eps) / (TP + self.alpha*FN + self.beta*FP + self.eps)
        return 1 - tversky.mean()

# ---------- Class imbalance → pos_weight for BCE ----------
def _compute_pos_weight_from_Y(Y_np):
    # Y_np: (N,H,W,2) uint8
    tot = Y_np.shape[0] * Y_np.shape[1] * Y_np.shape[2]
    pos = Y_np.reshape(-1, 2).sum(axis=0).astype(np.float64)
    pos = np.maximum(pos, 1.0)  # avoid div by zero
    neg = tot - pos
    pw = neg / pos  # per-channel pos_weight
    return torch.tensor(pw, dtype=torch.float32, device=device)

_pos_weight = _compute_pos_weight_from_Y(Y_train)

# ---------- Local training helper (doesn't touch your globals) ----------
def train_one_model(build_model_fn, criterion, lr=1e-3, wd=1e-4, epochs=20,
                    name="exp", thr=None, use_cosine=True):
    model_local = build_model_fn().to(device)
    opt = torch.optim.AdamW(model_local.parameters(), lr=lr, weight_decay=wd)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) if use_cosine else None

    best_val = -1.0
    best_state = None
    hist = {"epoch": [], "train_dice": [], "val_dice": [], "train_loss": [], "val_loss": []}
    run_thr = float(thr if thr is not None else globals().get("THR", 0.45))

    # eval train set w/o augs (reuse if exists)
    _train_eval_ds = NumpySegDataset(X_train, Y_train, train=False, apply_aug_fn=None)
    _train_eval_dl = DataLoader(_train_eval_ds, batch_size=globals().get("BATCH", 8),
                                shuffle=False, num_workers=2, pin_memory=USE_CUDA)

    start = time.time()
    for ep in range(1, epochs+1):
        model_local.train()
        tr_loss_sum, n_items = 0.0, 0
        for xb, yb in train_dl:
            xb = xb.to(device, non_blocking=USE_CUDA); yb = yb.to(device, non_blocking=USE_CUDA)
            opt.zero_grad(set_to_none=True)
            with autocast_ctx():
                logits = model_local(xb)
                loss = criterion(logits, yb)
            if USE_CUDA and scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(opt); scaler.update()
            else:
                loss.backward(); opt.step()
            tr_loss_sum += loss.item() * xb.size(0)
            n_items += xb.size(0)

        # eval
        tr_loss_eval, tr_dice, _, _ = evaluate_epoch(model_local, _train_eval_dl, criterion, thr=run_thr)
        va_loss_eval, va_dice, _, _ = evaluate_epoch(model_local, val_dl,        criterion, thr=run_thr)
        if sch is not None: sch.step()

        if va_dice > best_val:
            best_val = va_dice
            best_state = copy.deepcopy(model_local.state_dict())

        hist["epoch"].append(ep)
        hist["train_loss"].append(tr_loss_sum / max(n_items, 1))
        hist["val_loss"].append(va_loss_eval)
        hist["train_dice"].append(tr_dice)
        hist["val_dice"].append(va_dice)

        print(f"[{name}] Ep {ep:03d}/{epochs} | lr={opt.param_groups[0]['lr']:.2e} | "
              f"tr_loss={hist['train_loss'][-1]:.4f} val_loss={va_loss_eval:.4f} "
              f"tr_dice={tr_dice:.4f} val_dice={va_dice:.4f} best={best_val:.4f}")

    dur = time.time() - start
    ckpt = f"/kaggle/working/{name}.pt"
    if best_state is not None:
        torch.save(best_state, ckpt)
    return {"name": name, "best_val_dice": best_val, "epochs": epochs, "lr": lr,
            "wd": wd, "thr": run_thr, "seconds": round(dur, 1), "ckpt": ckpt}, hist

def build_unet(bilinear=True):
    return UNet2D(n_channels=1, n_classes=2, bilinear=bilinear)

HP_EPOCHS = min(30, int(globals().get("EPOCHS", 60)//2))  # short runs
configs = [
    {"name": "unet_baseline_lr5e-4",     "bilinear": True,  "loss": "bce_dice",          "lr": 5e-4, "wd": 1e-4},
    {"name": "unet_tversky_a0.7_b0.3",   "bilinear": True,  "loss": "tversky",           "lr": 1e-3, "wd": 1e-4},
    {"name": "unet_weighted_bce_dice",   "bilinear": True,  "loss": "bce_dice_weighted", "lr": 1e-3, "wd": 1e-4},
    {"name": "unet_deconv_upsample",     "bilinear": False, "loss": "bce_dice",          "lr": 1e-3, "wd": 1e-4},
]

CONFIGS_TO_RUN = [c["name"] for c in configs]  # or e.g. ["unet_tversky_a0.7_b0.3"]

xp_rows, xp_hists = [], {}
for cfg in configs:
    if cfg["name"] not in CONFIGS_TO_RUN:
        continue
    # pick criterion
    if cfg["loss"] == "tversky":
        criterion_hp = TverskyLoss(alpha=0.7, beta=0.3)
    elif cfg["loss"] == "bce_dice_weighted":
        criterion_hp = BCEDiceLossWeighted(pos_weight=_pos_weight)
    else:
        criterion_hp = BCEDiceLoss()  # your original

    res, hist = train_one_model(
        build_model_fn=lambda: build_unet(bilinear=cfg["bilinear"]),
        criterion=criterion_hp, lr=cfg["lr"], wd=cfg["wd"],
        epochs=HP_EPOCHS, name=cfg["name"], thr=globals().get("THR", 0.45)
    )
    xp_rows.append(res); xp_hists[cfg["name"]] = hist

# Save & display summary
if xp_rows:
    xp_df = pd.DataFrame(xp_rows).sort_values("best_val_dice", ascending=False)
    xp_csv = "/kaggle/working/hparam_experiments.csv"
    xp_df.to_csv(xp_csv, index=False)
    print("\n=== Hyperparam Experiments (summary) ===")
    display(xp_df)
    print(f"Saved: {xp_csv}")
else:
    print("No experiments were run. Edit CONFIGS_TO_RUN.")


## 7.2 DeepLab V3 vs our U-Net

### But why DeepLab?

- It brings multi-scale context via ASPP and dilated convolutions, which helps with lesions that vary from tiny GGOs to larger consolidations.

- Provides a strong, widely-used segmentation baseline with a different inductive bias than U-Net (context-heavy vs skip-heavy), giving a fair test of whether global context is your current bottleneck.

- Low engineering overhead (available in torchvision) and apples-to-apples evaluation (same data, loss, and metrics).

- If DeepLab outperforms, it signals you might add context modules/receptive-field expansions to U-Net; if it ties, the task may be data-limited and further gains come from ETL/augmentation.

### DeepLabV3 (comparison) — hyperparameter/config list

- Backbone: ResNet-50 (torchvision DeepLabV3), aux_loss=False.

- 0 Input: grayscale repeated to 3 channels.

- Output: num_classes=2.

- Loss: Weighted BCE+Dice (and Tversky as an alternative ablation).

- Optimizer: AdamW, lr=3e-4, weight_decay=1e-4.

- LR schedule: CosineAnnealingLR with T_max = epochs.

- Epochs: up to 40 (or a short run for quick comparison).

- Batch size: 8 (same as our U-Net).

- Threshold for metrics: 0.45.

- Augmentations & loaders: reuse exactly the same as U-Net.

- Checkpoint: save best Val Dice to /kaggle/working/best_deeplabv3.pt.

In [None]:
# ============================================================
# - Adapts 1-channel input by repeating to 3
# - Sets num_classes=2
# - Uses metrics/eval helpers
# - Saves /kaggle/working/best_deeplabv3.pt
# ============================================================
import torch
import torch.nn as nn

try:
    from torchvision.models.segmentation import deeplabv3_resnet50
    _has_tv = True
except Exception as e:
    print("torchvision not available for DeepLabV3:", e)
    _has_tv = False

class DeepLabV3Wrap(nn.Module):
    """Wraps torchvision DeepLabV3 to:
       - Accept (B,1,H,W) by repeating channel to 3
       - Return logits tensor (B,2,H,W) directly
    """
    def __init__(self, num_classes=2):
        super().__init__()
        if not _has_tv:
            raise RuntimeError("torchvision.deeplabv3_resnet50 not available.")

        # Try modern API
        try:
            self.net = deeplabv3_resnet50(weights=None, num_classes=num_classes, aux_loss=False)
        except TypeError:
            # Older API fallback
            self.net = deeplabv3_resnet50(pretrained=False, progress=True, aux_loss=False)
            # Replace classifier final conv out_channels to 2
            if hasattr(self.net, "classifier") and isinstance(self.net.classifier, nn.Sequential):
                # last module should be Conv2d(256, 21, 1)
                for m in reversed(self.net.classifier):
                    if isinstance(m, nn.Conv2d):
                        if m.out_channels != num_classes:
                            new_last = nn.Conv2d(m.in_channels, num_classes, kernel_size=1)
                            m.in_channels, m.out_channels = new_last.in_channels, new_last.out_channels
                            # replace in seq
                            idx = list(self.net.classifier).index(m)
                            self.net.classifier[idx] = new_last
                        break

    def forward(self, x):
        # x: (B,1,H,W) → (B,3,H,W)
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        out = self.net(x)
        # return raw logits tensor (torchvision returns dict in modern versions)
        if isinstance(out, dict) and "out" in out:
            return out["out"]
        return out  # assume tensor

def train_deeplab(epochs=None, lr=3e-4, wd=1e-4, thr=None, loss_name="bce_dice_weighted"):
    if not _has_tv:
        print("Skipping DeepLabV3: torchvision unavailable.")
        return None, None

    dlv3 = DeepLabV3Wrap(num_classes=2).to(device)

    if loss_name == "tversky":
        criterion_dl = TverskyLoss(alpha=0.7, beta=0.3)
    elif loss_name == "bce_dice_weighted":
        criterion_dl = BCEDiceLossWeighted(pos_weight=_pos_weight)
    else:
        criterion_dl = BCEDiceLoss()

    opt = torch.optim.AdamW(dlv3.parameters(), lr=lr, weight_decay=wd)
    ep = int(epochs if epochs is not None else min(40, int(globals().get("EPOCHS", 60))))
    run_thr = float(thr if thr is not None else globals().get("THR", 0.45))
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=ep)

    best_val, best_state = -1.0, None
    hist = {"epoch": [], "train_dice": [], "val_dice": [], "train_loss": [], "val_loss": []}

    # train-eval loader for TRAIN w/o augs
    _train_eval_ds = NumpySegDataset(X_train, Y_train, train=False, apply_aug_fn=None)
    _train_eval_dl = DataLoader(_train_eval_ds, batch_size=globals().get("BATCH", 8),
                                shuffle=False, num_workers=2, pin_memory=USE_CUDA)

    for e in range(1, ep+1):
        dlv3.train()
        tr_loss_sum, n_items = 0.0, 0
        for xb, yb in train_dl:  # reuse your existing train_dl with augs
            xb = xb.to(device, non_blocking=USE_CUDA); yb = yb.to(device, non_blocking=USE_CUDA)
            opt.zero_grad(set_to_none=True)
            with autocast_ctx():
                logits = dlv3(xb)
                loss = criterion_dl(logits, yb)
            if USE_CUDA and scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(opt); scaler.update()
            else:
                loss.backward(); opt.step()
            tr_loss_sum += loss.item() * xb.size(0)
            n_items += xb.size(0)

        tr_loss_eval, tr_dice, _, _ = evaluate_epoch(dlv3, _train_eval_dl, criterion_dl, thr=run_thr)
        va_loss_eval, va_dice, _, _ = evaluate_epoch(dlv3, val_dl,        criterion_dl, thr=run_thr)
        sch.step()

        if va_dice > best_val:
            best_val = va_dice
            best_state = copy.deepcopy(dlv3.state_dict())

        hist["epoch"].append(e)
        hist["train_loss"].append(tr_loss_sum / max(n_items, 1))
        hist["val_loss"].append(va_loss_eval)
        hist["train_dice"].append(tr_dice)
        hist["val_dice"].append(va_dice)

        print(f"[deeplabv3] Ep {e:03d}/{ep} | lr={opt.param_groups[0]['lr']:.2e} | "
              f"tr_loss={hist['train_loss'][-1]:.4f} val_loss={va_loss_eval:.4f} "
              f"tr_dice={tr_dice:.4f} val_dice={va_dice:.4f} best={best_val:.4f}")

    ckpt = "/kaggle/working/best_deeplabv3.pt"
    if best_state is not None:
        torch.save(best_state, ckpt)
    return {"name": "deeplabv3_resnet50", "best_val_dice": best_val, "epochs": ep,
            "lr": lr, "wd": wd, "thr": run_thr, "ckpt": ckpt}, hist

# ----  DeepLabV3 comparison ----
dl_summary, dl_hist = train_deeplab(epochs=min(40, int(globals().get("EPOCHS", 60))),
                                    lr=3e-4, wd=1e-4, loss_name="bce_dice_weighted")

# ---- Compare U-Net best vs DeepLabV3 ----
rows = []
try:
    rows.append({"model": "unet_yours_best", "best_val_dice": float(globals().get("best_val", -1.0)),
                 "ckpt": "/kaggle/working/best_unet.pt"})
except Exception:
    pass
if dl_summary is not None:
    rows.append({"model": dl_summary["name"], "best_val_dice": dl_summary["best_val_dice"],
                 "ckpt": dl_summary["ckpt"]})
if rows:
    cmp_df = pd.DataFrame(rows).sort_values("best_val_dice", ascending=False)
    print("\n=== Model comparison (val Dice) ===")
    display(cmp_df)
    cmp_df.to_csv("/kaggle/working/model_comparison.csv", index=False)
    print("Saved: /kaggle/working/model_comparison.csv")
else:
    print("No models to compare.")


## 8. Overall Conclussions


Overall and according with our results we get a more than acceptable results, and we detect the following points to support our good performance at our first baseline implementeation:

- Stable training: no divergence; train loss ↓ 1.15 → 0.37, val metrics trend upward and stabilize.

- Meaningful segmentation: not an all-background cheat—your val Dicē ≈ 0.664 with GGO Dice ≈ 0.813 shows the model really finds lesions.

- Gaps are modest: train vs val Dice gap ~4 pp = mild overfitting, not catastrophic.

- Confusions make sense: high TNs (class imbalance), reasonable GGO TP/FP tradeoff; CONS is harder (typical).

For our purpose we unlock a significant archivement ss a first baseline, results are more than acceptable and provide a solid platform for targeted improvements—especially for CONS and generalization.

### Pros 

Our model has strengths that will let us put hands on imprvement.

- Comes from U-Net base implementation so is not a very technical or difficult variation.

- Convergent and stable training with a sensible schedule (AdamW + cosine).

- Strong GGO segmentation (Dice ~0.81) with balanced precision/recall.

- Reproducibility: fixed seed, deterministic cuDNN, clear artifact logging (CSV + figures + best checkpoint).

- Operating point already serviceable: a single global threshold (τ=0.45) yields decent mean Dice without heavy tuning.

- Good specificity / low FP consistent with medical expectations.

### Cons
But also there is prove that mild overfitting and in consecuence, oportunity areas of opportunity like:

- Mild overfitting: persistent train–val gaps; val loss plateaus around 0.76.

- CONS underperformance: small and low-contrast regions yield lower Dice/recall than GGO.

- Accuracy is misleading: dominated by background TNs; Dice/IoU and balanced accuracy are more informative.

- Augmentation & sampling not class-aware: generic aug helps, but without lesion-aware sampling, CONS remains rare in batches.

### Future work and improvements

- Openess to feedback: First of all we are opened to receibe feedback from experts or more experienced people in the filed as our professors, community, etc... this in order to expand our vision and improve our model.

- Loss tilted toward FN: Try Tversky/Focal-Tversky with α=0.3, β=0.7 (β>α penalizes FN). Or keep BCE+Dice but set pos_weight>1 only for CONS.

- Early stopping on val Dice.

- Try Tversky/Focal-Tversky with α=0.3, β=0.7 (β>α penalizes FN). Or keep BCE+Dice but set pos_weight>1 only for CONS.

- Lesion-aware sampling: Ensure a fraction of training crops must contain CONS pixels (min-area filter or oversample CONS-positive slices).

- Better report diagnosis: Alongside pixel accuracy, include balanced accuracy, Cohen’s κ, per-image median Dice, and PR curves.

- More 

## Test

This section is dedicated to run our model (U-Net) in Test data set in order to generate our results and submit it to the competiotion. 

In [None]:
# ============================
# TEST → SUBMISSION (Id,Predicted)
# ============================
import os, numpy as np, pandas as pd, torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) Load weights if is necessary
ckpt_path = "/kaggle/working/best_unet.pt"
if os.path.exists(ckpt_path):
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)
model.eval()

try:
    X_test  # (N,512,512,1) float32 [0,1]
except NameError:
    prep = "/kaggle/working/prepared"
    X_test = np.load(os.path.join(prep, "test_medseg.npz"))["X"]

# 3) Prediction
THR   = float(globals().get("THR", 0.45))
BATCH = int(globals().get("BATCH", 8))

@torch.no_grad()
def predict_binary(model, X, thr=THR, batch=BATCH, device=device):
    outs = []
    for i in range(0, len(X), batch):
        xb = torch.from_numpy(X[i:i+batch]).permute(0,3,1,2).float().to(device)  # (B,1,H,W)
        logits = model(xb)                               # (B,2,H,W)
        probs  = torch.sigmoid(logits).cpu().numpy()     # (B,2,H,W)
        preds  = (probs >= thr).astype(np.uint8)         # binario
        outs.append(preds)
    return np.concatenate(outs, axis=0)                  # (N,2,H,W)

preds = predict_binary(model, X_test)
N, C, H, W = preds.shape
assert C == 2, f"Esperaba 2 clases, obtuve {C}"

# 4) Keeps the requested format (N,H,W,2) before ravel()
if preds.shape == (N, C, H, W):                         # NCHW → NHWC
    preds_nhwc = np.transpose(preds, (0, 2, 3, 1)).copy()
else:
    preds_nhwc = preds

# Sanity check 
u = np.unique(preds_nhwc)
assert set(u.tolist()).issubset({0,1}), f"No binary predictions: {u}"

# 5) Build CSV
flat = preds_nhwc.ravel(order='C').astype(int)          # (N*H*W*2,)
ids  = np.arange(flat.size, dtype=int)

sub_df = pd.DataFrame({"Id": ids, "Predicted": flat}).set_index("Id")
out_path = "/kaggle/working/submission.csv"
sub_df.to_csv(out_path)

print(f"[OK] submission.csv guardado en {out_path}")
print("Expected size:", N*H*W*2, "| rows:", len(sub_df))
print(sub_df.head())


## References

O. Ronneberger, P. Fischer, and T. Brox, “U-Net: Convolutional Networks for Biomedical Image Segmentation,” in Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, LNCS, vol. 9351, N. Navab, J. Hornegger, W. M. Wells, and A. F. Frangi, Eds. Cham: Springer, 2015, pp. 234–241. doi: 10.1007/978-3-319-24574-4_28.

Chen, L.-C., Papandreou, G., Kokkinos, I., Murphy, K., & Yuille, A. L. (2017). DeepLab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected CRFs (Version 2). arXiv. https://doi.org/10.48550/arXiv.1606.00915

Maftouni M. (2020). PyTorch Baseline for Semantic Segmentation. Kaggle. https://www.kaggle.com/code/maedemaftouni/pytorch-baseline-for-semantic-segmentation