In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

class TomographySimulator:
    def __init__(self, image, angle_step=1, n_detectors=None, l_span=None):
        self.image = image
        self.height, self.width = image.shape
        self.cx, self.cy = self.width / 2, self.height / 2
        self.angle_step = angle_step
        self.angles = np.arange(0, 180, angle_step)
        self.angles_rad = np.deg2rad(self.angles)
        self.R = min(self.width, self.height) / 2
        if l_span is None:
            l_span = 2 * self.R
        self.l_span = l_span
        if n_detectors is None:
            n_detectors = int(l_span)
        self.n_detectors = n_detectors
        self.detector_offsets = np.linspace(-l_span/2, l_span/2, n_detectors)
        self.sinogram = None
        self.rays = None
        self.filtered_sinogram = None
        self.reconstruction = None
        self.bp_angles_contrib = []
        
    def compute_sinogram(self):
        n_angles = len(self.angles_rad)
        sinogram = np.zeros((n_angles, self.n_detectors))
        rays = []
        for i, theta in enumerate(self.angles_rad):
            rays_angle = []
            for j, offset in enumerate(self.detector_offsets):
                start_x = self.cx + offset * np.cos(theta + np.pi/2) - self.R * np.cos(theta)
                start_y = self.cy + offset * np.sin(theta + np.pi/2) - self.R * np.sin(theta)
                end_x   = self.cx + offset * np.cos(theta + np.pi/2) + self.R * np.cos(theta)
                end_y   = self.cy + offset * np.sin(theta + np.pi/2) + self.R * np.sin(theta)
                line_pixels = bresenham_line(start_x, start_y, end_x, end_y)
                valid_pixels = [(x, y) for (x, y) in line_pixels if 0 <= x < self.width and 0 <= y < self.height]
                if valid_pixels:
                    intensities = [self.image[y, x] for (x, y) in valid_pixels]
                    value = np.mean(intensities)
                else:
                    value = 0
                sinogram[i, j] = value
                rays_angle.append(valid_pixels)
            rays.append(rays_angle)
        self.sinogram = sinogram
        self.rays = rays
        return sinogram

    def apply_ramp_filter(self):
        if self.sinogram is None:
            raise ValueError("Najpierw oblicz sinogram!")
        filtered = np.zeros_like(self.sinogram)
        for i in range(self.sinogram.shape[0]):
            projection = self.sinogram[i, :]
            proj_fft = np.fft.fft(projection)
            freqs = np.fft.fftfreq(self.n_detectors, d=1.0)
            ramp = np.abs(freqs)
            proj_fft_filtered = proj_fft * ramp
            proj_filtered = np.fft.ifft(proj_fft_filtered).real
            filtered[i, :] = proj_filtered
        self.filtered_sinogram = filtered
        return filtered

    def reconstruct(self):
        if self.filtered_sinogram is None:
            self.apply_ramp_filter()
        reconstruction = np.zeros((self.height, self.width))
        self.bp_angles_contrib = []
        for i, theta in enumerate(self.angles_rad):
            contrib = np.zeros((self.height, self.width))
            for j, offset in enumerate(self.detector_offsets):
                p_val = self.filtered_sinogram[i, j]
                ray = self.rays[i][j]
                for (x, y) in ray:
                    contrib[y, x] += p_val
            reconstruction += contrib
            self.bp_angles_contrib.append(contrib)
        reconstruction = reconstruction * (np.pi / len(self.angles_rad))
        self.reconstruction = reconstruction
        return reconstruction

    def iterative_reconstruction(self, upto_angle):
        if not self.bp_angles_contrib:
            self.reconstruct()
        iter_recon = np.zeros((self.height, self.width))
        for i in range(upto_angle):
            iter_recon += self.bp_angles_contrib[i]
        iter_recon = iter_recon * (np.pi / len(self.angles_rad))
        return iter_recon

def bresenham_line(x0, y0, x1, y1):
    x0, y0, x1, y1 = int(round(x0)), int(round(y0)), int(round(x1)), int(round(y1))
    points = []
    dx = abs(x1 - x0)
    dy = abs(y1 - y0)
    sx = 1 if x0 < x1 else -1
    sy = 1 if y0 < y1 else -1
    err = dx - dy
    while True:
        points.append((x0, y0))
        if x0 == x1 and y0 == y1:
            break
        e2 = 2 * err
        if e2 > -dy:
            err -= dy
            x0 += sx
        if e2 < dx:
            err += dx
            y0 += sy
    return points

