In [1]:
import os
import numpy as np
import pandas as pd
import pydicom
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from sklearn.model_selection import train_test_split
from tqdm import tqdm

print(torch.cuda.is_available())
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

True
2.5.1+cu121
12.1
90100


In [2]:
# --- Check Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# --- Load labels ---
labels = pd.read_csv("../../dataset/rsna/rsna_detailed_class_info.csv") 

labels['Target'] = labels['class'].map({
    'Lung Opacity': 1,
    'No Lung Opacity / Not Normal': 0,
    'Normal': 0
})


print(labels['class'].value_counts())
print(labels['Target'].value_counts())

class
No Lung Opacity / Not Normal    11821
Lung Opacity                     9555
Normal                           8851
Name: count, dtype: int64
Target
0    20672
1     9555
Name: count, dtype: int64


In [4]:
# --- Custom Dataset ---
class RSNADataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        dcm_path = os.path.join(self.root_dir, row['patientId'] + '.dcm')
        dcm = pydicom.dcmread(dcm_path)
        image = dcm.pixel_array.astype(np.float32)
        image = (image - np.min(image)) / (np.max(image) - np.min(image))
        image = np.stack([image] * 3, axis=-1)

        if self.transform:
            image = self.transform(image)

        return image, row['Target']

    def __len__(self):
        return len(self.df)

In [5]:
# --- Transforms ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])

# --- Data Split ---
train_df, val_df = train_test_split(labels, test_size=0.2, stratify=labels['Target'])

# --- Datasets and Loaders ---
train_ds = RSNADataset(train_df, "../../dataset/rsna/train_dicom", transform)
val_ds = RSNADataset(val_df, "../../dataset/rsna/train_dicom", transform)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16)

In [6]:
# --- Model Setup ---
model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [7]:
# --- Training Loop ---
for epoch in range(5):
    model.train()
    total_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} complete. Avg Loss: {total_loss / len(train_loader):.4f}")

Epoch 1: 100%|██████████| 1512/1512 [11:02<00:00,  2.28it/s]


Epoch 1 complete. Avg Loss: 0.4077


Epoch 2: 100%|██████████| 1512/1512 [09:38<00:00,  2.62it/s]


Epoch 2 complete. Avg Loss: 0.3558


Epoch 3: 100%|██████████| 1512/1512 [09:27<00:00,  2.67it/s]


Epoch 3 complete. Avg Loss: 0.2994


Epoch 4: 100%|██████████| 1512/1512 [09:02<00:00,  2.79it/s]


Epoch 4 complete. Avg Loss: 0.2037


Epoch 5: 100%|██████████| 1512/1512 [09:03<00:00,  2.78it/s]

Epoch 5 complete. Avg Loss: 0.1126





In [8]:
# --- Save Model ---
# os.makedirs("train/models", exist_ok=True)
torch.save(model.state_dict(), "models/dicom_model.pth")

In [9]:
# --- Validation Block ---
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Validating"):
        images = images.to(device)
        labels = labels.to(device).float().unsqueeze(1)

        outputs = model(images)
        preds = (torch.sigmoid(outputs) >= 0.5).float()

        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"\n✅ Validation Accuracy: {accuracy * 100:.2f}%")

Validating: 100%|██████████| 378/378 [02:31<00:00,  2.49it/s]


✅ Validation Accuracy: 85.84%



