In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
import argparse
from net.VIT.mae import VisionTransfromers as MAEFinetune
from get_dat import get_dataset
from scipy.io import loadmat
import h5py
def visualize_predictions(args, model, all_loader):
    """
    对所有点进行预测并可视化结果
    """
    model.eval()  # 将模型设置为评估模式
    predictions = []  # 存储预测结果
    # ground_truth = []  # 存储真实标签（如果有）

    with torch.no_grad():
        for batch_idx, (hsi, lidar, labels, hsi_pca) in enumerate(all_loader):
            hsi = hsi.to(device)
            lidar = lidar.to(device)
            hsi_pca = hsi_pca.to(device)

            # 进行预测
            outputs, _ = model(hsi, lidar, hsi_pca)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()  # 获取预测类别
            predictions.extend(preds)

            # 如果有真实标签，可以保存用于对比
            # if labels is not None:
            #     ground_truth.extend(labels.cpu().numpy())

    # 将预测结果转换为图像格式
    predictions = np.array(predictions)
    # if len(ground_truth) > 0:
    #     ground_truth = np.array(ground_truth)

    # 获取原始数据的形状（假设数据是二维的）
    hsi_path = "data/tlse/processed_hsi.h5"  # 替换为你的数据路径
    with h5py.File(hsi_path, 'r') as h5_file:
        hsi_data = h5_file['hyperspectral_matrix'][:] 
    height, width, _ = hsi_data.shape

    # 将预测结果重塑为原始图像形状
    pred_image = np.zeros((height, width), dtype=np.uint8)
    # gt_image = np.zeros((height, width), dtype=np.uint8) if len(ground_truth) > 0 else None

    # 填充预测结果
    all_index = loadmat("data/tlse/tlse_index.mat")['tlse_all']  # 替换为你的索引路径
    for idx, (h, w) in enumerate(all_index):
        pred_image[h, w] = predictions[idx] + 1  # 类别从0开始，+1是为了与真实标签对齐
        # if gt_image is not None:
        #     gt_image[h, w] = ground_truth[idx]

    # 可视化预测结果
    plt.figure(figsize=(12, 6))

    # 绘制预测结果
    plt.subplot(1, 2, 1)
    plt.imshow(pred_image, cmap='jet')
    plt.title("Predicted Labels")
    plt.colorbar()

    # # 如果有真实标签，绘制真实标签
    # if gt_image is not None:
    #     plt.subplot(1, 2, 2)
    #     plt.imshow(gt_image, cmap='jet')
    #     plt.title("Ground Truth")
    #     plt.colorbar()
    # 
    # plt.tight_layout()
    # plt.show()

# 在主函数中调用可视化函数
if __name__ == '__main__':
    # 加载最佳模型
    class Args:
        is_train = 0
        is_load_pretrain = 0
        is_pretrain = 1
        is_test = 0
        model_file = 'model'
        size_SA = 49
        channel_number = 291
        epoch = 500
        pca_num = 30
        mask_ratio = 0.7
        crop_size = 7
        device = "cuda:0"
        dataset = 'Tlse'
        num_classes = 13
        pretrain_num = 400000
        patch_size = 1
        finetune = 0
        mae_pretrain = 1
        depth = 2
        head = 8
        dim = 256
        model_name = None
        warmup_epochs = 5
        test_interval = 5
        optimizer_name = "adamw"
        lr = 1e-4
        cosine = 0
        weight_decay = 5e-2
        batch_size = 256

    args = Args()
    
    model = MAEFinetune(
    channel_number=args.channel_number,
    img_size=args.crop_size,
    patch_size=args.patch_size,
    embed_dim=args.dim,
    depth=args.depth,
    num_heads=args.head,
    num_classes=args.num_classes,
    args=args
    )
    
    device='cuda'
    save_dir = os.path.join('model', 'train', '20250120_235027')
    model_path = os.path.join(save_dir, 'best_model.pth')
    checkpoint = torch.load(model_path, map_location="cpu")
    model = model.to(device)
    model.cuda(device=device)
    
    model.load_state_dict(checkpoint['state_dict'])
    
    
    pretrain_loader, train_loader, test_loader, trntst_loader, all_loader = get_dataset(args)
    
    # 调用可视化函数
    visualize_predictions(args, model, all_loader)