In [3]:
import cv2
import numpy as np
from sklearn.cluster import KMeans, OPTICS
from scipy.spatial.distance import euclidean
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import pandas as pd
import os
from IPython.display import display, HTML, FileLink
import ipywidgets as widgets
from io import BytesIO
import base64
import zipfile

# Create output directory
output_dir = 'spinal_analysis_output'
os.makedirs(output_dir, exist_ok=True)

# Function to check font availability
def check_fonts():
    try:
        test_img = np.zeros((100,100,3), dtype=np.uint8)
        cv2.putText(test_img, 'Test', (10,50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2)
        return True
    except:
        return False

FONT_AVAILABLE = check_fonts()

# Function to resize large images
def resize_image(image, max_size=1000):
    h, w = image.shape[:2]
    if max(h, w) > max_size:
        scale = max_size / max(h, w)
        new_h, new_w = int(h * scale), int(w * scale)
        image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
    return image

# CLAHE preprocessing
def preprocess_image(image):
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    return clahe.apply(image)

# K-Means enhancement
def enhance_image_kmeans(image, n_clusters=8):
    pixel_values = image.reshape(-1, 1)
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    kmeans.fit(pixel_values)
    labels = kmeans.labels_
    centers = kmeans.cluster_centers_
    segmented_pixels = centers[labels].reshape(image.shape)
    return cv2.normalize(segmented_pixels, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), labels, centers

# Image blending
def blend_images(original, clustered, alpha=0.7):
    original = original.astype(np.float32)
    clustered = clustered.astype(np.float32)
    blended = alpha * clustered + (1 - alpha) * original
    return cv2.normalize(blended, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

# Robust text overlay using Pillow as fallback
def overlay_spaces(image, spaces):
    img_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    colors = {
        'Normal': (0,255,0),
        'Narrowed': (255,0,0),
        'Osteophytes': (255,165,0),
        'Sclerotic': (0,255,255)
    }
    
    if FONT_AVAILABLE:
        font = cv2.FONT_HERSHEY_SIMPLEX
        for space in spaces:
            color = colors.get(space['type'], (0,255,255))
            cv2.line(img_rgb, (space['x_start'], space['y']), 
                    (space['x_end'], space['y']), color, 3)
            
            label = f"{space['type']} {space['height']:.1f}px"
            text_size = cv2.getTextSize(label, font, 0.5, 2)[0]
            text_x = space['x_start'] + 10
            text_y = space['y'] - 10
            
            if text_y - text_size[1] < 0:
                text_y = space['y'] + text_size[1] + 10
            if text_x + text_size[0] > img_rgb.shape[1]:
                text_x = img_rgb.shape[1] - text_size[0] - 10
                
            cv2.putText(img_rgb, label, (text_x, text_y),
                       font, 0.5, color, 2)
    else:
        img_pil = Image.fromarray(img_rgb)
        draw = ImageDraw.Draw(img_pil)
        try:
            font = ImageFont.truetype('arial.ttf', 15)
        except:
            font = ImageFont.load_default()
        
        for space in spaces:
            color = colors.get(space['type'], (0,255,255))
            draw.line([(space['x_start'], space['y']), 
                     (space['x_end'], space['y'])], 
                     fill=color, width=3)
            draw.text((space['x_start']+10, space['y']-15), 
                     f"{space['type']} {space['height']:.1f}px", 
                     fill=color, font=font)
        
        img_rgb = np.array(img_pil)
    return img_rgb

# Enhanced disc space detection (original method)
def detect_disc_spaces(image):
    blurred = cv2.GaussianBlur(image, (5, 5), 0)
    binary = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                 cv2.THRESH_BINARY_INV, 21, 5)
    kernel = np.ones((5,5), np.uint8)
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=3)
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key=lambda c: cv2.boundingRect(c)[1])
    
    spaces = []
    for i in range(len(contours)-1):
        x1, y1, w1, h1 = cv2.boundingRect(contours[i])
        x2, y2, w2, h2 = cv2.boundingRect(contours[i+1])
        
        if abs(x1-x2) < max(w1,w2)*1.5 and y2 > y1+h1:
            space_height = y2 - (y1+h1)
            if 5 < space_height < 40:
                space = {
                    'y': y1+h1 + space_height//2,
                    'x_start': min(x1,x2),
                    'x_end': max(x1+w1,x2+w2),
                    'height': space_height,
                    'width': max(x1+w1,x2+w2) - min(x1,x2),
                    'type': classify_space(image, y1+h1, y2, min(x1,x2), max(x1+w1,x2+w2))
                }
                spaces.append(space)
    return spaces

