In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
from tqdm.notebook import tqdm

In [2]:
!ls data

testImages  trainImages  trainMasks  trainSet.csv


## EDA

In [3]:
data_dir = "data/"
train_df = pd.read_csv(data_dir + "trainSet.csv")
train_df

Unnamed: 0,imageID,status,mask
0,1164,1,16165 16166 16167 16168 16169 16678 16679 1668...
1,1169,0,-100
2,1171,1,58682 58683 58684 58685 58686 59194 59195 5919...
3,1177,1,125642 125643 125644 125645 125646 126155 1261...
4,1178,1,53951 53952 53953 53954 53955 54463 54464 5446...
...,...,...,...
500,20408,1,61293 61294 61295 61296 61297 61804 61805 6180...
501,20410,1,78295 78805 78806 78807 79316 79317 79318 7931...
502,20594,1,61197 61198 61199 61200 61201 61709 61710 6171...
503,20605,0,-100


In [4]:
train_df["status"].value_counts()

status
1    352
0    153
Name: count, dtype: int64

In [5]:
!ls data/trainImages/trainImages

ls: cannot access 'data/trainImages/trainImages': No such file or directory


## Dataloader

In [6]:
from PIL import Image

image_path = "data/trainImages/4826.jpg"

# Open image
img = Image.open(image_path)

# Get original size (width, height)
width, height = img.size

print(f"Original size: width = {width}, height = {height}")


Original size: width = 512, height = 512


In [7]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class LungCTNeedleDatasetV2(Dataset):
    def __init__(self, csv_path, image_dir, image_size=(512, 512), use_ignore_index=True):
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.image_size = image_size  # (H, W)
        self.use_ignore_index = use_ignore_index  # True → fill mask with -100 when label == 0


        self.image_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),  # [1, H, W]
        ])


    def _parse_mask(self, mask_str, label):
        H, W = self.image_size
        if str(mask_str).strip() == "-100" or label == 0:
            fill_value = -100.0 if self.use_ignore_index else 0.0
            return torch.full((1, H, W), fill_value, dtype=torch.float32)

        mask = torch.zeros(H * W, dtype=torch.float32)
        try:
            indices = list(map(int, mask_str.strip().split()))
            indices = [i for i in indices if 0 <= i < H * W]
            mask[indices] = 1.0
        except Exception as e:
            print(f"[Warning] Failed parsing mask: {mask_str} — {e}")
        return mask.view(1, H, W)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        patient_id = str(row['imageID'])
        label = int(row['status'])
        mask_str = row['mask']

        # Load image
        image_path = os.path.join(self.image_dir, f"{patient_id}.jpg")
        image = Image.open(image_path).convert("L")
        image = self.image_transform(image)  # [1, H, W]

        # Create mask
        mask = self._parse_mask(mask_str, label)

        return image, torch.tensor(label, dtype=torch.float32), mask, patient_id


In [8]:
dataset = LungCTNeedleDatasetV2(
    csv_path="data/trainSet.csv",
    image_dir="data/trainImages",
    image_size=(512, 512),
    use_ignore_index=True  # set to False if you want zero-filled instead
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)


In [9]:
for image, label, mask, patient_id in dataloader:
    print(image.shape)     # [B, 1, 512, 512]
    print(mask.shape)      # [B, 1, 512, 512]
    print(label) 
    print(patient_id)      # List[str]
    break


torch.Size([4, 1, 512, 512])
torch.Size([4, 1, 512, 512])
tensor([1., 0., 1., 1.])
('9574', '6373', '2319', '16184')


In [10]:
train_df[train_df["imageID"] == 16042]

Unnamed: 0,imageID,status,mask
378,16042,0,-100


## Model Architecture

In [11]:
import torch
pretrained_backbone = torch.hub.load("Warvito/radimagenet-models", 'radimagenet_resnet50')

Using cache found in /home/iadam/.cache/torch/hub/Warvito_radimagenet-models_main


