In [4]:
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from torch.utils.data import DataLoader
# from dataset.dataloader import CINE2DT
# from model.model_pytorch import CRNN_MRI
# from utils import multicoil2single, compressed_sensing as cs,IFFT2c
# from ..utils import multicoil2single, compressed_sensing as cs
# from utils.dnn_io import to_tensor_format, from_tensor_format
# from utils.fastmriBaseUtils import FFT2c,IFFT2c
# from trainer_dcrnn_test import prep_input
from torch.autograd import Variable
import os 
from PIL import Image
# import scipy.io as sio
from scipy.io import loadmat
import scipy.io as scio
import h5py

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定使用 GPU 1 和 GPU 4
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

In [6]:
import medutils
print(dir(medutils))  # 查看所有可用子模块

['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__']


In [13]:
import yaml
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from matplotlib import gridspec

def center_crop(image, target_shape):
    """自定义中心裁剪函数"""
    if image.ndim == 2:
        h, w = image.shape
        target_h, target_w = target_shape
        start_h = (h - target_h) // 2
        start_w = (w - target_w) // 2
        return image[start_h:start_h+target_h, start_w:start_w+target_w]
    elif image.ndim == 3:
        t, h, w = image.shape
        target_h, target_w = target_shape
        start_h = (h - target_h) // 2
        start_w = (w - target_w) // 2
        return image[:, start_h:start_h+target_h, start_w:start_w+target_w]
    else:
        raise ValueError("Unsupported image dimensions")

def load_4d_data(file_path, config):
    """处理四维数据 (slices=118, time=18, H=192, W=192)"""
    data = np.load(file_path)
    
    # 验证数据维度
    if data.shape != (118, 18, 192, 192):
        raise ValueError(f"数据维度错误: 应为 (118,18,192,192)，实际得到 {data.shape}")
    
    # 提取指定切片和时间点
    slice_idx = np.clip(config['slice_index'], 0, 117)
    time_idx = np.clip(config['time_index'], 0, 17)
    
    # 提取并归一化图像
    image = np.abs(data[slice_idx, time_idx])
    return (image - image.min()) / (image.max() - image.min())

# import yaml
# import numpy as np
# import matplotlib.pyplot as plt
# from medutils.visualization import center_crop
# from scipy.interpolate import interp1d
# from matplotlib import gridspec

# def load_4d_data(file_path, config):
#     """处理正确维度的4D数据 (slices=118, time=18, H=192, W=192)"""
#     data = np.load(file_path)
    
#     # 验证数据维度
#     if data.shape != (118, 18, 192, 192):
#         raise ValueError(f"Invalid data shape: {data.shape}, expected (118, 18, 192, 192)")
    
#     # 提取指定切片和时间点
#     slice_idx = np.clip(config['slice_index'], 0, 117)
#     time_idx = np.clip(config['time_index'], 0, 17)
    
#     # 提取二维图像并归一化
#     image = np.abs(data[slice_idx, time_idx])
#     return (image - image.min()) / (image.max() - image.min())
def load_config(config_file):
    try:
        with open(config_file, 'r') as file:
            return yaml.safe_load(file)
    except FileNotFoundError:
        raise Exception(f"配置文件 {config_file} 未找到")
    except yaml.YAMLError as exc:
        raise Exception(f"配置文件解析错误: {str(exc)}")
def visualize_results(config_path):
    # 加载配置
    config = load_config(config_path)['my_config']
    # config = load_config(config_path)
    
    # 加载参考数据
    ref_image = load_4d_data(config['reference_file'], config) if config.get('reference_file') else None
    
    # 加载对比数据
    comp_images = []
    for fpath in config['comparison_files']:
        try:
            img = load_4d_data(fpath, config)
            comp_images.append(img)
        except Exception as e:
            print(f"Error loading {fpath}: {str(e)}")
            comp_images.append(np.zeros((192,192)))  # 用空白图占位
    
    # 统一图像尺寸
    base_shape = (192, 192)
    if ref_image is not None:
        ref_image = center_crop(ref_image, base_shape)
    comp_images = [center_crop(img, base_shape) for img in comp_images]
    
    # 创建可视化布局
    num_cols = len(comp_images) + (1 if ref_image else 0)
    fig = plt.figure(figsize=(num_cols*3, 6))  # 动态调整宽度
    
    # 主图布局
    gs_main = gridspec.GridSpec(2, num_cols, height_ratios=[3,1])
    
    # 绘制主图
    def plot_slice(ax, img, title):
        im = ax.imshow(img, cmap='gray', 
                      vmin=config['display_range'][0], 
                      vmax=config['display_range'][1])
        ax.set_title(title, fontsize=10)
        ax.axis('off')
        return im
    
    # 绘制参考图
    if ref_image is not None:
        ax_ref = plt.subplot(gs_main[0, 0])
        im_ref = plot_slice(ax_ref, ref_image, "Reference")
    
    # 绘制对比图
    for idx, (img, title) in enumerate(zip(comp_images, config['comparison_titles'])):
        col = idx + (1 if ref_image else 0)
        ax = plt.subplot(gs_main[0, col])
        plot_slice(ax, img, title)
    
    # 绘制ROI区域
    roi = config['roi_coordinates']
    ax_roi = plt.subplot(gs_main[1, :])
    ax_roi.imshow(ref_image[roi[1]:roi[3], roi[0]:roi[2]], cmap='gray')
    ax_roi.set_title("ROI Zoom", fontsize=8)
    ax_roi.axis('off')
    
    # 添加颜色条
    cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    plt.colorbar(im_ref, cax=cax)
    
    # 保存结果
    plt.savefig(f'result_s{config["slice_index"]}_t{config["time_index"]}.png', 
                dpi=300, bbox_inches='tight')
    plt.close()



In [14]:
if __name__ == "__main__":
    visualize_results('figconfig_zzy.yml')

Error loading /data0/zhiyong/code/github/itzzy_git/k_gin_base/output/ls/ls_acc_8_merge.npy: 数据维度错误: 应为 (118,18,192,192)，实际得到 (192, 192, 18, 118)


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()