def classify_space(image, top, bottom, left, right):
    roi = image[top:bottom, left:right]
    edges = cv2.Canny(roi, 50, 150)
    edge_density = np.sum(edges)/(roi.size*255)
    
    if edge_density > 0.2:
        return 'Osteophytes'
    elif np.std(roi) < 25:
        return 'Sclerotic'
    elif (bottom-top) < 15:
        return 'Narrowed'
    else:
        return 'Normal'

# OPTICS-based disc space detection
def detect_disc_spaces_optics(image):
    if len(image.shape) == 2:
        color_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        gray = image
    else:
        color_img = image.copy()
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    edges = cv2.Canny(gray, 50, 150)
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    points = []
    for cnt in contours:
        M = cv2.moments(cnt)
        if M['m00'] > 0:
            cx = int(M['m10'] / M['m00'])
            cy = int(M['m01'] / M['m00'])
            points.append([cx, cy])
    points = np.array(points)

    if len(points) < 2:
        return [], color_img, points, []

    optics = OPTICS(min_samples=2, xi=0.05, min_cluster_size=0.05)
    labels = optics.fit_predict(points)

    output = color_img.copy()
    unique_labels = sorted(set(labels))
    base_cmap = plt.colormaps.get_cmap('tab10')
    color_list = [base_cmap(i % 10) for i in range(len(unique_labels))]

    cluster_centers = []
    for label in unique_labels:
        if label == -1:
            continue
        cluster_points = points[labels == label]
        center = np.mean(cluster_points, axis=0).astype(int)
        cluster_centers.append(center)
        color = color_list[label % len(color_list)]
        for pt in cluster_points:
            cv2.circle(output, tuple(pt), 3, (
                int(color[0] * 255),
                int(color[1] * 255),
                int(color[2] * 255)
            ), -1)
        cv2.circle(output, tuple(center), 6, (255, 255, 255), 2)

    cluster_centers.sort(key=lambda x: x[1])
    spaces = []
    for i in range(len(cluster_centers) - 1):
        pt1 = tuple(cluster_centers[i])
        pt2 = tuple(cluster_centers[i + 1])
        dist = euclidean(pt1, pt2)
        mid_x = (pt1[0] + pt2[0]) // 2
        mid_y = (pt1[1] + pt2[1]) // 2
        cv2.line(output, pt1, pt2, (0, 0, 255), 1)
        cv2.putText(output, f'{int(dist)}px', (mid_x, mid_y),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 0), 1)
        spaces.append({
            'Space': f'Space {i+1}',
            'Type': 'Cluster',
            'Height': f'{dist:.1f}px',
            'Width': '-'
        })
    return spaces, output, points, labels

# Image to base64 for download
def image_to_base64(image_array):
    buf = BytesIO()
    Image.fromarray(image_array).save(buf, format='PNG')
    return buf.getvalue()

# Visualization functions
def plot_kmeans_histogram(image, labels, centers, n_clusters, filename):
    plt.figure(figsize=(8, 6))
    for i in range(n_clusters):
        cluster_pixels = image[labels.reshape(image.shape) == i]
        plt.hist(cluster_pixels, bins=50, alpha=0.5, label=f'Cluster {i} (center: {centers[i][0]:.1f})')
    plt.title('K-Means Cluster Pixel Intensity Distribution')
    plt.xlabel('Pixel Intensity')
    plt.ylabel('Frequency')
    plt.legend()
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

def plot_optics_clusters(points, labels, filename):
    plt.figure(figsize=(8, 6))
    unique_labels = set(labels)
    base_cmap = plt.colormaps.get_cmap('tab10')
    for label in unique_labels:
        if label == -1:
            color = 'black'
            marker = 'x'
            label_name = 'Noise'
        else:
            color = base_cmap(label % 10)
            marker = 'o'
            label_name = f'Cluster {label}'
        cluster_points = points[labels == label]
        plt.scatter(cluster_points[:, 0], cluster_points[:, 1], c=[color], marker=marker, label=label_name, s=50)
    plt.title('OPTICS Clustering of Contour Centroids')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.legend()
    plt.gca().invert_yaxis()
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

