In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
from tqdm.notebook import tqdm

In [2]:
!ls data

testImages  trainImages  trainMasks  trainSet.csv


## EDA

In [3]:
data_dir = "data/"
train_df = pd.read_csv(data_dir + "trainSet.csv")
train_df

Unnamed: 0,imageID,status,mask
0,1164,1,16165 16166 16167 16168 16169 16678 16679 1668...
1,1169,0,-100
2,1171,1,58682 58683 58684 58685 58686 59194 59195 5919...
3,1177,1,125642 125643 125644 125645 125646 126155 1261...
4,1178,1,53951 53952 53953 53954 53955 54463 54464 5446...
...,...,...,...
500,20408,1,61293 61294 61295 61296 61297 61804 61805 6180...
501,20410,1,78295 78805 78806 78807 79316 79317 79318 7931...
502,20594,1,61197 61198 61199 61200 61201 61709 61710 6171...
503,20605,0,-100


In [4]:
train_df["status"].value_counts()

status
1    352
0    153
Name: count, dtype: int64

In [5]:
!ls data/trainImages/trainImages

ls: cannot access 'data/trainImages/trainImages': No such file or directory


## Dataloader

In [6]:
from PIL import Image

image_path = "data/trainImages/4826.jpg"

# Open image
img = Image.open(image_path)

# Get original size (width, height)
width, height = img.size

print(f"Original size: width = {width}, height = {height}")


Original size: width = 512, height = 512


In [7]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class LungCTNeedleDatasetV2(Dataset):
    def __init__(self, csv_path, image_dir, image_size=(512, 512), use_ignore_index=True):
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.image_size = image_size  # (H, W)
        self.use_ignore_index = use_ignore_index  # True → fill mask with -100 when label == 0


        self.image_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),  # [1, H, W]
        ])


    def _parse_mask(self, mask_str, label):
        H, W = self.image_size
        if str(mask_str).strip() == "-100" or label == 0:
            fill_value = -100.0 if self.use_ignore_index else 0.0
            return torch.full((1, H, W), fill_value, dtype=torch.float32)

        mask = torch.zeros(H * W, dtype=torch.float32)
        try:
            indices = list(map(int, mask_str.strip().split()))
            indices = [i for i in indices if 0 <= i < H * W]
            mask[indices] = 1.0
        except Exception as e:
            print(f"[Warning] Failed parsing mask: {mask_str} — {e}")
        return mask.view(1, H, W)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        patient_id = str(row['imageID'])
        label = int(row['status'])
        mask_str = row['mask']

        # Load image
        image_path = os.path.join(self.image_dir, f"{patient_id}.jpg")
        image = Image.open(image_path).convert("L")
        image = self.image_transform(image)  # [1, H, W]

        # Create mask
        mask = self._parse_mask(mask_str, label)

        return image, torch.tensor(label, dtype=torch.float32), mask, patient_id


In [8]:
dataset = LungCTNeedleDatasetV2(
    csv_path="data/trainSet.csv",
    image_dir="data/trainImages",
    image_size=(512, 512),
    use_ignore_index=True  # set to False if you want zero-filled instead
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)


In [9]:
for image, label, mask, patient_id in dataloader:
    print(image.shape)     # [B, 1, 512, 512]
    print(mask.shape)      # [B, 1, 512, 512]
    print(label) 
    print(patient_id)      # List[str]
    break


torch.Size([4, 1, 512, 512])
torch.Size([4, 1, 512, 512])
tensor([1., 1., 1., 0.])
('2603', '2321', '2826', '19285')


In [10]:
train_df[train_df["imageID"] == 16042]

Unnamed: 0,imageID,status,mask
378,16042,0,-100


## Model Architecture

In [11]:
import torch
pretrained_backbone = torch.hub.load("Warvito/radimagenet-models", 'radimagenet_resnet50')

Using cache found in /home/iadam/.cache/torch/hub/Warvito_radimagenet-models_main


In [12]:
import torch
import torch.nn as nn

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from radimagenet_models.models.resnet import radimagenet_resnet50


