# Import Library

In [1]:
import os
import glob

import torch
import torch.nn as nn
import torchvision.models as models

import pandas as pd
import pydicom
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader

from skimage.transform import resize
from typing import Dict, Tuple, List

# Config

In [2]:
DATASET_DIR = '/kaggle/input/mosmeddata-ct-hemorrhage-type-viii/MosMedData-CT-HEMORRHAGE-type VIII/' 
IMG_SIZE = (512, 512)
BATCH_SIZE = 8
MAX_IMAGES_PER_SERIES = 512
START_IDX = 0
NUM_CLASSES = 2  # 0: normal, 1: hemorrhage
NUM_CHANNELS = 1  # Assuming grayscale images

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Read Image

In [4]:
def find_clinical_sheet(excel_file, keywords):
    """
    Finds the sheet containing clinical data in an Excel file.

    Parameters:
    excel_file (pd.ExcelFile): The Excel file object.
    keywords (list): List of keywords to identify the clinical sheet.

    Returns:
    str: The name of the sheet containing clinical data.

    Raises:
    ValueError: If no sheet with clinical data is found.
    """
    for sheet in excel_file.sheet_names:
        if any(keyword in sheet.lower() for keyword in keywords):
            return sheet
    
    raise ValueError('No clinical sheet found in the Excel file.')

def read_clinical_data(file_path):
    """
    Reads clinical data from an Excel file.

    Parameters:
    file_path (str): The path to the Excel file.

    Returns:
    pandas.DataFrame: A DataFrame containing the clinical data.

    Raises:
    ValueError: If no sheet with clinical data is found in the Excel file.
    """
    clinical_keywords = ['clinical', 'клинические']
    
    try:
        xls = pd.ExcelFile(file_path)
        clinical_sheet = find_clinical_sheet(xls, clinical_keywords)
        return pd.read_excel(xls, clinical_sheet)
    except ValueError as e:
        raise ValueError(f"Error reading clinical data: {str(e)}")

In [5]:
def preprocess_images(dicom_images, img_size=IMG_SIZE):
    """
    Preprocesses DICOM images for machine learning model input.

    Parameters:
    dicom_images (dict): A dictionary with a single key, where the value is a list of pydicom.dataset.FileDataset objects.
    img_size (tuple): Target image size (height, width). Default is (512, 512).

    Returns:
    numpy.ndarray: 4D array of preprocessed images (n_images, height, width, 3).
    """
    # Get the list of DICOM images from the dictionary
    images_list = list(dicom_images.values())[0]
    
    # Apply preprocess_single_image to each image in the list
    return np.array([preprocess_single_image(img, img_size) for img in images_list])

def preprocess_single_image(dicom_image, img_size):
    """
    Preprocesses a single DICOM image.

    Parameters:
    dicom_image (pydicom.dataset.FileDataset): DICOM image object.
    img_size (tuple): Target image size (height, width).

    Returns:
    numpy.ndarray: Preprocessed image as a 3D array (height, width, 3).
    """
    image_array = dicom_image.pixel_array
    normalized_image = normalize_image(image_array)
    resized_image = resize_image(normalized_image, img_size)
    return convert_to_rgb(resized_image)

def normalize_image(image):
    """
    Normalizes image pixel values to [0, 1] range.

    Parameters:
    image (numpy.ndarray): Input image array.

    Returns:
    numpy.ndarray: Normalized image array.
    """
    max_value = np.max(image)
    return image / max_value if max_value > 0 else np.zeros_like(image)

def resize_image(image, target_size):
    """
    Resizes image to target size using anti-aliasing.

    Parameters:
    image (numpy.ndarray): Input image array.
    target_size (tuple): Target image size (height, width).

    Returns:
    numpy.ndarray: Resized image array.
    """
    return resize(image, target_size, anti_aliasing=True)

def convert_to_rgb(image):
    """
    Converts grayscale image to RGB if necessary.

    Parameters:
    image (numpy.ndarray): Input image array.

    Returns:
    numpy.ndarray: RGB image array.
    """
    if len(image.shape) == 2:
        return np.stack((image,) * 3, axis=-1)
    return image

