In [1]:
%env CUDA_VISIBLE_DEVICES=3
import sys
sys.path.append('/home/a_razumov/projects/k-space-mri')
import numpy as np
import pylab as plt
import pickle
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F

from k_space_reconstruction.utils.metrics import pt_msssim, pt_ssim
from k_space_reconstruction.datasets.acdc import ACDCSet, ACDCTransform, RandomMaskFunc
from k_space_reconstruction.datasets.fastmri import FastMRIh5Dataset, FastMRITransform, LegacyFastMRIh5Dataset
from k_space_reconstruction.utils.kspace import EquispacedMaskFunc, RandomMaskFunc
from k_space_reconstruction.utils.kspace import pt_spatial2kspace as Ft
from k_space_reconstruction.utils.kspace import pt_kspace2spatial as IFt

import os
import sys
from k_space_reconstruction.nets.unet import Unet
from k_space_reconstruction.nets.enet import ENet
from k_space_reconstruction.nets.mwcnn import MWCNN
import datetime
import torch
import torchvision
import numpy as np
import pylab as plt
plt.style.use('bmh')
import albumentations
import numpy as np
import h5py
import pylab as plt
import torch
import torch.nn.functional as F

import os
import re
import numpy as np
import pandas as pd
import nibabel

import torch
import torch.utils.data
import torchvision.transforms as transforms

from os.path import isdir, join
from typing import Callable, Dict, List, Any

env: CUDA_VISIBLE_DEVICES=3


In [2]:
fig_bezzeless = lambda nc, nr : plt.subplots(ncols=nc, nrows=nr, figsize=(4 * nc, 4), dpi=120, 
                                             subplot_kw=dict(frameon=False, xticks=[], yticks=[]), 
                                             gridspec_kw=dict(wspace=0.0, hspace=0.0))


def ce_loss(true, logits, weights, ignore=255):
    torch.nn.CrossEntropyLoss
    ce_loss = torch.nn.functional.cross_entropy(
        logits.float(),
        true.long(),
        ignore_index=ignore,
        weight=weights,
    )
    return ce_loss


def dice_loss(true, logits, eps=1e-7):
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(logits)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        probas = torch.nn.functional.softmax(logits, dim=1)
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    dice_loss = (2. * intersection / (cardinality + eps)).mean()
    return (1 - dice_loss)

def dice_coeffs(true, logits):
    num_classes = logits.shape[1]
    probas = F.softmax(logits, dim=1)
    probas[probas > 0.5] = 1; probas[probas <= 0.5] = 0
    pmask = torch.zeros_like(true).float()
    for i in range(1, num_classes):
        pmask[:,0] += i * probas[:,i]
    dice_ls = []
    for i in range(1, num_classes):
        yt = (true==i).float().flatten()
        yp = (pmask==i).float().flatten()
        intersection = torch.sum(yt * yp)
        cardinality = torch.sum(yt + yp)
        dice_ls.append((2. * intersection / (cardinality + 1e-7)).item())
    return dice_ls   

def finetune_model_on_sampling(train_generator, model, sampling, epochs=5, return_losses=False):
    if not os.path.exists('acdc_unet_checkpoints'):
        os.makedirs('acdc_unet_checkpoints')
    losses = np.zeros(epochs)
    checkpoints = []
    criterion = lambda p,t : dice_loss(t, p) * .75 + ce_loss(t.squeeze(1), p, weights=None) * .25
    metric = lambda p,t : 1 - dice_loss(t, p)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    pbar = tqdm(range(epochs))
    for epoch in pbar:
        model = model.train()
        for _, targets, images, means, stds in train_generator:
            images = images.cuda(); targets = targets.cuda().long(); means = means.cuda(); stds = stds.cuda()
            optimizer.zero_grad()
            # backpropagate
            images = (IFt(Ft(images * stds + means) * sampling).abs() - means) / (stds + 1e-11)
            pred = model(images)
            loss = criterion(pred, targets)
            loss.backward()
            optimizer.step()
            losses[epoch] += loss.item() / len(train_generator)
        model = model.train(False).eval()
        checkpoint_path = join('acdc_unet_checkpoints', 'epoch%d.pth' % epoch)
        torch.save(model.state_dict(), checkpoint_path)
        checkpoints.append(checkpoint_path)
    del optimizer
    best_checkpoint = checkpoints[np.argmin(losses)]
    model.load_state_dict(torch.load(best_checkpoint))
    if return_losses:
        return model, losses
    else:
        return model

