### Filter Demonstration

Import useful packages

In [9]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from ipywidgets import IntSlider, FloatSlider, Dropdown, FileUpload, VBox, HBox, Button, Output
import io
import warnings
warnings.filterwarnings('ignore')

In [None]:
class ImageFilterDemo:
    def __init__(self):
        self.image = None
        self.original_image = None
        self.filter_size = 3
        self.custom_filter = None
        self.output = Output()
        
        # 预设滤波器
        self.preset_filters = {
            "Identity": np.eye(3).tolist(),
            "Blur": np.ones((3, 3)).tolist(),
            "Sharpen": [[0, -1, 0], [-1, 5, -1], [0, -1, 0]],
            "Edge Detection": [[1, 2, 1], [0, 0, 0], [-1, -2, -1]],
            "Gaussian Blur": [[1, 2, 1], [2, 4, 2], [1, 2, 1]],
            "Emboss": [[-2, -1, 0], [-1, 1, 1], [0, 1, 2]],
            "Horizonal Gradient":[[0, 0, 0],[-1, 0, 1],[0, 0, 0]],
            "Vertical Gradient":[[0, 1, 0],[0, 0, 0],[0, -1, 0]]
        }
        
    def load_image(self, upload_data):
        """加载上传的图片"""
        if not upload_data:
            return
        
        # 检查 upload_data 是否为空元组或字典为空
        if isinstance(upload_data, tuple) and not upload_data:
            return
        if isinstance(upload_data, dict) and not upload_data:
            return
        
        # 获取上传的文件内容
        content = list(upload_data.values())[0]['content']
        image = Image.open(io.BytesIO(content)).convert('RGB')
        
        # 转换为Tensor
        transform = transforms.ToTensor()
        self.original_image = transform(image)
        self.image = self.original_image.clone()
        
        # 显示图像
        self.display_images()

    def load_image_from_path(self, image_path):
        """从指定路径加载图像"""
        try:
            image = Image.open(image_path).convert('RGB')
            transform = transforms.ToTensor()
            self.original_image = transform(image)
            self.image = self.original_image.clone()
            self.display_images()
        except FileNotFoundError:
            print(f"Error: Image not found at {image_path}")
        except Exception as e:
            print(f"Error loading image: {e}")
    
    def create_filter_matrix(self, filter_type, size=3, custom_values=None):
        """创建滤波器矩阵"""
        if filter_type == "Custom" and custom_values is not None:
            # 使用自定义值
            if len(custom_values) != size * size:
                # 如果自定义值数量不匹配，使用单位矩阵
                filter_matrix = np.eye(size).flatten().tolist()
            else:
                filter_matrix = custom_values
        else:
            # 使用预设滤波器
            if filter_type in self.preset_filters:
                preset = np.array(self.preset_filters[filter_type])
                if size != preset.shape[0]:
                    # 调整预设滤波器大小
                    from scipy.ndimage import zoom
                    zoom_factor = size / preset.shape[0]
                    preset = zoom(preset, zoom_factor, order=1)
                filter_matrix = preset.flatten().tolist()
            else:
                # 默认使用单位矩阵
                filter_matrix = np.eye(size).flatten().tolist()
                
        return filter_matrix
    
    def apply_filter(self, filter_type, size, custom_values=None):
        """应用滤波器到图像"""
        if self.image is None:
            print("请先上传图像")
            return
            
        # 创建滤波器矩阵
        filter_values = self.create_filter_matrix(filter_type, size, custom_values)
        filter_tensor = torch.tensor(filter_values, dtype=torch.float32).view(1, 1, size, size)
        
        # 应用滤波器到每个通道
        filtered_channels = []
        for channel in range(3):
            channel_data = self.original_image[channel:channel+1, :, :]  # 保持维度
            filtered = F.conv2d(channel_data.unsqueeze(0), filter_tensor, padding=size//2)
            filtered_channels.append(filtered.squeeze(0))
        
        # 合并通道
        self.image = torch.cat(filtered_channels, dim=0)
        
        # 显示结果
        self.display_images()
    
    def display_images(self):
        """显示原始图像和处理后的图像"""
        with self.output:
            self.output.clear_output(wait=True)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
            
            # 显示原始图像
            original = self.original_image.permute(1, 2, 0).numpy()
            ax1.imshow(original)
            ax1.set_title('Original Image')
            ax1.axis('off')
            
            # 显示处理后的图像
            processed = self.image.permute(1, 2, 0).numpy()
            # 确保值在0-1范围内
            processed = np.clip(processed, 0, 1)
            ax2.imshow(processed)
            ax2.set_title('Filtered Image')
            ax2.axis('off')
            
            plt.tight_layout()
            plt.show()
    
    def create_ui(self):
        """创建用户界面"""
        # 使用按钮从本地文件夹加载图像
        load_button = Button(description="Load Image")
        
        def on_load_button_clicked(b):
            # 替换为你的图像文件夹路径
            image_folder = "./images"  
            image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
            
            if image_files:
                # 加载第一张图片
                image_path = os.path.join(image_folder, image_files[0])
                self.load_image_from_path(image_path)
            else:
                print("No images found in the specified folder.")
        
        load_button.on_click(on_load_button_clicked)
        
        # 滤波器选择
        filter_selector = Dropdown(
            options=list(self.preset_filters.keys()) + ["Custom"],
            value="Identity",
            description="Filter Type:"
        )
        
        # 滤波器大小选择
        size_slider = IntSlider(
            value=3,
            min=1,
            max=15,
            step=2,
            description="Filter Size:"
        )
        
        # 自定义滤波器值输入
        custom_filter_label = Dropdown(
            options=[f"Value {i+1}" for i in range(9)],
            value="Value 1",
            description="Custom Value:"
        )
        
        custom_filter_value = FloatSlider(
            value=1.0,
            min=-5.0,
            max=5.0,
            step=0.1,
            description="Value:"
        )
        
        # 存储自定义值的字典
        custom_values = {f"Value {i+1}": 1.0 if i == 4 else 0.0 for i in range(9)}
        
        def update_custom_value(change):
            custom_values[custom_filter_label.value] = change.new
        
        custom_filter_value.observe(update_custom_value, names='value')
        
        def update_custom_slider(change):
            custom_filter_value.value = custom_values[change.new]
        
        custom_filter_label.observe(update_custom_slider, names='value')
        
        # 应用滤波器按钮
        apply_button = Button(description="Apply Filter")
        apply_button.on_click(lambda _: self.apply_filter(
            filter_selector.value, 
            size_slider.value, 
            list(custom_values.values())[:size_slider.value**2]
        ))
        
        # 重置按钮
        reset_button = Button(description="Reset Image")
        reset_button.on_click(lambda _: self.reset_image())
        
        # 创建UI布局
        ui = VBox([
            load_button,
            HBox([filter_selector, size_slider]),
            HBox([custom_filter_label, custom_filter_value]),
            HBox([apply_button, reset_button]),
            self.output
        ])
        
        return ui
    
    def reset_image(self):
        """重置图像到原始状态"""
        if self.original_image is not None:
            self.image = self.original_image.clone()
            self.display_images()

# 创建并显示演示界面
demo = ImageFilterDemo()
ui = demo.create_ui()
ui

VBox(children=(Button(description='Load Image', style=ButtonStyle()), HBox(children=(Dropdown(description='Fil…