In [None]:
import torch
import pywt


class MultimodalWaveletAligner:
    def multimodal_multiscale_wavelet_align(self, image_embedding, text_embedding):
        '''
        Desc: 多模态多尺度小波变换对齐
        Args:
            image_embedding: 图像embedding, image_embedding.shape: torch.Size([7050, 64]) 
            text_embedding: 文本embedding, text_embedding.shape: torch.Size([7050, 64]), 7050个物品,每个物品编码为64维向量
        Function:
            对图像和文本嵌入进行多模态多尺度小波变换对齐
        Returns:
            image_embedding_wave.shape: torch.Size([7050, 64]) text_embedding_wave.shape: torch.Size([7050, 64]) fusion_wave.shape: torch.Size([7050, 64])
        '''
        # 转换为numpy数组进行小波变换
        image_np = image_embedding.cpu().numpy()
        text_np = text_embedding.cpu().numpy()

        # 定义小波类型
        wavelet = 'db4'
        level = 3

        # 对图像和文本嵌入进行小波变换
        image_coeffs = pywt.wavedec(image_np, wavelet, level=level, axis=1)
        text_coeffs = pywt.wavedec(text_np, wavelet, level=level, axis=1)

        print("image_coeffs:", image_coeffs.shape)
        # 对每一级系数进行对齐和融合
        fused_coeffs = []
        for i in range(len(image_coeffs)):
            # 简单的平均融合
            fused_coeff = (image_coeffs[i] + text_coeffs[i]) / 2
            fused_coeffs.append(fused_coeff)

        # 进行小波逆变换
        image_embedding_wave_np = pywt.waverec(image_coeffs, wavelet, axis=1)
        text_embedding_wave_np = pywt.waverec(text_coeffs, wavelet, axis=1)
        fusion_wave_np = pywt.waverec(fused_coeffs, wavelet, axis=1)

        # 转换回torch张量
        image_embedding_wave = torch.tensor(image_embedding_wave_np, dtype=torch.float32, device=image_embedding.device)
        text_embedding_wave = torch.tensor(text_embedding_wave_np, dtype=torch.float32, device=text_embedding.device)
        fusion_wave = torch.tensor(fusion_wave_np, dtype=torch.float32, device=image_embedding.device)

        return image_embedding_wave, text_embedding_wave, fusion_wave


# 示例数据
image_embedding = torch.randn(7050, 64)
text_embedding = torch.randn(7050, 64)

aligner = MultimodalWaveletAligner()
image_embedding_wave, text_embedding_wave, fusion_wave = aligner.multimodal_multiscale_wavelet_align(image_embedding, text_embedding)
print(image_embedding_wave.shape)
print(text_embedding_wave.shape)
print(fusion_wave.shape)
    

AttributeError: 'list' object has no attribute 'shape'