In [29]:
%load_ext autoreload
%autoreload 2

import os

os.chdir('/home/tobias/Projects/bernoulli-mri')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
import random

import matplotlib.pyplot as plt
from torch.optim import Adam
from tqdm import tqdm
import torch
from torch import nn
from torchvision.utils import save_image
from torch.utils.data import DataLoader

from src.config import get_configuration
from src.constraint import ScoreConstrainer
from src.dense_rate import DenseRateScheduler
from src.optimization import get_mask_handler
from src.utils import (
    get_temperature,
    ifft2c,
    plot_heatmap,
    min_max_normalize,
)
from src.datasets import ACDCDataset

In [31]:
BASE_CONFIG = {
    'file_path': 'data/ACDC',
    'slice_idx': 17,
    'coil_idx': None,
    'cropping': (320, 320),
    'steps': 2500,
    'learning_rate': 1e-2,
    'bern_samples': 4,
    'mask_style': 'f',
    'dense_target': 1 / 8,
    'dense_start': 0.10,
    'dense_end': 0.85,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': 'logs/acdc',
    'log_imgs': 10,
}

cfg = get_configuration(BASE_CONFIG)

In [32]:
loss_func = nn.MSELoss()
#loss_func = nn.MSELoss(reduction='none')

In [33]:
dense_scheduler = DenseRateScheduler(
            target=cfg.dense_target,
            start_epoch=int(cfg.dense_start * cfg.steps),
            stop_epoch=int(cfg.dense_end * cfg.steps),
        )

In [34]:
torch.manual_seed(42)
num_samples = 64
acdc = ACDCDataset(cfg.file_path, train=True)

samples = next(iter(DataLoader(acdc, batch_size=num_samples, shuffle=True)))
img = samples['img']
img_k = samples['k_space']

seg_weight = 50
seg = samples['seg']
seg[seg != 0] = seg_weight
seg[seg == 0] = 1
seg = torch.repeat_interleave(seg, repeats=cfg.bern_samples, dim=0).to(cfg.device)

In [None]:
img = img.to(cfg.device)
img_k = img_k.to(cfg.device)

img_k_batch = torch.repeat_interleave(img_k, repeats=cfg.bern_samples, dim=0)
img_batch = torch.repeat_interleave(img, repeats=cfg.bern_samples, dim=0)

img_low_q = torch.quantile(img, 0.01)
img_high_q = torch.quantile(img, 0.99)

# Create mask handler that also contains scores
mask_handler = get_mask_handler(
    name=cfg.mask_style,
    height=img.shape[-2],
    width=img.shape[-1],
    device=cfg.device,
)

# Initialize optimizer with mask scores
optimizer = Adam([mask_handler.get_scores()], lr=cfg.learning_rate)

# Create constrainer for projection of scores to dense_rate
constrainer = ScoreConstrainer(mask_handler.get_scores())

log_recs = []
log_masks = []

for step in (pbar := tqdm(range(1, cfg.steps + 1))):

    # Get temperature for categorical "softness"
    temperature = get_temperature(step, cfg.steps)

    # Map scores to valid probability space
    dense_rate = dense_scheduler.get_dense_rate()
    constrainer.constrain(dense_rate=dense_rate)
    dense_scheduler.advance()

    # Sample from distribution
    mask = mask_handler.sample_mask(
        temperature=temperature, num_samples=cfg.bern_samples * img.shape[0]
    )

    # Compute image with mask
    img_pred = ifft2c(img_k_batch * mask + 0.0)
    img_mag = torch.abs(img_pred)

    # Compute loss between full and undersampled image
    loss = loss_func(img_mag, img_batch)

    #loss *= seg
    #loss = torch.mean(loss)

    # Optimize scores
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 20 == 0:
        pbar.set_description(
            'L: {:.2E} | D: {:.3f}'.format(
                float(loss), float(torch.mean(mask_handler.get_scores()))
            )
        )

    num_imgs = cfg.log_imgs
    if num_imgs > 0 and step % (cfg.steps // num_imgs) == 0:
        img_rec = torch.abs(img_pred[0])
        img_rec = min_max_normalize(img_rec, img_low_q, img_high_q)

        log_recs.append(img_rec.detach().cpu())
        log_masks.append(mask[0].detach().cpu())

# ------------------------------------------------------------------------------

# Construct training progress grid
os.makedirs(cfg.log_dir, exist_ok=True)
img_org = min_max_normalize(img[0], img_low_q, img_high_q)
log_recs.append(img_org.cpu())
log_masks.append(torch.zeros_like(img[0]).cpu())

save_image(
    log_recs + log_masks,
    fp=cfg.log_dir + '/progress.png',
    nrow=len(log_recs),
)

# Construct heatmap
plot_heatmap(
    mask_handler.get_mask_distribution(),
    save_path=cfg.log_dir + '/heatmap.png',
)

# Construct histogram of scores
plt.hist(mask_handler.get_scores().detach().cpu().flatten(), bins=20)
plt.savefig(cfg.log_dir + '/histogram.png')

L: 2.16E-02 | D: 0.557:   1%|▏         | 37/2500 [00:03<03:35, 11.41it/s]