def train_sampling_pattern(train_generator, model, n=14):
    c, bmasks, images, bmean, bstd = next(iter(train_generator))
    bks = Ft(images * bstd + bmean)
    bgt = IFt(bks).abs()
    w = torch.zeros(256).cuda().float()
    w[128] = 1
    bbatch = 32
    w_list = []
    pbar = tqdm(range(n))
    for count in pbar:
        w = torch.autograd.Variable(w, requires_grad=True)
        for j in range(bks.shape[0] // bbatch):
            bbks = bks[bbatch*j:bbatch*(j+1)].cuda()
            bbgt = bgt[bbatch*j:bbatch*(j+1)].cuda()
            bbmean = bmean[bbatch*j:bbatch*(j+1)].cuda()
            bbstd = bstd[bbatch*j:bbatch*(j+1)].cuda()
            bbmasks = bmasks[bbatch*j:bbatch*(j+1)].cuda()
            recs = IFt(bbks * w).abs()
            pm = model((recs - bbmean) / (bbstd + 1e-11))
            loss = dice_loss(bbmasks.long(), pm)
            loss.backward()
        for i in torch.topk(w.grad, 256, largest=False).indices:
            if w[i] == 0: 
                w = w.detach()
                w[i] = 1.
                w_list.append(w.clone())
                pbar.set_description('select: %d, loss: %.6f' % (i.item(), loss.item()))
                break
    return w_list

def test_sampling_pattern(sampling, model, val_generator):
    vc, vbmasks, vimages, vbmean, vbstd = next(iter(val_generator))
    vbks = Ft(vimages * vbstd + vbmean)
    vbgt = IFt(vbks).abs()
    dice_scores = []
    bbatch = 32
    for j in tqdm(range(vbks.shape[0] // bbatch)):
        vbbks = vbks[bbatch*j:bbatch*(j+1)].cuda()
        vbbgt = vbgt[bbatch*j:bbatch*(j+1)].cuda()
        vbbmean = vbmean[bbatch*j:bbatch*(j+1)].cuda()
        vbbstd = vbstd[bbatch*j:bbatch*(j+1)].cuda()
        vbbmasks = vbmasks[bbatch*j:bbatch*(j+1)].cuda()
        with torch.no_grad():
            # igs
            recs = IFt(vbbks * sampling).abs()
            pm = model((recs - vbbmean) / (vbbstd + 1e-11))
            for i in range(recs.shape[0]):
                dice_scores.append(1 - dice_loss(vbbmasks.long(), pm).item())
    return dice_scores

def test_on_classes_sampling_pattern(sampling, model, val_generator):
    vc, vbmasks, vimages, vbmean, vbstd = next(iter(val_generator))
    vbks = Ft(vimages * vbstd + vbmean)
    vbgt = IFt(vbks).abs()
    dice_scores = []
    bbatch = 1
    for j in tqdm(range(vbks.shape[0] // bbatch)):
        vbbks = vbks[bbatch*j:bbatch*(j+1)].cuda()
        vbbgt = vbgt[bbatch*j:bbatch*(j+1)].cuda()
        vbbmean = vbmean[bbatch*j:bbatch*(j+1)].cuda()
        vbbstd = vbstd[bbatch*j:bbatch*(j+1)].cuda()
        vbbmasks = vbmasks[bbatch*j:bbatch*(j+1)].cuda()
        with torch.no_grad():
            # igs
            recs = IFt(vbbks * sampling).abs()
            pm = model((recs - vbbmean) / (vbbstd + 1e-11))
            for i in range(recs.shape[0]):
                dice_scores.append(dice_coeffs(vbbmasks.long(), pm))
    return dice_scores

class ACDCDataset(torch.utils.data.Dataset):
    CLASSES = {0: 'NOR', 1: 'MINF', 2: 'DCM', 3: 'HCM', 4: 'RV'}

    def __init__(self, hf_path: str):
        super().__init__()
        self.hf = h5py.File(hf_path)

    def __len__(self) -> int:
        return len(self.hf)

    def __getitem__(self, item: int):
        img = self.hf[str(item)][:1]
        mask = self.hf[str(item)][1:]
        c = self.hf[str(item)].attrs['class']
        img = torch.tensor(img).float()
        mask = torch.tensor(mask)
        mean = img.mean()
        std = img.std()
        img = (img - mean) / (std + 1e-11)
        return c, mask, img, mean.unsqueeze(0).unsqueeze(0).unsqueeze(0), std.unsqueeze(0).unsqueeze(0).unsqueeze(0)

In [3]:
def pt_psnr(img1, img2, maxval):
    mse = torch.mean((img1 - img2) ** 2)
    return 20 * torch.log10(maxval / torch.sqrt(mse))

In [4]:
train_dataset = ACDCDataset('/home/a_razumov/small_datasets/acdc_seg_h5/train.h5')
val_dataset = ACDCDataset('/home/a_razumov/small_datasets/acdc_seg_h5/val.h5')
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True, num_workers=6)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)

In [5]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = Unet(1, 3+1).to(device).train(False).eval()
model.load_state_dict(torch.load('unet-acdc-norot.pt'))

<All keys matched successfully>

# 5-fold validation x16

In [None]:
from collections import defaultdict
from sklearn.model_selection import KFold

fastmri_mask_x16 = torch.tensor(EquispacedMaskFunc([0.04], [16])((256, 256))[0]).cuda().float()
zm = torch.zeros(256).cuda().float()
zm[256//2 - int(16)//2 : 256//2 + int(16)//2] = 1
fm = torch.ones(256).cuda().float()

model_zm = Unet(1, 3+1).to(device).train(False).eval()
model_zm.load_state_dict(torch.load('unet-acdc-norot.pt'))
model_igs = Unet(1, 3+1).to(device).train(False).eval()
model_igs.load_state_dict(torch.load('unet-acdc-norot.pt'))
model_fastmri = Unet(1, 3+1).to(device).train(False).eval()
model_fastmri.load_state_dict(torch.load('unet-acdc-norot.pt'))

dice_fold_scores = defaultdict(dict)

for i, (train_id, val_id) in enumerate(KFold(shuffle=False).split(range(len(train_dataset)))):
    td = torch.utils.data.Subset(train_dataset, train_id)
    vd = torch.utils.data.Subset(train_dataset, val_id)
    tg = torch.utils.data.DataLoader(td, batch_size=len(td), shuffle=True)
    vg = torch.utils.data.DataLoader(vd, batch_size=len(vd), shuffle=False)
    model_igs.load_state_dict(torch.load('unet-acdc-norot.pt'))
    for _ in range(1):
        tg = torch.utils.data.DataLoader(td, batch_size=len(td), shuffle=True)
        w = train_sampling_pattern(tg, model_igs, n=16)[14]
        tg = torch.utils.data.DataLoader(td, batch_size=16, shuffle=True)
        vg = torch.utils.data.DataLoader(vd, batch_size=16, shuffle=False)
        model_igs = finetune_model_on_sampling(tg, model_igs, w, epochs=25)
    model_zm.load_state_dict(torch.load('unet-acdc-norot.pt'))
    model_zm = finetune_model_on_sampling(tg, model_zm, zm, epochs=25)
    model_fastmri.load_state_dict(torch.load('unet-acdc-norot.pt'))
    model_fastmri = finetune_model_on_sampling(tg, model_fastmri, fastmri_mask_x16, epochs=25)
    tg = torch.utils.data.DataLoader(td, batch_size=len(td), shuffle=True)
    vg = torch.utils.data.DataLoader(vd, batch_size=len(vd), shuffle=False)
    dice_fold_scores[i] = dict(
        ours=test_sampling_pattern(w, model_igs, vg),
        fastmri=test_sampling_pattern(fastmri_mask_x16, model_fastmri, vg), 
        center=test_sampling_pattern(zm, model_zm, vg),
    )
fold_scores = [{k:[vv for vv in v] for k,v in dv.items()} for dv in dice_fold_scores.values()]
with open('unet_finetune_fold_scores.pkl', mode='wb') as f: pickle.dump(fold_scores, f)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))

In [6]:
with open('unet_finetune_fold_scores.pkl', 'rb') as f: fold_scores = pickle.load(f)

In [9]:
import scipy

df_dice = pd.DataFrame.from_dict({
    'fastmri': [np.mean(v['fastmri']) for v in fold_scores],
    'center': [np.mean(v['center']) for v in fold_scores],
    'ours': [np.mean(v['ours']) for v in fold_scores]
})
print(scipy.stats.ttest_rel(df_dice.center, df_dice.ours))
df_dice.describe()

Ttest_relResult(statistic=-5.147740121761072, pvalue=0.006755003802498928)


Unnamed: 0,fastmri,center,ours
count,5.0,5.0,5.0
mean,0.81486,0.857171,0.87428
std,0.019786,0.008406,0.011858
min,0.79147,0.84843,0.857907
25%,0.805183,0.851854,0.869292
50%,0.809801,0.855422,0.876228
75%,0.824623,0.860074,0.877775
max,0.843222,0.870074,0.8902


# Training on full train

In [6]:
fastmri_mask_x16 = torch.tensor(EquispacedMaskFunc([0.04], [16])((256, 256))[0]).cuda().float()
zm = torch.zeros(256).cuda().float()
zm[256//2 - int(16)//2 : 256//2 + int(16)//2] = 1
fm = torch.ones(256).cuda().float()

model_zm = Unet(1, 3+1).to(device).train(False).eval()
model_zm.load_state_dict(torch.load('unet-acdc-norot.pt'))
model_igs = Unet(1, 3+1).to(device).train(False).eval()
model_igs.load_state_dict(torch.load('unet-acdc-norot.pt'))
model_fastmri = Unet(1, 3+1).to(device).train(False).eval()
model_fastmri.load_state_dict(torch.load('unet-acdc-norot.pt'))

tg = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)

# for _ in range(3):
w = train_sampling_pattern(train_generator, model_igs, n=16)[14]
model_igs = finetune_model_on_sampling(tg, model_igs, w, epochs=25)
    
model_zm = finetune_model_on_sampling(tg, model_zm, zm, epochs=25)
model_fastmri = finetune_model_on_sampling(tg, model_fastmri, fastmri_mask_x16, epochs=25)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))