In [6]:
class BaseDatasetGenerator(Dataset):
    def __init__(self, dataset_dir: str, start_idx: int = 0, batch_size: int = BATCH_SIZE, max_images_per_series: int = MAX_IMAGES_PER_SERIES):
        self.dataset_dir = dataset_dir
        self.start_idx = start_idx
        self.batch_size = batch_size
        self.max_images_per_series = max_images_per_series
        self.studies_folders = self._get_studies_folders()
        self.clinical_data = self._load_clinical_data()
        self.clinical_data = self._filter_clinical_data()
        self.study_uids = self._get_study_uids()
        self.current_index = self.start_idx

    def _get_studies_folders(self) -> List[str]:
        return [f for f in os.listdir(self.dataset_dir) if os.path.isdir(os.path.join(self.dataset_dir, f))]

    def _load_clinical_data(self) -> pd.DataFrame:
        all_clinical_data = []
        for folder in self.studies_folders:
            folder_path = os.path.join(self.dataset_dir, folder, folder)
            excel_files = glob.glob(os.path.join(folder_path, '*.xlsx'))
            
            if not excel_files:
                print(f"Warning: No Excel file found in {folder_path}")
                continue
            
            if len(excel_files) > 1:
                print(f"Warning: Multiple Excel files found in {folder_path}. Using the first one.")
            
            excel_path = excel_files[0]
            df = pd.read_excel(excel_path)
            df['folder'] = folder
            all_clinical_data.append(df)
        
        if not all_clinical_data:
            raise FileNotFoundError(f"No Excel files found in any of the studies folders in {self.dataset_dir}")
        
        return pd.concat(all_clinical_data, ignore_index=True)

    def _filter_clinical_data(self) -> pd.DataFrame:
        hemorrhage_columns = [
            'epidural hemorrhage', 'subarachnoid hemorrhage', 
            'subdural hemorrhage', 'intracerebral hemorrhage', 
            'multiple hemorrhages', 'skull fracture'
        ]

        filtered_data = self.clinical_data[self.clinical_data['Comment'] != "Study without report"]

        for column in hemorrhage_columns:
            filtered_data = filtered_data[filtered_data[column].isin([0.0, 1.0])]

        # Add ICH column
        filtered_data['ICH'] = filtered_data[hemorrhage_columns].max(axis=1)

        return filtered_data

    def _get_study_uids(self) -> np.ndarray:
        return self.clinical_data['study_uid'].unique()

    def __len__(self) -> int:
        return len(self.study_uids)

    def __getitem__(self, idx: int) -> Tuple[np.ndarray, Dict]:
        study_uid = self.study_uids[idx]
        dicom_series = self._load_dicom_images(study_uid)
        labels = self._get_labels(study_uid)
        processed_images = self.preprocess_images(dicom_series, img_size=IMG_SIZE)
        return processed_images, labels

    def _load_dicom_images(self, study_uid: str) -> Dict[str, List[pydicom.dataset.FileDataset]]:
        dicom_series = {}
        study_folder = self.clinical_data[self.clinical_data['study_uid'] == study_uid]['folder'].iloc[0]
        study_path = os.path.join(self.dataset_dir, study_folder, study_folder, study_uid)
        
        for root, _, files in os.walk(study_path):
            for file in files:
                if file.endswith('.dcm'):
                    file_path = os.path.join(root, file)
                    dicom_image = pydicom.dcmread(file_path, force=True)
                    series_uid = dicom_image.SeriesInstanceUID
                    if series_uid not in dicom_series:
                        dicom_series[series_uid] = []
                    dicom_series[series_uid].append(dicom_image)
        return dicom_series

    def _get_labels(self, study_uid: str) -> Dict:
        study_clinical_data = self.clinical_data[self.clinical_data['study_uid'] == study_uid]
        return study_clinical_data.iloc[0].to_dict() if not study_clinical_data.empty else None

    @staticmethod
    def preprocess_images(dicom_series: Dict[str, List[pydicom.dataset.FileDataset]], img_size: Tuple[int, int]) -> np.ndarray:
        processed_images = []
        for series in dicom_series.values():
            for dicom_image in series[:MAX_IMAGES_PER_SERIES]:
                pixel_array = dicom_image.pixel_array
                resized_image = resize(pixel_array, img_size, anti_aliasing=True)
                processed_images.append(resized_image)
        
        # Pad if necessary
        if len(processed_images) < MAX_IMAGES_PER_SERIES:
            pad_width = ((0, MAX_IMAGES_PER_SERIES - len(processed_images)), (0, 0), (0, 0))
            processed_images = np.pad(processed_images, pad_width, mode='constant', constant_values=0)
        
        return np.array(processed_images)

