In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import albumentations as A
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List
import pydicom 

In [2]:
# VerSe Dataset (Vertèbres)
# https://verse2020.grand-challenge.org/

# MURA Dataset (Stanford)
# https://stanfordmlgroup.github.io/competitions/mura/

In [3]:
class CTDataLoader:
    def __init__(self, data_dir: Path):
        self.data_dir = data_dir

    def load_dicom_series(self, series_dir: Path) -> Tuple[np.ndarray, Dict]:
        """Charge une série DICOM"""
        # Charger tous les fichiers DICOM
        dicom_files = sorted(series_dir.glob('*.dcm'))
        
        # Lire le premier pour les métadonnées
        first_slice = pydicom.dcmread(str(dicom_files[0]))
        
        # Préparer le volume
        volume = np.zeros((first_slice.Rows, first_slice.Columns, len(dicom_files)))
        
        # Charger chaque tranche
        for i, dcm_file in enumerate(dicom_files):
            slice_data = pydicom.dcmread(str(dcm_file))
            volume[:, :, i] = slice_data.pixel_array
        
        # Métadonnées importantes
        metadata = {
            'spacing': (
                float(first_slice.PixelSpacing[0]),
                float(first_slice.PixelSpacing[1]),
                float(first_slice.SliceThickness)
            ),
            'origin': first_slice.ImagePositionPatient,
            'direction': first_slice.ImageOrientationPatient
        }
        
        return volume, metadata

    def load_nifti(self, file_path: Path) -> Tuple[np.ndarray, Dict]:
        """Charge un fichier NIfTI"""
        nifti_img = nib.load(str(file_path))
        volume = nifti_img.get_fdata()
        
        metadata = {
            'affine': nifti_img.affine,
            'header': dict(nifti_img.header)
        }
        
        return volume, metadata

