In [None]:
import torch
from medcam import medcam
from monai import transforms
import numpy as np
import nibabel as nib
from model import VoxResNet, VGG3D

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VGG3D().to(device)
# model = VoxResNet().to(device)

# checkpoint_path = r"C:\Users\17993\Downloads\best_epoch17.pth"
# checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))

# # 加载模型权重
# model.load_state_dict(checkpoint['model'])

# # 获取保存的accuracy, loss和epoch
# accuracy = checkpoint['accuracy']
# validation_loss = checkpoint['loss']
# epoch = checkpoint['epoch']
# print(f"Model loaded successfully with accuracy: {100.*accuracy:.2f}%, loss: {validation_loss:.6f}, at epoch: {epoch}")

model = medcam.inject(model, output_dir='attention_maps', backend='gcam', save_maps=False, return_attention=True, layer='auto')

In [None]:
transform = transforms.Compose([
    transforms.Resize(spatial_size=[110, 110, 110]),
    transforms.NormalizeIntensity(nonzero=True, channel_wise=True),
])

In [None]:
img_path = r"C:\Custom\DataSet\ADNI_预处理后\Image\brain_adni_0021_I196077_fsld.nii.gz"
nii_img = nib.load(img_path).get_fdata()
nii_img = nii_img.astype(np.float32)
original_img = nii_img
nii_img = torch.from_numpy(nii_img)
nii_img = nii_img.unsqueeze(0)
nii_img = transform(nii_img)
nii_img = nii_img.unsqueeze(0)
nii_img = nii_img.as_tensor()
print(nii_img.shape)

In [None]:
model.eval()
with torch.no_grad():
    # 进行预测
    outputs, attention_map = model(nii_img)
    # 输出处理，获取预测结果
    _, predicted = torch.max(outputs, 1)
    attention_map = attention_map.detach().cpu().numpy()
attention_map = np.squeeze(attention_map)
print('predicted:', predicted)
print('attention_map shape:', attention_map.shape)

In [None]:
import matplotlib.pyplot as plt
plt.imsave('./attention_maps/test.jpg', attention_map[40], cmap="jet")

In [None]:
slice_idx = 50
print(original_img[:,:,slice_idx].shape)
plt.imshow(np.flipud(original_img[:,:,slice_idx].T), cmap='gray')
plt.colorbar()
plt.show()
plt.close()

In [None]:
cnt = 0
for i in range(110):
    for j in range(110):
        for k in range(110):
            if attention_map[i][j][k] == 0:
                cnt += 1
            else:
                print("test")
print(cnt)

In [None]:
plt.imshow(np.flipud(attention_map[:,:,slice_idx].T), cmap='jet')
plt.colorbar()
plt.show()
plt.close()