class TrainDatasetGenerator(BaseDatasetGenerator):
    def __init__(self, dataset_dir: str, studies_folder: str, 
                 start_idx: int = 0, batch_size: int = BATCH_SIZE, max_images_per_series: int = MAX_IMAGES_PER_SERIES):
        super().__init__(dataset_dir, start_idx, batch_size, max_images_per_series)
        self.studies_folder = studies_folder
        self.study_uids = self._get_study_uids_from_folder()

    def _get_study_uids_from_folder(self) -> np.ndarray:
        folder_path = os.path.join(self.dataset_dir, self.studies_folder, self.studies_folder)
        return np.array([d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))])

class TestDatasetGenerator(BaseDatasetGenerator):
    def __init__(self, dataset_dir: str, 
                 start_idx: int = 0, batch_size: int = BATCH_SIZE, max_images_per_series: int = MAX_IMAGES_PER_SERIES):
        super().__init__(dataset_dir, start_idx, batch_size, max_images_per_series)
        # You can add test-specific functionality here if needed


In [7]:
class ResNet503D(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(ResNet503D, self).__init__()
        resnet = models.resnet50(pretrained=True)
        
        self.conv1 = nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
        self.conv1.weight.data = resnet.conv1.weight.data.unsqueeze(2).repeat(1, 1, 7, 1, 1) / 7
        
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer_3d(resnet.layer1, 64, 64, 3)
        self.layer2 = self._make_layer_3d(resnet.layer2, 128, 128, 4)
        self.layer3 = self._make_layer_3d(resnet.layer3, 256, 256, 6)
        self.layer4 = self._make_layer_3d(resnet.layer4, 512, 512, 3)
        
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * 4, num_classes)

    def _make_layer_3d(self, layer_2d, inplanes, planes, blocks):
        layers = []
        for i in range(blocks):
            layers.append(nn.Conv3d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False))
            layers.append(nn.BatchNorm3d(planes))
            layers.append(nn.ReLU(inplace=True))
            inplanes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [8]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs = inputs.unsqueeze(1).float().to(device)  # Add channel dimension and convert to float
            labels = torch.tensor([label['ICH'] for label in labels]).long().to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100. * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

In [9]:
def plot_dicom_images(images, labels, num_images=4):
    fig, axes = plt.subplots(1, num_images, figsize=(20, 5))
    for i, ax in enumerate(axes):
        if i < len(images):
            ax.imshow(images[i], cmap='gray')
            ax.set_title(f"Image {i+1}")
            # ax.axis('off')
        else:
            ax.axis('off')
    plt.tight_layout()
    plt.show()
    
    print("Labels:")
    for key, value in labels.items():
        print(f"{key}: {value}")


In [10]:
if __name__ == "__main__":
    studies_folder = '400_500_studies'  # Specify the folder for training data
    train_dataset = TrainDatasetGenerator(dataset_dir=DATASET_DIR, studies_folder=studies_folder)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    test_dataset = TestDatasetGenerator(dataset_dir=DATASET_DIR)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = ResNet503D(num_classes=NUM_CLASSES)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    train_model(model, train_loader, criterion, optimizer, num_epochs=10)

    # After training, you can evaluate the model on the test set
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.unsqueeze(1).float().to(device)
            labels = torch.tensor([label['ICH'] for label in labels]).long().to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

    test_accuracy = 100. * test_correct / test_total
    print(f'Test Accuracy: {test_accuracy:.2f}%')

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 188MB/s] 


IndexError: single positional indexer is out-of-bounds