In [4]:
def load_dicom_volume(dicom_dir: Path) -> Tuple[np.ndarray, Dict]:
    """Charge une série DICOM de phalanges avec détection automatique de la taille"""
    print(f"Loading DICOM series from {dicom_dir}...")
    
    # Trouver tous les fichiers phalanx
    dicom_files = sorted(dicom_dir.glob('phalanx*.dcm'))
    if not dicom_files:
        raise FileNotFoundError(f"No phalanx DICOM files found in {dicom_dir}")
    
    print(f"Found {len(dicom_files)} DICOM files")
    
    try:
        # Lire le premier fichier pour déterminer la taille
        first_file = dicom_files[0]
        with open(str(first_file), 'rb') as f:
            raw_data = np.fromfile(f, dtype=np.uint16)
            
        # Calculer les dimensions possibles
        total_pixels = len(raw_data)
        print(f"Total pixels in first file: {total_pixels}")
        
        # Trouver les facteurs pour déterminer les dimensions possibles
        factors = []
        for i in range(1, int(np.sqrt(total_pixels)) + 1):
            if total_pixels % i == 0:
                factors.append((i, total_pixels // i))
        
        print("Possible dimensions:", factors)
        
        # Choisir les dimensions les plus proches d'un carré
        rows, cols = min(factors, key=lambda x: abs(x[0] - x[1]))
        print(f"Selected dimensions: {rows}x{cols}")
        
        # Initialiser le volume
        volume = np.zeros((rows, cols, len(dicom_files)))
        
        # Charger chaque tranche
        for i, dcm_file in enumerate(tqdm(dicom_files, desc="Loading slices")):
            try:
                with open(str(dcm_file), 'rb') as f:
                    raw_data = np.fromfile(f, dtype=np.uint16)
                    img_array = raw_data.reshape(rows, cols)
                volume[:, :, i] = img_array
                
                # Debug: afficher la première tranche
                if i == 0:
                    plt.figure(figsize=(10, 10))
                    plt.imshow(img_array, cmap='bone')
                    plt.title("First slice")
                    plt.colorbar()
                    plt.show()
                
            except Exception as e:
                print(f"\nWarning: Error loading {dcm_file.name}: {e}")
                if i > 0:
                    volume[:, :, i] = volume[:, :, i-1]
                continue
        
        # Normalisation du volume
        volume = volume.astype(float)
        volume = (volume - volume.min()) / (volume.max() - volume.min())
        
        print(f"\nVolume loaded successfully!")
        print(f"Shape: {volume.shape}")
        print(f"Value range: [{volume.min():.2f}, {volume.max():.2f}]")
        
        # Visualiser quelques tranches
        visualize_slices(volume)
        
        return volume, {'spacing': [1.0, 1.0, 1.0]}
        
    except Exception as e:
        print(f"Error in load_dicom_volume: {e}")
        raise

def visualize_slices(volume: np.ndarray, n_slices: int = 4):
    """Visualise quelques tranches du volume"""
    fig, axes = plt.subplots(1, n_slices, figsize=(20, 5))
    step = volume.shape[2] // n_slices
    
    for i in range(n_slices):
        idx = i * step
        axes[i].imshow(volume[:, :, idx], cmap='bone')
        axes[i].set_title(f'Slice {idx}')
        axes[i].axis('off')
        plt.colorbar(axes[i].imshow(volume[:, :, idx], cmap='bone'), ax=axes[i])
    
    plt.tight_layout()
    plt.show()


In [5]:
def train_model(model, train_loader, val_loader, epochs=100):
    """Entraînement du modèle"""
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        # Mode entraînement
        model.train()
        train_loss = 0
        
        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}')):
            data, target = data.to(Config.DEVICE), target.to(Config.DEVICE)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Visualisation périodique
            if batch_idx % 10 == 0:
                visualize_batch(data, target, output, epoch, batch_idx)
        
        train_loss /= len(train_loader)
        
        # Mode validation
        model.eval()
        val_loss = 0
        dice_scores = []
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(Config.DEVICE), target.to(Config.DEVICE)
                output = model(data)
                val_loss += criterion(output, target).item()
                
                # Calculer le score Dice
                pred = (torch.sigmoid(output) > 0.5).float()
                dice = (2. * (pred * target).sum()) / (pred.sum() + target.sum() + 1e-6)
                dice_scores.append(dice.item())
        
        val_loss /= len(val_loader)
        mean_dice = np.mean(dice_scores)
        
        # Mise à jour du scheduler
        scheduler.step(val_loss)
        
        # Sauvegarder le meilleur modèle
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
            }, Config.MODELS_DIR / 'best_model.pth')
        
        print(f'Epoch {epoch+1}:')
        print(f'Train Loss: {train_loss:.4f}')
        print(f'Val Loss: {val_loss:.4f}')
        print(f'Dice Score: {mean_dice:.4f}')

def visualize_batch(data, target, output, epoch, batch_idx):
    """Visualise les résultats pendant l'entraînement"""
    with torch.no_grad():
        # Prendre le premier exemple du batch
        img = data[0, 0].cpu().numpy()
        mask = target[0, 0].cpu().numpy()
        pred = torch.sigmoid(output[0, 0]).cpu().numpy() > 0.5
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(img, cmap='bone')
        axes[0].set_title('Input')
        
        axes[1].imshow(mask, cmap='bone')
        axes[1].set_title('Target')
        
        axes[2].imshow(pred, cmap='bone')
        axes[2].set_title('Prediction')
        
        plt.suptitle(f'Epoch {epoch+1}, Batch {batch_idx}')
        plt.show()


In [6]:
class Config:
    DATA_DIR = Path('./data')
    OUTPUT_DIR = Path('./output')
    MODELS_DIR = Path('./models')
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    @classmethod
    def create_directories(cls):
        for dir_path in [cls.DATA_DIR, cls.OUTPUT_DIR, cls.MODELS_DIR]:
            dir_path.mkdir(parents=True, exist_ok=True)


