## Convolution Filter

In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Ein einfaches 8x8 Bild mit einer vertikalen Kante erstellen
image = np.zeros((8, 8))
image[:, :4] = 1  # Linke Hälfte weiß, rechte Hälfte schwarz

# Kantenerkennungsfilter definieren
filter = np.array([
    [-1, 0, 1],
    [-1, 0, 1],
    [-1, 0, 1]
])

# Funktion für Convolution an einer Position
def apply_filter(image, filter, pos_x, pos_y):
    result = 0
    for i in range(3):
        for j in range(3):
            if (pos_x+i-1 >= 0 and pos_x+i-1 < image.shape[0] and 
                pos_y+j-1 >= 0 and pos_y+j-1 < image.shape[1]):
                result += image[pos_x+i-1, pos_y+j-1] * filter[i, j]
    return result

# Komplettes Ergebnisbild berechnen
result_image = np.zeros_like(image)
for i in range(1, image.shape[0]-1):
    for j in range(1, image.shape[1]-1):
        result_image[i, j] = apply_filter(image, filter, i, j)

# Visualisierung
def plot_convolution(image, filter, pos_x, pos_y):
    # Erstelle Figure mit Subplots
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=('Eingabebild', 'Filter', 'Ergebnisbild'),
        column_widths=[0.33, 0.33, 0.33]
    )

    # Eingabebild
    fig.add_trace(
        go.Heatmap(
            z=image,
            colorscale='gray',
            showscale=False,
            text=np.round(image, 2),
            texttemplate='%{text}',
            textfont={"size": 20},
        ),
        row=1, col=1
    )

    # Filter hervorheben
    if pos_x is not None and pos_y is not None:
        fig.add_trace(
            go.Scatter(
                x=[pos_y-1.5, pos_y+1.5, pos_y+1.5, pos_y-1.5, pos_y-1.5],
                y=[pos_x-1.5, pos_x-1.5, pos_x+1.5, pos_x+1.5, pos_x-1.5],
                mode='lines',
                line=dict(color='red', width=2),
                showlegend=False
            ),
            row=1, col=1
        )

    # Filter
    fig.add_trace(
        go.Heatmap(
            z=filter,
            colorscale='RdBu',
            text=filter,
            texttemplate='%{text}',
            textfont={"size": 20},
            showscale=False
        ),
        row=1, col=2
    )

    # Ergebnisbild
    fig.add_trace(
        go.Heatmap(
            z=result_image,
            colorscale='RdBu',
            text=np.round(result_image, 2),
            texttemplate='%{text}',
            textfont={"size": 20},
        ),
        row=1, col=3
    )

    # Layout anpassen
    fig.update_layout(
        height=400,
        width=1200,
        title_text=f"Convolution Demonstration (Position: {pos_x}, {pos_y})",
        showlegend=False
    )

    # Achsen anpassen
    for i in range(1, 4):
        fig.update_xaxes(showgrid=False, zeroline=False, showticklabels=False, row=1, col=i)
        fig.update_yaxes(showgrid=False, zeroline=False, showticklabels=False, row=1, col=i)

    fig.show()

# Verschiedene Positionen zeigen
positions = [(4, 2), (4, 3), (4, 4)]  # Positionen vor, auf und nach der Kante
for pos_x, pos_y in positions:
    plot_convolution(image, filter, pos_x, pos_y)

# Zusätzlich: Animation der Filterverschiebung erstellen
frames = []
for x in range(1, 7):
    for y in range(1, 7):
        frames.append(go.Frame(
            data=[
                go.Heatmap(z=image, colorscale='gray', showscale=False),
                go.Scatter(
                    x=[y-1.5, y+1.5, y+1.5, y-1.5, y-1.5],
                    y=[x-1.5, x-1.5, x+1.5, x+1.5, x-1.5],
                    mode='lines',
                    line=dict(color='red', width=2)
                )
            ],
            name=f'frame_{x}_{y}'
        ))
fig = go.Figure()

# Basisdaten (erstes Frame)
fig.add_trace(
    go.Heatmap(
        z=image,
        colorscale='gray',
        showscale=False,
        name='image'
    )
)

fig.add_trace(
    go.Scatter(
        x=[1.5, 3.5, 3.5, 1.5, 1.5],
        y=[1.5, 1.5, 3.5, 3.5, 1.5],
        mode='lines',
        line=dict(color='red', width=2),
        name='filter'
    )
)

# Frames erstellen
frames = []
for x in range(1, 7):
    for y in range(1, 7):
        frame = go.Frame(
            data=[
                # Bild in jedem Frame neu zeichnen
                go.Heatmap(
                    z=image,
                    colorscale='gray',
                    showscale=False,
                    name='image'
                ),
                # Rotes Rechteck für Filter-Position
                go.Scatter(
                    x=[y-1.5, y+1.5, y+1.5, y-1.5, y-1.5],
                    y=[x-1.5, x-1.5, x+1.5, x+1.5, x-1.5],
                    mode='lines',
                    line=dict(color='red', width=2),
                    name='filter'
                )
            ],
            name=f'frame_{x}_{y}'
        )
        frames.append(frame)

fig.frames = frames

# Layout optimieren
fig.update_layout(
    title="Filterverschiebung über das Bild",
    width=500,
    height=500,
    updatemenus=[{
        'type': 'buttons',
        'showactive': False,
        'y': 0,
        'x': 0,
        'xanchor': 'left',
        'yanchor': 'top',
        'buttons': [{
            'label': 'Play',
            'method': 'animate',
            'args': [None, {
                'frame': {'duration': 300, 'redraw': True},
                'fromcurrent': True,
                'mode': 'immediate',
            }]
        }]
    }]
)

# Achsen anpassen
fig.update_xaxes(showgrid=False, zeroline=False, showticklabels=False, range=[-0.5, 7.5])
fig.update_yaxes(showgrid=False, zeroline=False, showticklabels=False, range=[-0.5, 7.5])

fig.show()

: 