# SVD 压缩图像

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["font.sans-serif"] = "SimHei"
plt.rcParams["axes.unicode_minus"] = False
plt.tight_layout()

读取图像，分别提取出 RGB 三个通道的数据，并对它们进行 SVD 分解。

取 0.5、0.9、0.99、0.999 四个累计贡献率阈值，观察需要多少个奇异值的共同贡献，才能达到这个精度。

In [None]:
image = plt.imread("butterfly.bmp")

red_component = image[:, :, 0]
green_component = image[:, :, 1]
blue_component = image[:, :, 2]

configs = {
    "red": {
        "title": "红色通道",
        "color": "red",
        "data": red_component,
        "axis": 0,
    },
    "green": {
        "title": "绿色通道",
        "color": "green",
        "data": green_component,
        "axis": 1,
    },
    "blue": {
        "title": "蓝色通道",
        "color": "blue",
        "data": blue_component,
        "axis": 2,
    },
}

precision_list = (0.3, 0.5, 0.7, 0.9, 0.99)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("主成分数-累计贡献率")


def plot_accumulative_contribution(color):
    config = configs[color]
    title = config["title"]
    color = config["color"]
    data = config["data"]
    ax = axes[config["axis"]]

    _, s, _ = np.linalg.svd(data)
    accumulative_contribution = np.cumsum(s) / np.sum(s)

    ax.plot(accumulative_contribution, color=color)
    ax.set_title(title)

    for precision in precision_list:
        component_num = len(
            accumulative_contribution[accumulative_contribution < precision]
        )
        ax.plot(component_num, precision, color=color, marker="o")
        ax.text(component_num + 5, precision - 0.04, f"({component_num},{precision})")


plot_accumulative_contribution("red")
plot_accumulative_contribution("green")
plot_accumulative_contribution("blue")
plt.show()

从上图可以看到，奇异值的累计贡献率一开始随着秩数的增加而快速上升，在 170 个左右时达到 99%。

为了压缩的质量，我们以蓝色为准，分别选择 4、13、34、95、176 这五个秩数，分别观察它们的压缩效果。

In [None]:
def compress_component(component, k):
    u, s, vh = np.linalg.svd(component)
    return (u[:, :k] @ np.diag(s[:k])) @ vh[:k]


fig, axes = plt.subplots(2, 3, figsize=(15, 5))
fig.suptitle("压缩效果对比图")

rank_list = (4, 13, 34, 95, 176)

for index, rank in enumerate(rank_list):
    compressed = np.dstack(
        (
            compress_component(red_component, rank),
            compress_component(green_component, rank),
            compress_component(blue_component, rank),
        )
    ).astype(np.uint8)
    ax = axes[index // 3, index % 3]
    ax.imshow(compressed)
    ax.set_title(f"秩={rank}，精度={precision_list[index]}")
    ax.set_xticks([])
    ax.set_yticks([])

axes[1][2].imshow(image)
axes[1][2].set_title("原图")
axes[1][2].set_xticks([])
axes[1][2].set_yticks([])
plt.show()