In [12]:
import torch
import torch.nn as nn

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from radimagenet_models.models.resnet import radimagenet_resnet50

# --------------------------
# ASPP Module for multi-scale context
# --------------------------
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6)
        self.conv12 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)
        self.conv18 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18)
        self.out_conv = nn.Conv2d(out_channels * 4, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv6(x)
        x3 = self.conv12(x)
        x4 = self.conv18(x)
        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
        return self.out_conv(x_cat)

# --------------------------
# Improved AttentionGatedUNet with ASPP and deep supervision
# --------------------------
class NeedleSegmentationNet(nn.Module):
    def __init__(self, in_channels=1, feature_dim=2048):
        super().__init__()
        self.encoder = radimagenet_resnet50()

        if in_channels == 1:
            old_conv = self.encoder.conv1
            new_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
            with torch.no_grad():
                new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
            self.encoder.conv1 = new_conv

        self.enc1 = nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.relu)
        self.enc2 = self.encoder.layer1
        self.enc3 = self.encoder.layer2
        self.enc4 = self.encoder.layer3
        self.enc5 = self.encoder.layer4

        # ASPP on the deepest features
        self.aspp = ASPP(2048, 256)

        self.up4 = self._upblock(256, 128)
        self.up3 = self._upblock(128, 64)
        self.up2 = self._upblock(64, 32)
        self.up1 = self._upblock(32, 16)

        self.segmentation_head = nn.Conv2d(16, 1, kernel_size=1)

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(2048, 1)
        )

    def _upblock(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def _align_skip(self, x, skip):
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=True)
        return x + skip

    def forward(self, x):
        e1 = self.enc1(x)       # [B, 64, 256, 256]
        e2 = self.enc2(e1)      # [B, 256, 128, 128]
        e3 = self.enc3(e2)      # [B, 512, 64, 64]
        e4 = self.enc4(e3)      # [B, 1024, 32, 32]
        e5 = self.enc5(e4)      # [B, 2048, 16, 16]

        gated = self.aspp(e5)   # ASPP context [B, 256, 16, 16]

        d4 = self._align_skip(self.up4(gated), e4)
        d3 = self._align_skip(self.up3(d4), e3)
        d2 = self._align_skip(self.up2(d3), e2)
        d1 = self._align_skip(self.up1(d2), e1)

        seg_mask = self.segmentation_head(d1)  # [B, 1, 512, 512]
        class_logits = self.classifier(e5).squeeze(-1)

        return {
            "segmentation": seg_mask,
            "classification": class_logits
        }


In [14]:
def focal_tversky_loss(pred, target, alpha=0.3, beta=0.7, gamma=0.75, epsilon=1e-6):
    pred = torch.sigmoid(pred)
    tp = (pred * target).sum(dim=(1,2,3))
    fp = ((1 - target) * pred).sum(dim=(1,2,3))
    fn = (target * (1 - pred)).sum(dim=(1,2,3))
    tversky = (tp + epsilon) / (tp + alpha * fp + beta * fn + epsilon)
    return (1 - tversky) ** gamma

In [15]:
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn

