In [1]:
%env CUDA_VISIBLE_DEVICES=3
import sys
sys.path.append('/home/a_razumov/projects/k-space-mri')
import numpy as np
import pylab as plt
import pickle
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F

# from k_space_reconstruction.utils.metrics import pt_msssim, pt_ssim
from k_space_reconstruction.datasets.acdc import ACDCSet, ACDCTransform, RandomMaskFunc
from k_space_reconstruction.datasets.fastmri import FastMRIh5Dataset, FastMRITransform, LegacyFastMRIh5Dataset
from k_space_reconstruction.utils.kspace import EquispacedMaskFunc, RandomMaskFunc
from k_space_reconstruction.utils.kspace import pt_spatial2kspace as Ft
from k_space_reconstruction.utils.kspace import pt_kspace2spatial as IFt

import os
import sys
from k_space_reconstruction.nets.unet import Unet
from k_space_reconstruction.nets.enet import ENet
from k_space_reconstruction.nets.mwcnn import MWCNN
import datetime
import torch
import torchvision
import numpy as np
import pylab as plt
plt.style.use('bmh')
import albumentations
import numpy as np
import h5py
import pylab as plt
import torch
import torch.nn.functional as F

import os
import re
import numpy as np
import pandas as pd
import nibabel

import torch
import torch.utils.data
import torchvision.transforms as transforms

from os.path import isdir, join
from typing import Callable, Dict, List, Any

env: CUDA_VISIBLE_DEVICES=3


In [2]:
fig_bezzeless = lambda nc, nr : plt.subplots(ncols=nc, nrows=nr, figsize=(4 * nc, 4), dpi=120, 
                                             subplot_kw=dict(frameon=False, xticks=[], yticks=[]), 
                                             gridspec_kw=dict(wspace=0.0, hspace=0.0))
    

