In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import monai.transforms as mt
from monai.transforms import (
    EnsureChannelFirst,
    ScaleIntensity,
    ToTensor
)

class HEDNet(nn.Module):
    """HED Network optimisé pour les CT-scans"""
    def __init__(self):
        super(HEDNet, self).__init__()
        
        # Couches de convolution spécialisées pour CT-scans
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.conv2 = nn.Sequential(
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.conv3 = nn.Sequential(
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Sorties latérales avec attention
        self.side1 = nn.Sequential(
            nn.Conv2d(64, 1, 1),
            nn.BatchNorm2d(1)
        )
        
        self.side2 = nn.Sequential(
            nn.Conv2d(128, 1, 1),
            nn.BatchNorm2d(1)
        )
        
        self.side3 = nn.Sequential(
            nn.Conv2d(256, 1, 1),
            nn.BatchNorm2d(1)
        )
        
        # Fusion avec poids adaptatifs
        self.fuse = nn.Sequential(
            nn.Conv2d(3, 1, 1),
            nn.BatchNorm2d(1)
        )
        
    def forward(self, x):
        # Feature extraction
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        
        # Side outputs avec attention aux bords
        s1 = self.side1(c1)
        s2 = F.interpolate(self.side2(c2), size=x.shape[2:], mode='bilinear', align_corners=True)
        s3 = F.interpolate(self.side3(c3), size=x.shape[2:], mode='bilinear', align_corners=True)
        
        # Fusion adaptative
        fuse = self.fuse(torch.cat([s1, s2, s3], dim=1))
        
        return torch.sigmoid(fuse)

In [None]:
class HumerusDataset:
    """Class for handling the humerus CT scan dataset"""
    
    def __init__(self, data_dir: Path = Config.DATA_DIR / 'humerus'):
        self.data_dir = data_dir
        self.series_pattern = "hum*.dcm"
        self.n_expected_slices = 102
    
    def load_series(self) -> Tuple[np.ndarray, Dict]:
        """
        Load the complete humerus DICOM series
        
        Returns:
            Tuple containing:
            - 3D numpy array of the CT scan
            - Dictionary of metadata
        """
        # Get all DICOM files in order
        dicom_files = sorted(self.data_dir.glob(self.series_pattern))
        
        if not dicom_files:
            raise FileNotFoundError(f"No DICOM files found in {self.data_dir}")
        
        if len(dicom_files) != self.n_expected_slices:
            print(f"Warning: Found {len(dicom_files)} slices, expected {self.n_expected_slices}")
        
        # Load the first slice to get metadata and image dimensions
        first_slice = sitk.ReadImage(str(dicom_files[0]))
        first_array = sitk.GetArrayFromImage(first_slice)
        
        # Create volume with correct shape (n_slices, height, width)
        volume = np.zeros((len(dicom_files), *first_array.shape[1:]), dtype=np.float32)
        
        # Load all slices
        print("Loading DICOM series...")
        for idx, file_path in enumerate(dicom_files):
            img = sitk.ReadImage(str(file_path))
            slice_array = sitk.GetArrayFromImage(img)[0]  # Get 2D slice
            volume[idx] = slice_array
        
        # Get metadata from first slice
        metadata = {
            'spacing': first_slice.GetSpacing(),
            'origin': first_slice.GetOrigin(),
            'direction': first_slice.GetDirection(),
            'size': first_slice.GetSize()
        }
        
        # Transpose volume to have shape (height, width, n_slices)
        volume = np.transpose(volume, (1, 2, 0))
        
        return volume, metadata
    
    def visualize_slices(self, volume: np.ndarray, n_samples: int = 4):
        """
        Visualize sample slices from the volume
        
        Args:
            volume: 3D numpy array of the CT scan
            n_samples: Number of slices to visualize
        """
        fig, axes = plt.subplots(1, n_samples, figsize=(20, 5))
        slice_indices = np.linspace(0, volume.shape[2]-1, n_samples, dtype=int)
        
        for i, idx in enumerate(slice_indices):
            axes[i].imshow(volume[:, :, idx], cmap='bone')
            axes[i].set_title(f'Slice {idx}')
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    def preprocess_volume(self, volume: np.ndarray) -> np.ndarray:
        """
        Preprocess the CT volume
        
        Args:
            volume: Raw CT volume
            
        Returns:
            Preprocessed volume normalized to [0,1]
        """
        # Convert to HU units (assuming it's not already in HU)
        volume_hu = volume.astype(float)
        
        # Clip to bone window
        volume_hu = np.clip(volume_hu, Config.HU_MIN, Config.HU_MAX)
        
        # Normalize to [0,1]
        volume_norm = (volume_hu - Config.HU_MIN) / (Config.HU_MAX - Config.HU_MIN)
        
        return volume_norm