def plot_pipeline_summary(original, preprocessed, enhanced, analysis, filename):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    axes[1].imshow(preprocessed, cmap='gray')
    axes[1].set_title('Preprocessed (CLAHE)')
    axes[1].axis('off')
    axes[2].imshow(enhanced, cmap='gray')
    axes[2].set_title('Enhanced (K-Means)')
    axes[2].axis('off')
    axes[3].imshow(analysis)
    axes[3].set_title('Disc Space Analysis (OPTICS)')
    axes[3].axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

# Widgets for interaction
uploader = widgets.FileUpload(accept='.jpg,.png,.jpeg', multiple=False, description='Upload spine image')
enable_detection = widgets.Checkbox(value=False, description='Enable disc space detection (OPTICS)')
show_team = widgets.Checkbox(value=False, description='Show Team Info')
output = widgets.Output()

# Display widgets
display(HTML('<h2>Spinal Cord Image Clustering</h2>'))
display(uploader)
display(enable_detection)
display(show_team)
display(output)

# Team information
team = [
    {
        'name': 'Alice Smith',
        'role': 'Lead Developer',
        'img': 'https://randomuser.me/api/portraits/women/44.jpg',
        'bio': 'Expert in medical imaging and AI.'
    },
    {
        'name': 'Bob Lee',
        'role': 'Backend Engineer',
        'img': 'https://randomuser.me/api/portraits/men/32.jpg',
        'bio': 'Loves Python, FastAPI, and scalable systems.'
    },
    {
        'name': 'Carol Tan',
        'role': 'UI/UX Designer',
        'img': 'https://randomuser.me/api/portraits/women/68.jpg',
        'bio': 'Passionate about beautiful, accessible design.'
    }
]

