In [4]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button, Slider
from scipy.ndimage import distance_transform_edt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib
matplotlib.use('TkAgg')

# Create a blank canvas
canvas = np.ones((500, 500))  # White canvas (1 = white, 0 = black)

# Create a Matplotlib figure
fig, ax = plt.subplots()
im = ax.imshow(canvas, cmap="gray", vmin=0, vmax=1)
plt.axis("off")

# Initialize variables
drawing = False
erasing = False
brush_size = 5
inflation_factor = 50  # Default inflation factor

# Function to draw on the canvas
def draw_circle(x, y, radius, color=0):
    global canvas
    x, y = int(x), int(y)
    for i in range(-radius, radius + 1):
        for j in range(-radius, radius + 1):
            if 0 <= x + i < canvas.shape[1] and 0 <= y + j < canvas.shape[0]:
                if i**2 + j**2 <= radius**2:  # Check for circular shape
                    canvas[y + j, x + i] = color
    im.set_data(canvas)
    fig.canvas.draw_idle()

# Mouse press handler
def on_press(event):
    global drawing, erasing
    if event.xdata is not None and event.ydata is not None:
        if event.button == 1:  # Left mouse button for drawing
            drawing = True
            draw_circle(event.xdata, event.ydata, brush_size)
        elif event.button == 3:  # Right mouse button for erasing
            erasing = True
            draw_circle(event.xdata, event.ydata, brush_size, color=1)

# Mouse release handler
def on_release(event):
    global drawing, erasing
    drawing = False
    erasing = False

# Mouse motion handler
def on_motion(event):
    if event.xdata is not None and event.ydata is not None:
        if drawing:
            draw_circle(event.xdata, event.ydata, brush_size)
        elif erasing:
            draw_circle(event.xdata, event.ydata, brush_size, color=1)

# Button to clear the canvas
def clear_canvas(event):
    global canvas
    canvas.fill(1)  # Reset canvas to white
    im.set_data(canvas)
    fig.canvas.draw_idle()

# Add a "Clear" button to the figure
clear_ax = plt.axes([0.8, 0.05, 0.1, 0.075])  # Position of the button
clear_button = Button(clear_ax, "Clear")
clear_button.on_clicked(clear_canvas)

# Slider to adjust brush size
brush_slider_ax = plt.axes([0.2, 0.01, 0.4, 0.03])
brush_slider = Slider(brush_slider_ax, "Brush Size", 1, 20, valinit=brush_size, valstep=1)

def update_brush_size(val):
    global brush_size
    brush_size = int(val)

brush_slider.on_changed(update_brush_size)

# Slider to adjust inflation factor
inflation_slider_ax = plt.axes([0.2, 0.05, 0.4, 0.03])
inflation_slider = Slider(inflation_slider_ax, "Inflation", 10, 100, valinit=inflation_factor, valstep=5)

def update_inflation_factor(val):
    global inflation_factor
    inflation_factor = int(val)

inflation_slider.on_changed(update_inflation_factor)

# 3D Plot Function with Inflation
def plot_3d_with_inflation():
    global inflation_factor
    fig_3d = plt.figure(figsize=(8, 6))
    ax_3d = fig_3d.add_subplot(111, projection='3d')
    
    # Generate grid coordinates
    x = np.arange(0, canvas.shape[1])
    y = np.arange(0, canvas.shape[0])
    x, y = np.meshgrid(x, y)
    
    # Invert the canvas for distance transform
    inverted_canvas = 1 - canvas  # Black shapes become 1, white background becomes 0
    
    # Compute the distance transform (inflation effect)
    distance_map = distance_transform_edt(inverted_canvas)
    
    # Scale the height map for better visualization
    z = distance_map / np.max(distance_map) * inflation_factor  # Scale heights
    
    # Plot the surface
    ax_3d.plot_surface(x, y, z, cmap='viridis', edgecolor='none', antialiased=True)
    ax_3d.set_title("3D Inflated Model of the Drawing")
    ax_3d.set_xlabel("X-axis")
    ax_3d.set_ylabel("Y-axis")
    ax_3d.set_zlabel("Height")
    ax_3d.view_init(elev=30, azim=45)  # Adjust viewing angle
    
    # Show the plot
    plt.show()

# Add a "3D View" button to the figure
view_3d_ax = plt.axes([0.65, 0.05, 0.1, 0.075])  # Position of the button
view_3d_button = Button(view_3d_ax, "3D View")
view_3d_button.on_clicked(lambda event: plot_3d_with_inflation())

# Connect events
fig.canvas.mpl_connect("button_press_event", on_press)
fig.canvas.mpl_connect("button_release_event", on_release)
fig.canvas.mpl_connect("motion_notify_event", on_motion)

# Show the 2D drawing canvas
plt.show()