def train_one_epoch_with_eval(model, dataloader, optimizer, device, scaler,
                               lambda_cls=1.0, threshold=0.5, alpha=0.3, beta=0.7, gamma=0.75):
    model.train()
    cls_loss_fn = nn.BCEWithLogitsLoss()
    total_loss = 0.0

    dice_total = 0.0
    sens_total = 0.0
    count = 0

    for images, labels, masks, _ in dataloader:
        images = images.to(device)
        labels = labels.to(device).float()
        masks = masks.to(device)

        optimizer.zero_grad(set_to_none=True)

        with autocast():
            outputs = model(images)
            seg_pred = outputs['segmentation']
            class_logits = outputs['classification']

            if class_logits.ndim == 2:
                class_logits = class_logits.squeeze(-1)

            # Apply ignore index masking
            valid_mask = (masks != -100)
            pred_valid = seg_pred[valid_mask]
            target_valid = masks[valid_mask]

            # Focal Tversky Loss
            loss_seg = focal_tversky_loss(pred_valid, target_valid,
                                          alpha=alpha, beta=beta, gamma=gamma).mean()

            # Classification loss
            loss_cls = cls_loss_fn(class_logits, labels)
            loss = loss_seg + lambda_cls * loss_cls

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        # --- Metrics (only for positives) ---
        needle_present = labels == 1
        if needle_present.sum() > 0:
            seg_pred_pos = torch.sigmoid(seg_pred[needle_present]) > threshold
            gt_mask_pos = masks[needle_present]

            preds_flat = seg_pred_pos.view(seg_pred_pos.size(0), -1)
            masks_flat = gt_mask_pos.view(gt_mask_pos.size(0), -1)

            intersection = (preds_flat * masks_flat).sum(dim=1)
            dice = (2. * intersection) / (preds_flat.sum(dim=1) + masks_flat.sum(dim=1) + 1e-8)
            TP = intersection
            FN = ((~preds_flat.bool()) * masks_flat.bool()).sum(dim=1)
            sens = TP / (TP + FN + 1e-8)

            dice_total += dice.sum().item()
            sens_total += sens.sum().item()
            count += preds_flat.size(0)

    avg_loss = total_loss / len(dataloader)
    avg_dice = dice_total / count if count > 0 else 0.0
    avg_sens = sens_total / count if count > 0 else 0.0

    return avg_loss, avg_dice, avg_sens


In [16]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler
from tqdm import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"
num_epochs = 200
train_loader = dataloader  # your dataloader for train set

# Initialize model, optimizer, scaler
model = NeedleSegmentationNet(in_channels=1).to(device)
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scaler = GradScaler()

# Optional: scheduler (adaptive LR if loss plateaus)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Start training
for epoch in tqdm(range(num_epochs), desc='Training Progress'):
    loss, dice, sens = train_one_epoch_with_eval(
        model, 
        train_loader, 
        optimizer, 
        device, 
        scaler,
        lambda_cls=1.0,         # classification loss weight
        threshold=0.3,           # use your best threshold
        alpha=0.3, beta=0.7, gamma=0.75  # focal tversky params
    )
    scheduler.step(loss)  # reduce LR on plateau

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {loss:.4f} | Dice: {dice:.4f} | Sensitivity: {sens:.4f}")




Training Progress:   0%|                                                                                                                                             | 0/200 [00:00<?, ?it/s]


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [17]:
# Save model weights
torch.save(model.state_dict(), "attn_gated_0.73DICE_0.72SENSE.pth")


In [19]:
!ls data

testImages  trainImages  trainMasks  trainSet.csv


## Test

In [21]:
class LungCTTestDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, image_size=(512, 512)):
        self.image_dir = image_dir
        self.image_size = image_size
        self.image_paths = sorted([
            f for f in os.listdir(image_dir) if f.endswith(".jpg")
        ], key=lambda x: int(x.split('.')[0]))  # assumes filenames like 1164.jpg

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        filename = self.image_paths[idx]
        image_path = os.path.join(self.image_dir, filename)
        image = Image.open(image_path).convert("L")
        image = self.transform(image)
        image_id = int(filename.split('.')[0])  # extract numeric ID
        return image, image_id


In [23]:
device = torch.device("cuda:0")
final_model = AttentionGatedUNet()
final_model.load_state_dict(torch.load("attn_gated_0.73DICE_0.72SENSE.pth", map_location=device))  # replace with your model path
final_model.to(device)
final_model.eval()