def train_sampling_pattern(train_generator, n=14, loss_fn=F.l1_loss):
    bks = next(iter(train_generator))
    w = torch.zeros(320).cuda().float()
    w[160] = 1
    bbatch = 32
    w_list = []
    pbar = tqdm(range(n))
    for count in pbar:
        w = torch.autograd.Variable(w, requires_grad=True)
        for j in range(bks.shape[0] // bbatch):
            bbks = bks[bbatch*j:bbatch*(j+1)].cuda()
            bbgt = IFt(bbks).abs()
            recs = IFt(bbks * w).abs()
            loss = loss_fn(recs, bbgt)
            loss.backward()
        for i in torch.topk(w.grad, 320, largest=False).indices:
            if w[i] == 0: 
                w = w.detach()
                w[i] = 1.
                w_list.append(w.clone())
                pbar.set_description('select: %d, loss: %.6f' % (i.item(), loss.item()))
                break
    return w_list

def test_sampling_pattern(sampling, val_generator, score_fn=F.l1_loss):
    bks = next(iter(train_generator))
    scores = []
    bbatch = 32
    for j in tqdm(range(bks.shape[0] // bbatch)):
        bbks = bks[bbatch*j:bbatch*(j+1)].cuda()
        bbgt = IFt(bbks).abs()
        with torch.no_grad():
            recs = IFt(bbks * sampling).abs()
            scores.append(score_fn(recs, bbgt).item())
    return scores


class BraTS3dDataset(torch.utils.data.Dataset):

    def __init__(self, hf_path: str, aug=None, indexes=None):
        super().__init__()
        self.hf = h5py.File(hf_path)
        self.aug = aug
        if indexes is not None:
            self.indexes = [str(i) for i in indexes]
        else:
            self.indexes = [k for k in self.hf.keys()]

    def __len__(self) -> int:
        return len(self.indexes)
    
    def __getitem__(self, item: int):
        key = self.indexes[item]
        img = self.hf[key][:-1,:,:,:]
        mask = self.hf[key][-1:,:,:,:]
        mask[mask==4] = 3
        if self.aug:
            aug = self.aug(image=img, mask=mask)
            img = aug['image']
            mask = aug['mask']
        img = torch.tensor(img).float()
        mask = torch.tensor(mask).long()
        img = img.movedim(3,1)
        mask = mask.movedim(3,1)
        mean = img.mean(dim=(2,3)).unsqueeze(2).unsqueeze(2)
        std = img.std(dim=(2,3)).unsqueeze(2).unsqueeze(2)
        img = (img - mean) / (std + 1e-11)
        return Ft(img * std + mean)


class BraTS2dDataset(torch.utils.data.Dataset):

    def __init__(self, hf_path: str, aug1=None, aug2=None, slices=None, indexes=None):
        super().__init__()
        self.hf = h5py.File(hf_path)
        self.aug1 = aug1
        self.aug2 = aug2
        if slices:
            self.slices = slices
        else:
            self.slices = []
            for k in tqdm(self.hf.keys()):
                for j in range(self.hf[k].shape[-1]):
                    if self.hf[k][:,:,:,j].max() != 0.:
                        self.slices.append((k, j))
        if indexes is not None:
            self.slices = [s for s in self.slices if int(s[0]) in indexes]

    def __len__(self) -> int:
        return len(self.slices)
    
    def __getitem__(self, item: int):
        key, idx = self.slices[item]
        img = self.hf[key][:-1,:,:,idx]
        mask = self.hf[key][-1:,:,:,idx]
        mask[mask==4] = 3
        if self.aug1:
            aug = self.aug1(image=img, mask=mask)
            img = aug['image']
            mask = aug['mask']
        img = torch.tensor(img).float()
        if self.aug2:
            img = self.aug2(img.unsqueeze(0))[0]
        mask = torch.tensor(mask).long()
        mean = img.mean().unsqueeze(0).unsqueeze(0).unsqueeze(0)
        std = img.std().unsqueeze(0).unsqueeze(0).unsqueeze(0)
        img = (img - mean) / (std + 1e-11) + 1e-11
        return Ft(img * std + mean)

In [3]:
def t2i(x):
    x = x - x.min()
    x = x / x.max()
    x = x * 255.
    return x

def pt_ssim(pred, gt):
    from pytorch_msssim import ssim
    return ssim(t2i(pred)[None], t2i(gt)[None])

def pt_msssim(pred, gt):
    from pytorch_msssim import ms_ssim
    return ms_ssim(t2i(pred)[None], t2i(gt)[None])

def pt_nmse(pred, gt):
    return torch.norm(gt - pred, p=2) ** 2 / torch.norm(gt, p=2) ** 2

def pt_psnr(pred, gt):
    maxval = gt.max()
    mse = torch.mean((pred - gt) ** 2)
    return 20 * torch.log10(maxval / torch.sqrt(mse))

In [4]:
np.random.seed(42)
torch.manual_seed(42)

with open('brats_train.cache', 'rb') as f:
    train_slices_cache = pickle.load(f)
with open('brats_val.cache', 'rb') as f:
    val_slices_cache = pickle.load(f)

train_dataset = BraTS2dDataset('/home/a_razumov/small_datasets/brats_h5/train.h5', slices=train_slices_cache)
val_dataset = BraTS2dDataset('/home/a_razumov/small_datasets/brats_h5/val.h5', slices=val_slices_cache)
train_3d_dataset = BraTS3dDataset('/home/a_razumov/small_datasets/brats_h5/train.h5')
val_3d_dataset = BraTS3dDataset('/home/a_razumov/small_datasets/brats_h5/val.h5')
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=32)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=32)

In [5]:
len(train_dataset), len(val_dataset), len(train_3d_dataset), len(val_3d_dataset)

(35508, 15391, 258, 111)

# 5-fold validation x16

In [6]:
from collections import defaultdict
from sklearn.model_selection import KFold

fastmri_mask_x16 = torch.tensor(EquispacedMaskFunc([0.04], [16])((640, 320))[0]).cuda().float()
zm = torch.zeros(320).cuda().float()
zm[320//2 - int(20)//2 : 320//2 + int(20)//2] = 1
fm = torch.ones(320).cuda().float()

fold_scores = defaultdict(dict)

for i, (train_id, val_id) in enumerate(KFold(shuffle=False).split(range(len(train_dataset)))):
    td = torch.utils.data.Subset(train_dataset, train_id)
    vd = torch.utils.data.Subset(train_dataset, val_id)
    tg = torch.utils.data.DataLoader(td, batch_size=len(td), shuffle=True)
    vg = torch.utils.data.DataLoader(vd, batch_size=len(vd), shuffle=False)
    w_list = train_sampling_pattern(tg, n=19, loss_fn=F.l1_loss)
    fold_scores[i] = dict(
        ours=test_sampling_pattern(w_list[18], vg, score_fn=pt_ssim),
        fastmri=test_sampling_pattern(fastmri_mask_x16, vg, score_fn=pt_ssim), 
        center=test_sampling_pattern(zm, vg, score_fn=pt_ssim),
    )
fold_scores = [{k:[vv for vv in v] for k,v in dv.items()} for dv in fold_scores.values()]
with open('fold_scores_l2.pkl', mode='wb') as f: pickle.dump(fold_scores, f)

KeyboardInterrupt: 