In [1]:
# %%
import os
import random
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
from skimage.exposure import match_histograms

# %%
# Paths
ct_dir = "Dataset/CT-MRI/CT"
mri_dir = "Dataset/CT-MRI/MRI"
output_dir = "DatasetProcessed"

for split in ["train", "val", "test"]:
    os.makedirs(os.path.join(output_dir, split, "CT"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, split, "MRI"), exist_ok=True)

# %%
# Load paired image names
ct_images = sorted(os.listdir(ct_dir))
mri_images = sorted(os.listdir(mri_dir))
pairs = list(zip(ct_images, mri_images))
random.shuffle(pairs)

# Split: 70% train, 20% val, 10% test
total = len(pairs)
train_count = int(0.7 * total)
val_count = int(0.2 * total)
test_count = total - train_count - val_count

train_pairs = pairs[:train_count]
val_pairs = pairs[train_count:train_count+val_count]
test_pairs = pairs[train_count+val_count:]

print(f"Total pairs: {total}")
print(f"Train pairs: {len(train_pairs)}, Val pairs: {len(val_pairs)}, Test pairs: {len(test_pairs)}")

# %%
# Contrast matching function: match CT histogram to MRI
def match_ct_to_mri(ct_img, mri_img):
    ct_array = np.array(ct_img)
    mri_array = np.array(mri_img)
    matched = match_histograms(ct_array, mri_array)
    return Image.fromarray(np.uint8(np.clip(matched, 0, 255)))

# %%
# Augmentation: rotation, flip, scaling
def augment_pair(ct_img, mri_img):
    # Random rotation
    angle = random.choice([0, 90, 180, 270])
    ct_img = ct_img.rotate(angle)
    mri_img = mri_img.rotate(angle)
    
    # Random flips
    if random.random() > 0.5:
        ct_img = ct_img.transpose(Image.FLIP_LEFT_RIGHT)
        mri_img = mri_img.transpose(Image.FLIP_LEFT_RIGHT)
    if random.random() > 0.5:
        ct_img = ct_img.transpose(Image.FLIP_TOP_BOTTOM)
        mri_img = mri_img.transpose(Image.FLIP_TOP_BOTTOM)
    
    # Random scaling (0.9 to 1.1)
    scale = random.uniform(0.9, 1.1)
    w, h = ct_img.size
    new_w, new_h = int(w*scale), int(h*scale)
    
    ct_img = ct_img.resize((new_w, new_h), Image.BILINEAR)
    mri_img = mri_img.resize((new_w, new_h), Image.BILINEAR)
    
    if scale > 1.0:
        ct_img = ct_img.crop((0,0,w,h))
        mri_img = mri_img.crop((0,0,w,h))
    else:
        new_ct = Image.new('L', (w,h))
        new_ct.paste(ct_img, (0,0))
        ct_img = new_ct
        new_mri = Image.new('L', (w,h))
        new_mri.paste(mri_img, (0,0))
        mri_img = new_mri
    
    return ct_img, mri_img

# %%
# Apply filters to a pair
def apply_filters(ct_img, mri_img):
    filters = [
        ImageFilter.GaussianBlur(radius=1),
        ImageFilter.EDGE_ENHANCE,
        ImageFilter.EMBOSS,
        ImageFilter.SHARPEN,
        ImageFilter.SMOOTH
    ]
    
    filtered_pairs = [(ct_img, mri_img)]  # include original matched pair
    
    for f in filters:
        filtered_pairs.append((ct_img.filter(f), mri_img.filter(f)))
    
    return filtered_pairs

# %%
# Save function
def save_dataset(pairs_list, split_name, apply_aug=True):
    counter = 1
    for ct_name, mri_name in pairs_list:
        ct_img = Image.open(os.path.join(ct_dir, ct_name)).convert("L")
        mri_img = Image.open(os.path.join(mri_dir, mri_name)).convert("L")
        
        # Match CT histogram to MRI
        ct_img = match_ct_to_mri(ct_img, mri_img)
        
        # Generate augmented pairs
        paired_variants = [(ct_img, mri_img)]
        if apply_aug:
            # Standard augmentation
            aug_ct, aug_mri = augment_pair(ct_img, mri_img)
            paired_variants.append((aug_ct, aug_mri))
            # Filter-based augmentation
            filter_pairs = apply_filters(ct_img, mri_img)
            paired_variants.extend(filter_pairs)
        
        # Save all variants
        for ct_var, mri_var in paired_variants:
            ct_path = os.path.join(output_dir, split_name, "CT", f"{counter:05d}.png")
            mri_path = os.path.join(output_dir, split_name, "MRI", f"{counter:05d}.png")
            ct_var.save(ct_path)
            mri_var.save(mri_path)
            counter += 1
    
    print(f"{split_name.capitalize()} set saved with {counter-1} images.")

# %%
# Save all datasets
save_dataset(train_pairs, "train", apply_aug=True)
save_dataset(val_pairs, "val", apply_aug=True)
save_dataset(test_pairs, "test", apply_aug=False)

print("Dataset preparation complete!")


Total pairs: 573
Train pairs: 401, Val pairs: 114, Test pairs: 58
Train set saved with 3208 images.
Val set saved with 912 images.
Test set saved with 58 images.
Dataset preparation complete!
