In [1]:
import os
import zipfile
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import shutil

def load_nii_files_and_extract_slices(temp_dir):
    # This function will load all nii.gz files and collect their slice data and white pixel counts
    all_slices = []
    white_pixels_counts = []

    for root, dirs, files in os.walk(temp_dir):
        for filename in files:
            if filename.endswith('.nii.gz'):
                file_path = os.path.join(root, filename)
                img = nib.load(file_path)
                data = img.get_fdata()
                
                # Filter non-empty slices
                non_empty_slices = np.any(data > 0, axis=(0, 1))
                slices = data[:, :, non_empty_slices]

                # Flatten each slice to count white pixels and store them
                for i in range(slices.shape[2]):
                    slice_data = slices[:, :, i]
                    all_slices.append(slice_data)
                    white_pixels_counts.append(np.sum(slice_data > 0))

    return all_slices, white_pixels_counts

def save_selected_slices(all_slices, white_pixels_counts, output_dir, std_dev_threshold=1.5):
    # Calculate mean and standard deviation globally
    mean = np.mean(white_pixels_counts)
    std_dev = np.std(white_pixels_counts)

    # Determine slices within the first standard deviation of white pixel count
    lower_bound = mean - std_dev_threshold * std_dev
    upper_bound = mean + std_dev_threshold * std_dev
    selected_indices = [i for i, count in enumerate(white_pixels_counts) if lower_bound <= count <= upper_bound]

    # Save selected slices as PNG
    for index in selected_indices:
        plt.imsave(os.path.join(output_dir, f'slice_{index}.png'), all_slices[index], cmap='gray')

def process_zip_file(zip_path, output_dir, std_dev_threshold=1.0):
    # Create a temporary directory for extraction
    temp_dir = os.path.join(output_dir, 'temp_extraction')
    os.makedirs(temp_dir, exist_ok=True)
    
    # Extract .nii.gz files from the zip file
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(temp_dir)

    # Load all nii files and extract relevant slices and pixel counts
    all_slices, white_pixels_counts = load_nii_files_and_extract_slices(temp_dir)

    # Save the slices that fall within the global standard deviation threshold
    save_selected_slices(all_slices, white_pixels_counts, output_dir, std_dev_threshold)
    
    # Clean up the temporary extraction directory
    shutil.rmtree(temp_dir)

# Example usage
zip_path = 'segmentations_2.zip'
output_dir = 'extracted_slices'
process_zip_file(zip_path, output_dir)


In [9]:
# import shutil

# shutil.rmtree('extracted_slices')