In [7]:
torch.save(w, 'sampling_igs_finetune.pt')
torch.save(model_fastmri.state_dict(), 'model_fastmri.pt')
torch.save(model_zm.state_dict(), 'model_zm.pt')
torch.save(model_igs.state_dict(), 'model_igs.pt')

In [8]:
fastmri_mask_x16 = torch.tensor(EquispacedMaskFunc([0.04], [16])((256, 256))[0]).cuda().float()
zm = torch.zeros(256).cuda().float()
zm[256//2 - int(16)//2 : 256//2 + int(16)//2] = 1
fm = torch.ones(256).cuda().float()
w = torch.load('sampling_igs_finetune.pt')

model_zm = Unet(1, 3+1).to(device).train(False).eval()
model_zm.load_state_dict(torch.load('model_zm.pt'))
model_igs = Unet(1, 3+1).to(device).train(False).eval()
model_igs.load_state_dict(torch.load('model_igs.pt'))
model_fastmri = Unet(1, 3+1).to(device).train(False).eval()
model_fastmri.load_state_dict(torch.load('model_fastmri.pt'))

<All keys matched successfully>

In [9]:
fastmri_mask_x16.sum(), zm.sum(), w.sum()

(tensor(16., device='cuda:0'),
 tensor(16., device='cuda:0'),
 tensor(16., device='cuda:0'))

In [10]:
dice_scores = dict(
    fastmri=test_sampling_pattern(fastmri_mask_x16, model_fastmri, val_generator), 
    center=test_sampling_pattern(zm, model_zm, val_generator),
    ours=test_sampling_pattern(w, model_igs, val_generator),
    full=test_sampling_pattern(fm, model, val_generator),
)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




In [12]:
df_dice = pd.DataFrame.from_dict(dice_scores)
df_dice.describe()

Unnamed: 0,fastmri,center,ours,full
count,512.0,512.0,512.0,512.0
mean,0.819452,0.85974,0.879824,0.922834
std,0.037434,0.02824,0.02053,0.015547
min,0.71624,0.794961,0.835609,0.885007
25%,0.806269,0.845228,0.867109,0.912665
50%,0.824782,0.864697,0.882555,0.925333
75%,0.84184,0.88228,0.890008,0.936478
max,0.869003,0.89681,0.913424,0.948431


# Eval dice classes scores x16

In [13]:
fastmri_mask_x16 = torch.tensor(EquispacedMaskFunc([0.04], [16])((256, 256))[0]).cuda().float()
zm = torch.zeros(256).cuda().float()
zm[256//2 - int(16)//2 : 256//2 + int(16)//2] = 1
fm = torch.ones(256).cuda().float()
w = torch.load('sampling_igs_finetune.pt')

model = Unet(1, 3+1).to(device).train(False).eval()
model.load_state_dict(torch.load('unet-acdc-norot.pt'))
model_zm = Unet(1, 3+1).to(device).train(False).eval()
model_zm.load_state_dict(torch.load('model_zm.pt'))
model_igs = Unet(1, 3+1).to(device).train(False).eval()
model_igs.load_state_dict(torch.load('model_igs.pt'))
model_fastmri = Unet(1, 3+1).to(device).train(False).eval()
model_fastmri.load_state_dict(torch.load('model_fastmri.pt'))

dice_class_scores = dict(
    fastmri=test_on_classes_sampling_pattern(fastmri_mask_x16, model_fastmri, val_generator), 
    center=test_on_classes_sampling_pattern(zm, model_zm, val_generator),
    ours=test_on_classes_sampling_pattern(w, model_igs, val_generator),
    full=test_on_classes_sampling_pattern(fm, model, val_generator),
)

class_map = {0: 'RV cavity', 1: 'LV myo', 2: 'LV cavity'}
for name in dice_class_scores.keys():
    arr = np.array(dice_class_scores[name]).T
    print('##############', name, '##############')
    print(pd.DataFrame.from_dict({class_map[i]:arr[i] for i in range(arr.shape[0])}).describe())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=514.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=514.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=514.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=514.0), HTML(value='')))


############## fastmri ##############
        RV cavity      LV myo   LV cavity
count  514.000000  514.000000  514.000000
mean     0.558414    0.602568    0.721218
std      0.375114    0.249848    0.320781
min      0.000000    0.000000    0.000000
25%      0.000000    0.485560    0.680326
50%      0.758549    0.700854    0.873333
75%      0.870009    0.781476    0.928485
max      0.965863    0.908752    0.975955
############## center ##############
        RV cavity      LV myo   LV cavity
count  514.000000  514.000000  514.000000
mean     0.643160    0.652036    0.748662
std      0.354764    0.266722    0.323344
min      0.000000    0.000000    0.000000
25%      0.452725    0.596957    0.779150
50%      0.824435    0.765891    0.899645
75%      0.902517    0.827424    0.940635
max      0.965468    0.920667    0.984179
############## ours ##############
        RV cavity      LV myo   LV cavity
count  514.000000  514.000000  514.000000
mean     0.654976    0.725562    0.803359
std    