In [1]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.decomposition import DictionaryLearning

# 1. 加载并转换为灰度图
target = 1
data_name = ['0618', '0854', '1066'][target - 1]
image = cv2.imread(f'../input_data/{data_name}.png')  # 加载RGB图像

# 检查图像是否正确加载
if image is None:
    raise ValueError(f"Error loading image: ../input_data/{data_name}.png")

# 转换为灰度图像
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# 调整图像大小以确保一致性（假设图像大小为500x250）
gray_image = cv2.resize(gray_image, (250, 500))

# 2. 使用字典学习提取特征
patch_size = 8  # 假设每个块为8x8像素
patches = []

# 将图像分成多个8x8小块
for y in range(0, gray_image.shape[0] - patch_size, patch_size):
    for x in range(0, gray_image.shape[1] - patch_size, patch_size):
        patch = gray_image[y:y + patch_size, x:x + patch_size]
        patches.append(patch.flatten())  # 将块展平为一个特征向量

# 转换为numpy数组
patches = np.array(patches)

# 将数据转换为PyTorch张量并移到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
patches_tensor = torch.tensor(patches).float().to(device)

# 3. 使用GPU加速的字典学习（基于SVD）
def dictionary_learning_torch(X, n_components=64, n_iter=100):
    # 使用PyTorch实现字典学习（基于SVD）
    m, n = X.shape
    D = torch.randn(n_components, m, device=X.device)  # 字典D的大小应为 (n_components, m)，而不是 (n_components, n)
    D = D / D.norm(dim=1, keepdim=True)  # 正则化字典

    for _ in range(n_iter):
        # 计算系数矩阵 H
        H = torch.matmul(D, X.T)  # D^T * X
        H = H / (H.norm(dim=0, keepdim=True) + 1e-6)  # 正则化系数矩阵

        # 更新字典 D
        X_hat = torch.matmul(D.T, H)  # D * H
        D = X_hat / X_hat.norm(dim=1, keepdim=True)

    return D

# 执行字典学习（计算字典D）
D = dictionary_learning_torch(patches_tensor.T, n_components=64)

# 4. 可视化字典集合（学习到的字典原型）
dictionary = D.cpu().detach().numpy()  # 转移到CPU并转换为numpy数组

# 将字典中的每个原型reshape为图像块大小（8x8）
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(dictionary[i].reshape(patch_size, patch_size), cmap='gray')
    ax.axis('off')
plt.suptitle("Dictionary Components (Learned Patches)")
plt.show()

# 输出字典特征矩阵
print("Dictionary Learning Results:")
print("Components (Shape):", dictionary.shape)

# 最终输出是1922x64的特征矩阵
print("Output Feature Matrix Shape: ", patches.shape[0], "x", dictionary.shape[0])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x64 and 1922x64)