In [24]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

# Set device for GPU acceleration if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [25]:
DATA_DIR = r"D:/ml_med_data/luna"
PATCH_SIZE = 64

In [None]:
ann = pd.read_csv("D:\\ml_med_data\\luna\\annotations.csv")
ann.head()

In [26]:
def load_scan(mhd_path):
    """Loads a .mhd/.raw pair and returns the volumetric data."""
    meta = {}
    with open(mhd_path, 'r') as f:
        for line in f:
            if '=' in line:
                k, v = line.strip().split(' = ')
                meta[k] = v

    dims = list(map(int, meta['DimSize'].split()))
    base_dir = os.path.dirname(mhd_path)
    raw_file_path = os.path.join(base_dir, meta['ElementDataFile'])
    
    with open(raw_file_path, 'rb') as f:
        data = np.fromfile(f, dtype=np.int16)

    # Reshape to (Z, Y, X)
    volume = data.reshape(dims[::-1])
    return volume, meta

def normalize_hu(img):
    """Clips Hounsfield Units and scales to [0, 1]."""
    img = np.clip(img, -1000, 400)
    img = (img + 1000) / 1400
    return img

def extract_3d_patch(volume, center_zyx, patch_size=64):
    """Extracts a 3D cube from the volume centered at specific coordinates."""
    z, y, x = map(int, center_zyx)
    half = patch_size // 2
    
    # Pad volume to handle edge cases
    padded_vol = np.pad(volume, half, mode='constant', constant_values=0)
    
    # Adjust for padding
    z, y, x = z + half, y + half, x + half
    patch = padded_vol[z-half:z+half, y-half:y+half, x-half:x+half]
    return patch

In [27]:
# Load annotations
ann = pd.read_csv(os.path.join(DATA_DIR, "annotations.csv"))
mhd_files = glob.glob(os.path.join(DATA_DIR, "*.mhd"))

X, y = [], []

for fpath in tqdm(mhd_files, desc="Processing Scans"):
    uid = os.path.basename(fpath).replace('.mhd', '')
    volume, _ = load_scan(fpath)
    volume = normalize_hu(volume)
    
    # Positive Patches: Centered on known nodules
    nodules = ann[ann['seriesuid'] == uid]
    for _, row in nodules.iterrows():
        # LUNA coords (X,Y,Z) -> Voxel indices usually requires origin/spacing conversion
        # This assumes your coords are already mapped or close to voxel space
        coord = (row['coordZ'], row['coordY'], row['coordX']) 
        patch = extract_3d_patch(volume, coord, PATCH_SIZE)
        if patch.shape == (PATCH_SIZE, PATCH_SIZE, PATCH_SIZE):
            X.append(patch)
            y.append(1)

    # Negative Patches: Random background crops
    for _ in range(len(nodules)): 
        rand_coord = [np.random.randint(0, s) for s in volume.shape]
        patch = extract_3d_patch(volume, rand_coord, PATCH_SIZE)
        if patch.shape == (PATCH_SIZE, PATCH_SIZE, PATCH_SIZE):
            X.append(patch)
            y.append(0)

X = np.expand_dims(np.array(X), axis=1) # Add channel dim
y = np.array(y)

# Create PyTorch Dataset
dataset = TensorDataset(torch.tensor(X).float(), torch.tensor(y).float().view(-1, 1))
loader = DataLoader(dataset, batch_size=4, shuffle=True)


[A


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Processing Scans: 100%|██████████| 12/12 [00:31<00:00,  2.59s/it]


In [28]:
class NoduleNet3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

model = NoduleNet3D().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

Exception ignored in: <function tqdm.__del__ at 0x000001E868450720>
Traceback (most recent call last):
  File "c:\Users\ngtru\AppData\Local\Programs\Python\Python312\Lib\site-packages\tqdm\std.py", line 1148, in __del__
    self.close()
  File "c:\Users\ngtru\AppData\Local\Programs\Python\Python312\Lib\site-packages\tqdm\notebook.py", line 277, in close
    self.disp(bar_style='danger', check_delay=False)
    ^^^^^^^^^
AttributeError: 'tqdm_notebook' object has no attribute 'disp'


In [29]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(loader):.4f}")

Epoch 1/10, Loss: 0.6880
Epoch 2/10, Loss: 0.6713
Epoch 3/10, Loss: 0.6580
Epoch 4/10, Loss: 0.6501
Epoch 5/10, Loss: 0.6446
Epoch 6/10, Loss: 0.6363
Epoch 7/10, Loss: 0.6315
Epoch 8/10, Loss: 0.6272
Epoch 9/10, Loss: 0.6257
Epoch 10/10, Loss: 0.6197


In [None]:
def nms_3d(detections, threshold_dist=20):
    """Filters overlapping 3D detections (Non-Maximum Suppression)."""
    if not detections: return []
    detections = sorted(detections, key=lambda x: x[3], reverse=True)
    keep = []
    while detections:
        best = detections.pop(0)
        keep.append(best)
        detections = [d for d in detections if np.linalg.norm(np.array(best[:3]) - np.array(d[:3])) > threshold_dist]
    return keep

def detect_nodules(volume, model, stride=32):
    """Applies model over volume using sliding window."""
    model.eval()
    detections = []
    sz = PATCH_SIZE
    with torch.no_grad():
        for z in range(0, volume.shape[0] - sz, stride):
            for y in range(0, volume.shape[1] - sz, stride):
                for x in range(0, volume.shape[2] - sz, stride):
                    patch = volume[z:z+sz, y:y+sz, x:x+sz]
                    patch_t = torch.from_numpy(patch).unsqueeze(0).unsqueeze(0).float().to(device)
                    prob = model(patch_t).item()
                    if prob > 0.8: # Confidence threshold
                        detections.append((z + sz//2, y + sz//2, x + sz//2, prob))
    return nms_3d(detections)