class AttentionGatedUNet(nn.Module):
    def __init__(self, in_channels=1, feature_dim=2048):
        super().__init__()

        # --- Encoder (RadImageNet pretrained) ---
        self.encoder = radimagenet_resnet50()

        # Patch first conv to accept grayscale
        if in_channels == 1:
            old_conv = self.encoder.conv1
            new_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
            with torch.no_grad():
                new_conv.weight = nn.Parameter(old_conv.weight.mean(dim=1, keepdim=True))
            self.encoder.conv1 = new_conv

        # Encoder blocks
        self.enc1 = nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.relu)  # [B, 64, 256, 256]
        self.enc2 = self.encoder.layer1  # [B, 256, 128, 128]
        self.enc3 = self.encoder.layer2  # [B, 512, 64, 64]
        self.enc4 = self.encoder.layer3  # [B, 1024, 32, 32]
        self.enc5 = self.encoder.layer4  # [B, 2048, 16, 16]

        # --- Attention Gate ---
        self.attn_gate = nn.Sequential(
            nn.Conv2d(feature_dim, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # --- Decoder Blocks (U-Net style) ---
        self.up4 = self._upblock(2048, 1024)  # 16 → 32
        self.up3 = self._upblock(1024, 512)   # 32 → 64
        self.up2 = self._upblock(512, 256)    # 64 → 128
        self.up1 = self._upblock(256, 64)     # 128 → 256
        self.up0 = self._upblock(64, 32)      # 256 → 512 (new learnable final upsample)

        # --- Heads ---
        self.segmentation_head = nn.Conv2d(32, 1, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(feature_dim, 1)
        )

    def _upblock(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def _align_skip(self, x, skip):
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=True)
        return x + skip

    def forward(self, x):
        # --- Encoder ---
        e1 = self.enc1(x)       # [B, 64, 256, 256]
        e2 = self.enc2(e1)      # [B, 256, 128, 128]
        e3 = self.enc3(e2)      # [B, 512, 64, 64]
        e4 = self.enc4(e3)      # [B, 1024, 32, 32]
        e5 = self.enc5(e4)      # [B, 2048, 16, 16]

        # --- Attention Gating ---
        attn = self.attn_gate(e5)     # [B, 1, 16, 16]
        gated = e5 * attn             # [B, 2048, 16, 16]

        # --- Decoder ---
        d4 = self._align_skip(self.up4(gated), e4)  # → [B, 1024, 32, 32]
        d3 = self._align_skip(self.up3(d4), e3)     # → [B, 512, 64, 64]
        d2 = self._align_skip(self.up2(d3), e2)     # → [B, 256, 128, 128]
        d1 = self._align_skip(self.up1(d2), e1)     # → [B, 64, 256, 256]
        d0 = self.up0(d1)                           # → [B, 32, 512, 512]

        # --- Outputs ---
        seg_mask = self.segmentation_head(d0)       # → [B, 1, 512, 512]
        class_logits = self.classifier(gated).squeeze(-1)  # → [B]

        return {
            "segmentation": seg_mask,
            "attention": attn,
            "classification": class_logits
        }



In [14]:
def masked_bce_loss(pred, target, ignore_val=-100.0):
    """
    pred: [B, 1, H, W] — logits
    target: [B, 1, H, W] — binary mask with some pixels = ignore_val
    """
    mask = (target != ignore_val).float()
    target_clean = torch.clamp(target, min=0.0, max=1.0)

    bce = F.binary_cross_entropy_with_logits(pred, target_clean, reduction='none')
    bce = bce * mask

    return bce.sum() / (mask.sum() + 1e-6)


In [15]:
from torch.cuda.amp import autocast, GradScaler

def train_one_epoch_with_eval(model, dataloader, optimizer, device, scaler, lambda_attn=1.0, lambda_cls=1.0, threshold=0.5):
    model.train()
    cls_loss_fn = nn.BCEWithLogitsLoss()
    total_loss = 0.0

    dice_total = 0.0
    sens_total = 0.0
    count = 0

    for batch in dataloader:
        images, labels, masks, _ = batch
        images = images.to(device)
        labels = labels.to(device).float()
        masks = masks.to(device)

        optimizer.zero_grad(set_to_none=True)

        with autocast():
            outputs = model(images)
            seg_pred = outputs['segmentation']
            attn_map = outputs['attention']
            class_logits = outputs['classification']

            if class_logits.ndim == 2:
                class_logits = class_logits.squeeze(-1)

            # --- Losses ---
            loss_seg = masked_bce_loss(seg_pred, masks)
            loss_cls = cls_loss_fn(class_logits, labels)
            attn_neg = attn_map[labels == 0]
            loss_attn = attn_neg.mean() if attn_neg.numel() > 0 else torch.tensor(0.0, device=device)
            loss = loss_seg + lambda_cls * loss_cls + lambda_attn * loss_attn

        # --- Backward + Optimizer ---
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        # --- Metrics (only compute for positives) ---
        needle_present = labels == 1
        if needle_present.sum() > 0:
            seg_pred_bin = torch.sigmoid(seg_pred[needle_present]) > threshold
            gt_mask = masks[needle_present]

            preds_flat = seg_pred_bin.view(seg_pred_bin.size(0), -1)
            masks_flat = gt_mask.view(gt_mask.size(0), -1)

            intersection = (preds_flat * masks_flat).sum(dim=1)
            dice = (2. * intersection) / (preds_flat.sum(dim=1) + masks_flat.sum(dim=1) + 1e-8)
            TP = intersection
            FN = ((~preds_flat) * masks_flat.bool()).sum(dim=1)
            sens = TP / (TP + FN + 1e-8)

            dice_total += dice.sum().item()
            sens_total += sens.sum().item()
            count += preds_flat.size(0)

    avg_loss = total_loss / len(dataloader)
    avg_dice = dice_total / count if count > 0 else 0.0
    avg_sens = sens_total / count if count > 0 else 0.0

    return avg_loss, avg_dice, avg_sens



In [16]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import AdamW

device = "cuda:0"
num_epochs = 200
train_loader = dataloader

#v1: epochs = 100, no scheduler

model = AttentionGatedUNet(in_channels=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler = GradScaler()

for epoch in tqdm(range(num_epochs)):
    loss, dice, sens = train_one_epoch_with_eval(model, train_loader, optimizer, device, scaler)
    print(f"Epoch {epoch+1} | Loss: {loss:.4f} | Dice: {dice:.4f} | Sensitivity: {sens:.4f}")




  0%|          | 0/200 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


Epoch 1 | Loss: 1.3450 | Dice: 0.0104 | Sensitivity: 0.1584
Epoch 2 | Loss: 1.1028 | Dice: 0.0099 | Sensitivity: 0.0091
Epoch 3 | Loss: 1.0294 | Dice: 0.0024 | Sensitivity: 0.0017
Epoch 4 | Loss: 0.9256 | Dice: 0.0028 | Sensitivity: 0.0023


KeyboardInterrupt: 

In [None]:
for epoch in tqdm(range(num_epochs)):
    loss, dice, sens = train_one_epoch_with_eval(model, train_loader, optimizer, device, scaler)
    print(f"Epoch {epoch+1} | Loss: {loss:.4f} | Dice: {dice:.4f} | Sensitivity: {sens:.4f}")
    scheduler.step(dice)