# Process image and update display
def process_image(change):
    with output:
        output.clear_output()
        if not uploader.value:
            display(HTML('<p style="color: blue;">Please upload a spinal X-ray image to begin analysis</p>'))
            return

        uploaded_file = next(iter(uploader.value.values()))
        filename = uploaded_file['metadata']['name']
        file_base = os.path.splitext(filename)[0]
        image = Image.open(BytesIO(uploaded_file['content'])).convert('L')
        img_array = np.array(image)
        img_array = resize_image(img_array)

        # Processing pipeline
        processed = preprocess_image(img_array)
        clustered, labels, centers = enhance_image_kmeans(processed, 8)
        enhanced = blend_images(processed, clustered)

        # Prepare images for download
        images_dict = {
            'Original': image_to_base64(img_array),
            'Preprocessed': image_to_base64(processed),
            'Enhanced': image_to_base64(enhanced)
        }

        # Save images
        cv2.imwrite(os.path.join(output_dir, f'{file_base}_original.png'), img_array)
        cv2.imwrite(os.path.join(output_dir, f'{file_base}_preprocessed.png'), processed)
        cv2.imwrite(os.path.join(output_dir, f'{file_base}_enhanced.png'), enhanced)

        # Display images in three columns
        display(HTML('<h3>Image Processing Results</h3>'))
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img_array, cmap='gray')
        axes[0].set_title('Original')
        axes[0].axis('off')
        axes[1].imshow(processed, cmap='gray')
        axes[1].set_title('Preprocessed (CLAHE)')
        axes[1].axis('off')
        axes[2].imshow(enhanced, cmap='gray')
        axes[2].set_title('Enhanced (K-Means)')
        axes[2].axis('off')
        plt.tight_layout()
        plt.show()

        # Download buttons
        display(HTML('<div style="display: flex; justify-content: space-around;">'))
        for key, img_data in images_dict.items():
            with open(os.path.join(output_dir, f'{file_base}_{key.lower()}.png'), 'wb') as f:
                f.write(img_data)
            display(FileLink(os.path.join(output_dir, f'{file_base}_{key.lower()}.png'), 
                            result_html_prefix=f'<a download="{file_base}_{key.lower()}.png" style="text-align: center; display: block;">Download {key} Image</a>'))
        display(HTML('</div>'))

        # K-Means visualization
        display(HTML('<h3>K-Means Clustering Visualization</h3>'))
        plot_kmeans_histogram(img_array, labels, centers, 8, f'{file_base}_kmeans_histogram.png')
        plt.figure(figsize=(8, 6))
        plt.imshow(plt.imread(os.path.join(output_dir, f'{file_base}_kmeans_histogram.png')))
        plt.axis('off')
        plt.show()

        # OPTICS detection
        if enable_detection.value:
            display(HTML('<hr><h3>Disc Space Detection (OPTICS)</h3>'))
            orig_color = np.array(Image.open(BytesIO(uploaded_file['content'])).convert('RGB'))
            orig_color = resize_image(orig_color)
            spaces, overlaid, points, labels_optics = detect_disc_spaces_optics(orig_color)
            cv2.imwrite(os.path.join(output_dir, f'{file_base}_analysis.png'), overlaid)
            images_dict['Analysis'] = image_to_base64(overlaid)

            # Display OPTICS result in two columns
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
            ax1.imshow(overlaid)
            ax1.set_title('Disc Space Analysis (OPTICS)')
            ax1.axis('off')
            ax2.axis('off')  # Placeholder for table
            plt.tight_layout()
            plt.show()

            # Display table
            if spaces:
                df = pd.DataFrame(spaces, columns=['Space', 'Type', 'Height', 'Width'])
                display(df)

            # Download analysis image
            with open(os.path.join(output_dir, f'{file_base}_analysis.png'), 'wb') as f:
                f.write(images_dict['Analysis'])
            display(HTML('<div style="text-align: center;">'))
            display(FileLink(os.path.join(output_dir, f'{file_base}_analysis.png'), 
                            result_html_prefix=f'<a download="{file_base}_analysis.png">Download Analysis Image</a>'))
            display(HTML('</div>'))

            # OPTICS cluster visualization
            if len(points) > 0:
                display(HTML('<h3>OPTICS Clustering Visualization</h3>'))
                plot_optics_clusters(points, labels_optics, f'{file_base}_optics_clusters.png')
                plt.figure(figsize=(8, 6))
                plt.imshow(plt.imread(os.path.join(output_dir, f'{file_base}_optics_clusters.png')))
                plt.axis('off')
                plt.show()

            # Pipeline summary
            display(HTML('<h3>Pipeline Summary</h3>'))
            plot_pipeline_summary(img_array, processed, enhanced, overlaid, f'{file_base}_pipeline_summary.png')
            plt.figure(figsize=(8, 2))
            plt.imshow(plt.imread(os.path.join(output_dir, f'{file_base}_pipeline_summary.png')))
            plt.axis('off')
            plt.show()

        # Create and provide ZIP download
        zip_buffer = BytesIO()
        with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
            for key, img_data in images_dict.items():
                zip_file.writestr(f'{file_base}_{key.lower()}.png', img_data)
        zip_path = os.path.join(output_dir, f'{file_base}_images.zip')
        with open(zip_path, 'wb') as f:
            f.write(zip_buffer.getvalue())
        display(HTML('<div style="text-align: center;">'))
        display(FileLink(zip_path, result_html_prefix=f'<a download="{file_base}_images.zip">Download All Images as ZIP</a>'))
        display(HTML('</div>'))

        # Team info
        if show_team.value:
            display(HTML('<hr><h3>Meet the Team</h3>'))
            for member in team:
                display(HTML(f'<img src="{member["img"]}" width="100"><br>'))
                display(HTML(f'<h4>{member["name"]}</h4><p><i>{member["role"]}</i><br>{member["bio"]}</p>'))

# Connect widgets to processing function
uploader.observe(process_image, names='value')
enable_detection.observe(process_image, names='value')
show_team.observe(process_image, names='value')
print('Upload an image to begin analysis.')

FileUpload(value={}, accept='.jpg,.png,.jpeg', description='Upload spine image')

Checkbox(value=False, description='Enable disc space detection (OPTICS)')

Checkbox(value=False, description='Show Team Info')

Output()

Upload an image to begin analysis.