AttentionGatedUNet(
  (encoder): ResNet50(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1.001e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (bn1): BatchNorm2d(64, eps=1.001e-05, momentum=0.99, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(64, eps=1.001e-05, momentum=0.99, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (bn3): BatchNorm2d(256, eps=1.001e-05, momentum=0.99, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0)

In [24]:
test_dataset = LungCTTestDataset(image_dir="data/testImages", image_size=(512, 512))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

In [26]:
from torchvision.utils import save_image

os.makedirs("predicted_masks", exist_ok=True)

with torch.no_grad():
    for images, image_ids in test_loader:
        images = images.to(device)
        outputs = final_model(images)
        seg_logits = outputs['segmentation']  # [B, 1, 512, 512]
        seg_probs = torch.sigmoid(seg_logits)
        seg_bin = (seg_probs > 0.5).float()  # Binarize to [0,1]

        for i in range(images.size(0)):
            mask = seg_bin[i]  # [1, 512, 512]
            image_id = image_ids[i]
            save_path = os.path.join("predicted_masks", f"{image_id}_pred.png")
            save_image(mask, save_path)  # Will be scaled to 0–255


In [37]:
def processImages(imgDirectory: str, saveDirectory: str = os.getcwd(), returnDF:bool = False):
    """
    Process binarized predicted mask images saved in a directory, implement checks, 
    and create a `'submission.csv'` submission file containing image status and mask indices.

    **Do NOT modify this function.**

    Parameters
    ----------
    imgDirectory : str
        The directory containing the images to be processed. It should have exactly 127 .png files.
        When you save your model's predicted masks, make sure the pixel values 
        are either 0 or 255, and save it as a .png file (preferably using PIL).
    saveDirectory : str, optional
        The directory where the resulting DataFrame will be saved as a CSV file. 
        Defaults to the current directory.
    returnDF : bool, optional
        Whether to return the DataFrame. Defaults to False.

    Returns
    -------
    df: A DataFrame with columns 'imageID', 'status', and 'mask', indexed by 'imageID' if `returnDF` is True, else None

    Raises
    ------
    ValueError
        If the number of .png files in `imgDirectory` is not 127,
        if any image is not binary, or if any image is not 512x512 pixels.

    Example Usage
    -------------
    `processImages('path/to/img/folder', 'path/to/save/folder')`
    """

    files = [f for f in os.listdir(imgDirectory) if f.endswith('.png')]  # Get all .png files in the directory
    if len(files) != 127:
        raise ValueError("Directory must contain exactly 127 .png files")

    files.sort(key=lambda x: int(x.split('_')[0]))  # Sort the files

    data = []  # List of dictionaries to be converted to DataFrame
    for file in files:
        imgPath = os.path.join(imgDirectory, file)
        img = np.array(Image.open(imgPath).convert('L'), dtype=np.uint8)

        # Check if image is binary
        if not np.array_equal(img, img.astype(bool).astype(img.dtype) * 255):
            raise ValueError(f"Image {file} is not binary")
        # Check image size
        if img.shape != (512, 512):
            raise ValueError(f"Image {file} is not of size 512x512")

        status = 1 if np.any(img == 255) else 0  # Determine status of image
        maskIndices = ' '.join(map(str, np.nonzero(img.flatten() == 255)[0])) if status else '-100'

        data.append({'imageID': int(file.split('_')[0]), 'status': status, 'mask': maskIndices})

    df = pd.DataFrame(data).set_index('imageID')
    df.to_csv(os.path.join(saveDirectory, 'submission.csv'))

    if returnDF: return df

In [30]:
processImages(imgDirectory="predicted_masks", saveDirectory=".", returnDF=False)
print("✅ Submission saved to submission.csv")

✅ Submission saved to submission.csv


## Threshold Tuning

In [31]:
import os
import torch
from tqdm import tqdm

# Folder to save predicted raw probabilities
os.makedirs("train_raw_probs", exist_ok=True)
os.makedirs("train_gt_masks", exist_ok=True)

final_model.eval()

with torch.no_grad():
    for images, labels, masks, patient_ids in tqdm(dataloader):
        images = images.cuda()
        outputs = final_model(images)
        seg_logits = outputs['segmentation']  # [B, 1, 512, 512]
        seg_probs = torch.sigmoid(seg_logits)  # Probabilities between 0 and 1

        for i in range(images.size(0)):
            prob_map = seg_probs[i].cpu()  # [1, 512, 512]
            gt_mask = masks[i].cpu()       # [1, 512, 512]
            image_id = patient_ids[i]

            torch.save(prob_map, f"train_raw_probs/{image_id}.pt")
            torch.save(gt_mask, f"train_gt_masks/{image_id}.pt")


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 127/127 [00:23<00:00,  5.45it/s]


In [None]:
import numpy as np
import torch
import os

def dice_score(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    return (2 * intersection + epsilon) / (union + epsilon)

def sensitivity_score(pred, target, epsilon=1e-6):
    # Sensitivity = TP / (TP + FN)
    tp = (pred * target).sum()
    fn = ((1 - pred) * target).sum()
    return (tp + epsilon) / (tp + fn + epsilon)

# Weights
alpha = 0.7  # You can adjust this (e.g., 0.5 if you want equal weight)

# Thresholds to try
thresholds = [round(x, 2) for x in torch.arange(0.3, 0.71, 0.05).tolist()]
results = {}

prob_files = sorted([f for f in os.listdir("train_raw_probs") if f.endswith(".pt")], key=lambda x: int(x.split('.')[0]))

for threshold in thresholds:
    all_dice = []
    all_sens = []
    for file in prob_files:
        prob_map = torch.load(os.path.join("train_raw_probs", file))  # [1, H, W]
        gt_mask = torch.load(os.path.join("train_gt_masks", file))    # [1, H, W]

        pred_mask = (prob_map > threshold).float()

        valid = (gt_mask != -100)
        pred_mask = pred_mask[valid]
        gt_mask_clean = gt_mask[valid]

        dice = dice_score(pred_mask, gt_mask_clean)
        sens = sensitivity_score(pred_mask, gt_mask_clean)

        all_dice.append(dice.item())
        all_sens.append(sens.item())

    avg_dice = np.mean(all_dice)
    avg_sens = np.mean(all_sens)
    combined_score = alpha * avg_dice + (1 - alpha) * avg_sens

    results[threshold] = {
        "dice": avg_dice,
        "sensitivity": avg_sens,
        "combined": combined_score
    }

    print(f"Threshold {threshold:.2f} | Dice: {avg_dice:.4f} | Sensitivity: {avg_sens:.4f} | Score: {combined_score:.4f}")


In [34]:
# Find best threshold based on combined score
best_threshold = max(results, key=lambda t: results[t]["combined"])
best_score = results[best_threshold]["combined"]
best_dice = results[best_threshold]["dice"]
best_sens = results[best_threshold]["sensitivity"]

print(f"\n✅ Best Threshold: {best_threshold} — Combined Score: {best_score:.4f} | Dice: {best_dice:.4f} | Sensitivity: {best_sens:.4f}")



✅ Best Threshold: 0.3 — Combined Score: 0.8385 | Dice: 0.8260 | Sensitivity: 0.8678


## Apply to Test

In [35]:
import os
import torch
from torchvision.utils import save_image

threshold = 0.3  # 🟢 Use the best threshold from train set

os.makedirs("predicted_masks_thresh_0.3", exist_ok=True)

final_model.eval()
with torch.no_grad():
    for images, image_ids in test_loader:
        images = images.to(device)
        outputs = final_model(images)
        seg_probs = torch.sigmoid(outputs['segmentation'])  # [B, 1, 512, 512]
        seg_bin = (seg_probs > threshold).float()

        for i in range(images.size(0)):
            mask = seg_bin[i]  # [1, 512, 512]
            image_id = image_ids[i]
            save_path = os.path.join("predicted_masks_thresh_0.3", f"{image_id}_pred.png")
            save_image(mask, save_path)


In [38]:

processImages(
    imgDirectory="predicted_masks_thresh_0.3",
    saveDirectory=".",
    returnDF=False
)

print("✅ submission.csv created using threshold = 0.3")


✅ submission.csv created using threshold = 0.3