def on_upload_change(change):
    global loaded_image
    output_area.clear_output()
    files = uploader.value
    if not files:
        return
    for file_info in files:
        name = file_info.get('name', 'unknown')
        bytes_data = file_info.get('content')
        np_arr = np.frombuffer(bytes_data, np.uint8)
        img = cv2.imdecode(np_arr, cv2.IMREAD_GRAYSCALE)
        if img is None:
            with output_area:
                print("Nie udało się wczytać obrazu.")
            return
        max_dim = 256
        h, w = img.shape
        if max(h, w) > max_dim:
            scaling = max_dim / float(max(h, w))
            img = cv2.resize(img, (int(w * scaling), int(h * scaling)),
                             interpolation=cv2.INTER_AREA)
        loaded_image = img
        with output_area:
            print(f"Wczytano obraz: {name}, rozmiar: {img.shape}")

def run_simulation(b):
    global simulator, loaded_image
    output_area.clear_output()
    out_full.clear_output(wait=True)
    if loaded_image is None:
        with output_area:
            print("Najpierw wczytaj obraz!")
        return
    try:
        angle_step = float(angle_step_widget.value)
        n_detectors = int(detectors_widget.value)
        l_span = float(lspan_widget.value)
    except Exception as e:
        with output_area:
            print("Błędne parametry!", e)
        return
    simulator = TomographySimulator(loaded_image, angle_step=angle_step, 
                                    n_detectors=n_detectors, l_span=l_span)
    with output_area:
        print("Generowanie sinogramu...")
    sinogram = simulator.compute_sinogram()
    with output_area:
        print("Sinogram wygenerowany.\nStosowanie filtru rampowego...")
    simulator.apply_ramp_filter()
    with output_area:
        print("Rekonstruowanie obrazu...")
    reconstruction = simulator.reconstruct()
    mse = np.sqrt(np.mean((loaded_image.astype(float) - reconstruction)**2))
    with output_area:
        print("MSE:", mse)
    
    with out_full:
        clear_output(wait=True)
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(loaded_image, cmap='gray')
        axes[0].set_title("Obraz oryginalny")
        axes[0].axis('off')
        
        im1 = axes[1].imshow(sinogram, cmap='gray', aspect='auto')
        axes[1].set_title("Sinogram")
        axes[1].set_xlabel("Detektor")
        axes[1].set_ylabel("Kąt")
        fig.colorbar(im1, ax=axes[1])
        
        axes[2].imshow(reconstruction, cmap='gray')
        axes[2].set_title("Rekonstrukcja")
        axes[2].axis('off')
        
        plt.show()

def update_iterative(change):
    if simulator is None:
        return
    upto = slider.value
    iter_image = simulator.iterative_reconstruction(upto)
    with out_iter:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(iter_image, cmap='gray')
        ax.set_title(f"Rekonstrukcja iteracyjna (kąty 0 do {upto})")
        ax.axis('off')
        plt.show()

def run_iterative(b):
    if simulator is None:
        with output_area:
            print("Najpierw uruchom pełną symulację!")
        return
    slider.max = len(simulator.angles)
    slider.value = 1
    with output_area:
        clear_output()
        print("Steruj rekonstrukcją iteracyjną za pomocą suwaka poniżej.")
    update_iterative(None)

simulator = None
loaded_image = None

output_area = widgets.Output()

out_full = widgets.Output()
out_iter = widgets.Output()

uploader = widgets.FileUpload(accept=".jpg, .jpeg", multiple=False)

uploader.observe(on_upload_change, names='value')

angle_step_widget = widgets.FloatText(
    value=1.0, 
    description='Krok ∆α:',
    step=0.1
)
detectors_widget = widgets.IntText(
    value=180, 
    description='Liczba detektorów:'
)
lspan_widget = widgets.FloatText(
    value=256.0, 
    description='Rozpiętość układu:'
)

simulate_button = widgets.Button(
    description="Uruchom symulację",
    button_style='success'
)
iterative_button = widgets.Button(
    description="Rekonstrukcja iteracyjna",
    button_style='info'
)

slider = widgets.IntSlider(
    value=1,
    min=1,
    max=180,
    step=1,
    description='Kąty:',
    continuous_update=True
)

simulate_button.on_click(run_simulation)
iterative_button.on_click(run_iterative)
slider.observe(update_iterative, names='value')

full_simulation_box = widgets.VBox([simulate_button, out_full])
iterative_box = widgets.VBox([iterative_button, out_iter, slider])

ui = widgets.VBox([
    widgets.HTML("<h2>Symulator Tomografii Komputerowej</h2>"),
    uploader,
    angle_step_widget,
    detectors_widget,
    lspan_widget,
    output_area,
    full_simulation_box,
    iterative_box
])
display(ui)


VBox(children=(HTML(value='<h2>Symulator Tomografii Komputerowej</h2>'), FileUpload(value=(), accept='.jpg, .j…