In [7]:
def create_synthetic_masks(volume: np.ndarray) -> np.ndarray:
    """Crée des masques synthétiques pour l'entraînement initial"""
    print("Creating synthetic masks...")
    print(f"Input volume shape: {volume.shape}")
    
    # Calculer les dimensions optimales
    total_pixels = volume.shape[1]  # 262697
    factors = []
    for i in range(1, int(np.sqrt(total_pixels)) + 1):
        if total_pixels % i == 0:
            factors.append((i, total_pixels // i))
    
    print("Possible dimensions:", factors)
    
    # Trouver les dimensions les plus proches de 512x512
    target_size = 512
    optimal_dims = min(factors, key=lambda x: abs(x[0] - target_size) + abs(x[1] - target_size))
    height, width = optimal_dims
    print(f"Original dimensions: {height}x{width}")
    
    # Redimensionner en 512x512
    n_slices = volume.shape[2]
    volume_reshaped = np.zeros((512, 512, n_slices))
    
    for i in range(n_slices):
        # Redimensionner chaque tranche
        slice_data = volume[0, :, i].reshape(height, width)
        # Interpolation pour obtenir une image 512x512
        slice_2d = cv2.resize(slice_data, (512, 512))
        volume_reshaped[:, :, i] = slice_2d
        
        # Debug: afficher la première tranche
        if i == 0:
            plt.figure(figsize=(15, 5))
            plt.subplot(121)
            plt.imshow(slice_data, cmap='bone')
            plt.title(f'Original ({height}x{width})')
            plt.colorbar()
            
            plt.subplot(122)
            plt.imshow(slice_2d, cmap='bone')
            plt.title('Resized (512x512)')
            plt.colorbar()
            
            plt.tight_layout()
            plt.show()
    
    print(f"Reshaped volume: {volume_reshaped.shape}")
    
    # Créer les masques
    masks = np.zeros_like(volume_reshaped)
    
    for i in range(n_slices):
        slice_data = volume_reshaped[:, :, i]
        
        # 1. Seuillage adaptatif
        threshold = np.percentile(slice_data, 95)
        binary = (slice_data > threshold).astype(np.uint8)
        
        # 2. Opérations morphologiques
        kernel = np.ones((5,5), np.uint8)
        binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
        binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
        
        # 3. Trouver les contours
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        # 4. Sélectionner le plus grand contour
        if contours:
            largest_contour = max(contours, key=cv2.contourArea)
            
            # 5. Créer un masque avec le contour
            mask = np.zeros_like(slice_data)
            cv2.drawContours(mask, [largest_contour], -1, 1, 2)
            
            masks[:, :, i] = mask
            
        # Debug: afficher quelques tranches
        if i % 20 == 0:
            plt.figure(figsize=(15, 5))
            plt.subplot(131)
            plt.imshow(slice_data, cmap='bone')
            plt.title(f'Original Slice {i}')
            plt.colorbar()
            
            plt.subplot(132)
            plt.imshow(binary, cmap='bone')
            plt.title('Binary Mask')
            plt.colorbar()
            
            plt.subplot(133)
            plt.imshow(masks[:, :, i], cmap='bone')
            plt.title('Final Mask')
            plt.colorbar()
            
            plt.tight_layout()
            plt.show()
    
    return volume_reshaped, masks
def visualize_masks(volume: np.ndarray, masks: np.ndarray, n_slices: int = 4):
    """Visualise les masques avec les images originales"""
    fig, axes = plt.subplots(2, n_slices, figsize=(20, 10))
    step = volume.shape[2] // n_slices
    
    for i in range(n_slices):
        idx = i * step
        
        # Image originale
        axes[0, i].imshow(volume[:, :, idx], cmap='bone')
        axes[0, i].set_title(f'Original Slice {idx}')
        axes[0, i].axis('off')
        
        # Masque
        axes[1, i].imshow(masks[:, :, idx], cmap='bone')
        axes[1, i].set_title(f'Mask Slice {idx}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

class BoneDataset(Dataset):
    """Dataset pour les images CT et leurs masques"""
    def __init__(self, volume: np.ndarray, masks: np.ndarray, transform=None):
        self.volume = volume
        self.masks = masks
        self.transform = transform
        
        print(f"Dataset initialized with volume shape: {volume.shape}")
        print(f"Masks shape: {masks.shape}")

    def __len__(self):
        return self.volume.shape[2]

    def __getitem__(self, idx):
        image = self.volume[:, :, idx]
        mask = self.masks[:, :, idx]

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return (torch.from_numpy(image).float().unsqueeze(0),
                torch.from_numpy(mask).float().unsqueeze(0))
                
def create_dataloaders(volume: np.ndarray, masks: np.ndarray, 
                      train_ratio: float = 0.8, batch_size: int = 8):
    """Crée les dataloaders pour l'entraînement"""
    
    # Augmentation pour l'entraînement
    train_transform = A.Compose([
        A.RandomRotate90(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5),
        A.OneOf([
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
            A.OpticalDistortion(distort_limit=0.3, p=0.5),
        ], p=0.3),
        A.OneOf([
            A.GaussNoise(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.RandomGamma(p=0.5),
        ], p=0.3),
    ])

    # Diviser en train/val
    n_slices = volume.shape[2]
    n_train = int(n_slices * train_ratio)
    
    train_volume = volume[:, :, :n_train]
    train_masks = masks[:, :, :n_train]
    val_volume = volume[:, :, n_train:]
    val_masks = masks[:, :, n_train:]

    print(f"\nTrain volume shape: {train_volume.shape}")
    print(f"Val volume shape: {val_volume.shape}")

    # Créer les datasets
    train_dataset = BoneDataset(train_volume, train_masks, transform=train_transform)
    val_dataset = BoneDataset(val_volume, val_masks)

    # Créer les dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True
    )

    return train_loader, val_loader
def visualize_augmentations(dataset, n_samples=3):
    """Visualise les augmentations de données"""
    fig, axes = plt.subplots(n_samples, 3, figsize=(15, 5*n_samples))
    
    for i in range(n_samples):
        # Obtenir un échantillon original
        image, mask = dataset[i]
        image = image.squeeze().numpy()
        mask = mask.squeeze().numpy()
        
        # Appliquer l'augmentation
        if dataset.transform:
            augmented = dataset.transform(image=image, mask=mask)
            aug_image = augmented['image']
            aug_mask = augmented['mask']
        else:
            aug_image = image
            aug_mask = mask
        
        # Afficher
        axes[i, 0].imshow(image, cmap='bone')
        axes[i, 0].set_title('Original')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(aug_image, cmap='bone')
        axes[i, 1].set_title('Augmented')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(aug_mask, cmap='bone')
        axes[i, 2].set_title('Augmented Mask')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

In [8]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = self._block(in_channels, 64)
        self.enc2 = self._block(64, 128)
        self.enc3 = self._block(128, 256)
        self.enc4 = self._block(256, 512)
        
        # Bottleneck
        self.bottleneck = self._block(512, 1024)
        
        # Decoder
        self.dec4 = self._block(1024, 512)
        self.dec3 = self._block(512, 256)
        self.dec2 = self._block(256, 128)
        self.dec1 = self._block(128, 64)
        
        # Final conv
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
        # Max pooling
        self.pool = nn.MaxPool2d(2)
        
        # Upsampling
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))
        
        # Decoder
        dec4 = self.dec4(torch.cat([self.up4(bottleneck), enc4], 1))
        dec3 = self.dec3(torch.cat([self.up3(dec4), enc3], 1))
        dec2 = self.dec2(torch.cat([self.up2(dec3), enc2], 1))
        dec1 = self.dec1(torch.cat([self.up1(dec2), enc1], 1))
        
        return self.final_conv(dec1)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

def load_and_display_ct(dicom_dir: Path, slice_idx: int = 0):
    """Charge et affiche une tranche du CT scan"""
    print(f"Loading DICOM series from {dicom_dir}...")
    
    # Trouver tous les fichiers DICOM
    dicom_files = sorted(dicom_dir.glob('phalanx*.dcm'))
    if not dicom_files:
        raise FileNotFoundError(f"No DICOM files found in {dicom_dir}")
    
    print(f"Found {len(dicom_files)} files")
    
    # Charger la tranche spécifiée
    slice_data = pydicom.dcmread(str(dicom_files[slice_idx]), force=True)
    raw_data = slice_data.pixel_array
    
    print(f"\nSlice {slice_idx} info:")
    print(f"Shape: {raw_data.shape}")
    print(f"Data type: {raw_data.dtype}")
    print(f"Value range: [{raw_data.min()}, {raw_data.max()}]")
    
    # Normaliser les données pour l'affichage
    normalized_data = (raw_data - raw_data.min()) / (raw_data.max() - raw_data.min())
    
    # Créer une figure avec plusieurs vues
    fig, axes = plt.subplots(2, 2, figsize=(20, 20))
    
    # 1. Image originale
    im1 = axes[0, 0].imshow(raw_data, cmap='bone')
    axes[0, 0].set_title('Image originale')
    plt.colorbar(im1, ax=axes[0, 0])
    
    # 2. Image normalisée
    im2 = axes[0, 1].imshow(normalized_data, cmap='bone')
    axes[0, 1].set_title('Image normalisée')
    plt.colorbar(im2, ax=axes[0, 1])
    
    # 3. Histogramme des valeurs originales
    axes[1, 0].hist(raw_data.ravel(), bins=100, color='blue', alpha=0.7)
    axes[1, 0].set_title('Distribution des valeurs originales')
    axes[1, 0].set_xlabel('Intensité')
    axes[1, 0].set_ylabel('Fréquence')
    axes[1, 0].grid(True)
    
    # 4. Histogramme des valeurs normalisées
    axes[1, 1].hist(normalized_data.ravel(), bins=100, color='green', alpha=0.7)
    axes[1, 1].set_title('Distribution des valeurs normalisées')
    axes[1, 1].set_xlabel('Intensité normalisée')
    axes[1, 1].set_ylabel('Fréquence')
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Afficher les profils d'intensité
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 5))
    
    # Profil horizontal (milieu de l'image)
    middle_row = raw_data.shape[0] // 2
    ax1.plot(raw_data[middle_row, :])
    ax1.set_title(f'Profil horizontal (ligne {middle_row})')
    ax1.set_xlabel('Position X')
    ax1.set_ylabel('Intensité')
    ax1.grid(True)
    
    # Profil vertical (milieu de l'image)
    middle_col = raw_data.shape[1] // 2
    ax2.plot(raw_data[:, middle_col])
    ax2.set_title(f'Profil vertical (colonne {middle_col})')
    ax2.set_xlabel('Position Y')
    ax2.set_ylabel('Intensité')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    return raw_data, normalized_data

