In [1]:
import fastmri
import torch
import numpy as np
from fastmri.data.mri_data import SliceDataset
from fastmri.models import VarNet
from fastmri.data.subsample import RandomMaskFunc,EquiSpacedMaskFunc,EquispacedMaskFractionFunc
from fastmri.data import transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchmetrics.functional import structural_similarity_index_measure as SSIM
from torchmetrics.functional import peak_signal_noise_ratio as PSNR

## fastmri dataset
dataset = SliceDataset(root='/data3/M4Raw/multicoil_val',challenge='multicoil',
                       transform=T.VarNetDataTransform(EquispacedMaskFractionFunc(center_fractions=[0.1171875], accelerations=[2])))
val_loader = DataLoader(dataset, batch_size=1, shuffle=False,
                            num_workers=4,drop_last=False,pin_memory=True)

In [4]:
d = dataset[0]
for i in d:
    try:
        print(i.shape)
    except:
        print(i)

torch.Size([4, 256, 256, 2])
torch.Size([1, 1, 256, 1])
30
torch.Size([256, 256])
2022061203_FLAIR01.h5
0
()
(256, 256)


In [4]:
ssim_list = []
psnr_list = []

## build model load weight
model = VarNet(num_cascades=12,sens_chans=8,sens_pools=4,chans=18,pools=4).cuda()
checkpoint = torch.load('./ckpt/varnet_M4Raw_init_v1.ckpt')["state_dict"]
checkpoint = {k.replace("varnet.", "",1): v for k, v in checkpoint.items()}
del checkpoint["loss.w"]
model.load_state_dict(checkpoint, strict=True)
del checkpoint
model.eval()
## inference
for index,d in enumerate(tqdm(val_loader)):
    with torch.no_grad():
        pre = model(d[0].cuda(),d[1].cuda(),num_low_frequencies=d[2].cuda())
        ssim_list.append(SSIM(pre.unsqueeze(0),d[3].unsqueeze(0).cuda(),data_range = d[-2].cuda()).item())## 
        psnr_list.append(PSNR(pre.unsqueeze(0),d[3].unsqueeze(0).cuda(),data_range = d[-2].cuda()).item())
        
print(f'SSIM:{round(sum(ssim_list)/len(ssim_list),2)}\nPSNR:{round(sum(psnr_list)/len(psnr_list),2)}')

100%|██████████| 1620/1620 [01:46<00:00, 15.23it/s]


SSIM:86.07
PSNR:34.8