In [9]:
def enhance_image(image):
    """Améliore la qualité de l'image avec plusieurs techniques"""
    # Normalisation initiale
    normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX)
    
    # CLAHE (Contrast Limited Adaptive Histogram Equalization)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    clahe_img = clahe.apply(normalized.astype(np.uint8))
    
    # Débruitage
    denoised = cv2.fastNlMeansDenoising(clahe_img)
    
    # Amélioration des bords
    kernel = np.array([[-1,-1,-1],
                      [-1, 9,-1],
                      [-1,-1,-1]])
    sharpened = cv2.filter2D(denoised, -1, kernel)
    
    return normalized, clahe_img, denoised, sharpened

def analyze_enhanced_image(image):
    """Analyse détaillée de l'image avec différentes améliorations"""
    # Appliquer les améliorations
    normalized, clahe_img, denoised, sharpened = enhance_image(image)
    
    # Créer une grande figure
    plt.figure(figsize=(20, 15))
    
    # 1. Image originale
    plt.subplot(331)
    plt.imshow(image, cmap='bone')
    plt.title('Image originale')
    plt.colorbar()
    
    # 2. Image normalisée
    plt.subplot(332)
    plt.imshow(normalized, cmap='bone')
    plt.title('Normalisée')
    plt.colorbar()
    
    # 3. CLAHE
    plt.subplot(333)
    plt.imshow(clahe_img, cmap='bone')
    plt.title('CLAHE')
    plt.colorbar()
    
    # 4. Débruitage
    plt.subplot(334)
    plt.imshow(denoised, cmap='bone')
    plt.title('Débruitée')
    plt.colorbar()
    
    # 5. Accentuation des bords
    plt.subplot(335)
    plt.imshow(sharpened, cmap='bone')
    plt.title('Bords accentués')
    plt.colorbar()
    
    # 6. Détection de contours (Sobel)
    sobelx = cv2.Sobel(normalized, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(normalized, cv2.CV_64F, 0, 1, ksize=3)
    sobel = np.sqrt(sobelx**2 + sobely**2)
    plt.subplot(336)
    plt.imshow(sobel, cmap='bone')
    plt.title('Sobel')
    plt.colorbar()
    
    # 7. Seuillage adaptatif
    thresh = cv2.adaptiveThreshold(normalized.astype(np.uint8), 255,
                                 cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                 cv2.THRESH_BINARY, 11, 2)
    plt.subplot(337)
    plt.imshow(thresh, cmap='bone')
    plt.title('Seuillage adaptatif')
    
    # 8. Gradient morphologique
    kernel = np.ones((3,3), np.uint8)
    gradient = cv2.morphologyEx(normalized.astype(np.uint8), 
                              cv2.MORPH_GRADIENT, kernel)
    plt.subplot(338)
    plt.imshow(gradient, cmap='bone')
    plt.title('Gradient morphologique')
    plt.colorbar()
    
    # 9. Superposition contours sur original
    overlay = cv2.addWeighted(normalized.astype(np.uint8), 0.7,
                            gradient, 0.3, 0)
    plt.subplot(339)
    plt.imshow(overlay, cmap='bone')
    plt.title('Superposition')
    plt.colorbar()
    
    plt.tight_layout()
    plt.show()
    
    return normalized, clahe_img, denoised, sharpened

if __name__ == "__main__":
    data_dir = Path('./data/dicom_series')
    
    try:
        # 1. Afficher quelques images brutes
        print("Affichage des images brutes...")
        first_image = display_raw_images(data_dir, n_slices=4)
        
        # 2. Analyser l'image améliorée
        print("\nAnalyse détaillée avec améliorations...")
        normalized, clahe_img, denoised, sharpened = analyze_enhanced_image(first_image)
        
    except Exception as e:
        print(f"\nErreur: {e}")
        import traceback
        traceback.print_exc()

Affichage des images brutes...

Erreur: name 'display_raw_images' is not defined


Traceback (most recent call last):
  File "/tmp/ipykernel_46734/866377687.py", line 104, in <module>
    first_image = display_raw_images(data_dir, n_slices=4)
                  ^^^^^^^^^^^^^^^^^^
NameError: name 'display_raw_images' is not defined


In [12]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import cv2

def display_raw_images(dicom_dir: Path, n_slices: int = 4):
    """Affiche quelques images brutes du dossier"""
    print(f"Loading images from {dicom_dir}...")
    
    # Trouver tous les fichiers humerus
    image_files = sorted(dicom_dir.glob('humerus*.dcm'))
    if not image_files:
        raise FileNotFoundError(f"No humerus DICOM files found in {dicom_dir}")
    
    print(f"Found {len(image_files)} files")
    
    # Sélectionner quelques images réparties uniformément
    step = len(image_files) // n_slices
    selected_files = [image_files[i * step] for i in range(n_slices)]
    
    # Créer une figure avec n_slices sous-plots
    fig, axes = plt.subplots(1, n_slices, figsize=(20, 5))
    
    for i, file_path in enumerate(selected_files):
        try:
            # Lire le fichier DICOM
            dicom_data = pydicom.dcmread(str(file_path), force=True)
            
            # Extraire les données pixel
            image = dicom_data.pixel_array
            
            print(f"\nFile: {file_path.name}")
            print(f"Image shape: {image.shape}")
            print(f"Data type: {image.dtype}")
            print(f"Value range: [{image.min()}, {image.max()}]")
            print(f"Mean value: {image.mean():.2f}")
            
            # Normaliser pour l'affichage
            image_float = image.astype(float)
            # Utiliser les percentiles pour éviter les valeurs extrêmes
            p1, p99 = np.percentile(image_float, (1, 99))
            image_normalized = np.clip((image_float - p1) / (p99 - p1), 0, 1)
            
            # Convertir en uint8 pour CLAHE
            image_uint8 = (image_normalized * 255).astype(np.uint8)
            
            # Améliorer le contraste avec CLAHE
            clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
            enhanced = clahe.apply(image_uint8)
            
            # Afficher l'image
            im = axes[i].imshow(enhanced, cmap='bone')
            axes[i].set_title(f'Slice {i*step}\n{image.shape[0]}x{image.shape[1]}')
            axes[i].axis('off')
            plt.colorbar(im, ax=axes[i])
            
            # Afficher l'histogramme
            plt.figure(figsize=(10, 4))
            plt.hist(image.ravel(), bins=100, color='blue', alpha=0.7)
            plt.title(f'Distribution des intensités - Slice {i*step}')
            plt.xlabel('Intensité')
            plt.ylabel('Fréquence')
            plt.grid(True)
            plt.show()
            
        except Exception as e:
            print(f"Error processing {file_path.name}: {e}")
            continue
    
    plt.tight_layout()
    plt.show()
    
    return enhanced

def analyze_slice(image):
    """Analyse détaillée d'une tranche"""
    # Créer une figure pour l'analyse
    plt.figure(figsize=(20, 10))
    
    # 1. Image originale
    plt.subplot(231)
    plt.imshow(image, cmap='bone')
    plt.title('Original')
    plt.colorbar()
    
    # 2. Histogramme
    plt.subplot(232)
    plt.hist(image.ravel(), bins=50, color='blue', alpha=0.7)
    plt.title('Histogramme')
    plt.grid(True)
    
    # 3. CLAHE
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
    clahe_img = clahe.apply(image.astype(np.uint8))
    plt.subplot(233)
    plt.imshow(clahe_img, cmap='bone')
    plt.title('CLAHE')
    plt.colorbar()
    
    # 4. Détection de contours
    edges = cv2.Canny(image.astype(np.uint8), 50, 150)
    plt.subplot(234)
    plt.imshow(edges, cmap='gray')
    plt.title('Contours')
    
    # 5. Segmentation par seuillage adaptatif
    thresh = cv2.adaptiveThreshold(image.astype(np.uint8), 255,
                                 cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                 cv2.THRESH_BINARY, 11, 2)
    plt.subplot(235)
    plt.imshow(thresh, cmap='gray')
    plt.title('Seuillage adaptatif')
    
    # 6. Superposition contours sur original
    overlay = cv2.addWeighted(image.astype(np.uint8), 0.7,
                            edges, 0.3, 0)
    plt.subplot(236)
    plt.imshow(overlay, cmap='bone')
    plt.title('Superposition')
    
    plt.tight_layout()
    plt.show()


In [13]:

if __name__ == "__main__":
    data_dir = Path('./data/dicom_series')
    
    try:
        # 1. Afficher les tranches
        print("Affichage des images brutes...")
        enhanced_image = display_raw_images(data_dir, n_slices=4)
        
        # 2. Analyser une tranche en détail
        print("\nAnalyse détaillée d'une tranche...")
        analyze_slice(enhanced_image)
        
    except Exception as e:
        print(f"\nErreur: {e}")
        import traceback
        traceback.print_exc()

Affichage des images brutes...
Loading images from data/dicom_series...

Erreur: No humerus DICOM files found in data/dicom_series


Traceback (most recent call last):
  File "/tmp/ipykernel_46734/1986210829.py", line 7, in <module>
    enhanced_image = display_raw_images(data_dir, n_slices=4)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_46734/1238474881.py", line 13, in display_raw_images
    raise FileNotFoundError(f"No humerus DICOM files found in {dicom_dir}")
FileNotFoundError: No humerus DICOM files found in data/dicom_series
