In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os

import cv2
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as functional
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image

In [3]:
from zse.datamodules.components.brain_datasets import ZStackDataset3D
from zse.models.adain_module import AdaINLitModule2D
from zse.models.components.adain_unet import AdaINUNet
from zse.style_transfer.style_transfer import style_transfer
from zse.utils.data_utils import read_h5

In [4]:
def save_nib(data, path, fname):
    os.makedirs(path, exist_ok=True)
    if not os.path.exists(f"{path}/{fname}.nii"):
        nib_image = nib.Nifti1Image(data.mul(255).byte().squeeze().permute(1, 2, 0).numpy(), np.identity(4))
        nib.nifti1.save(nib_image, f"{path}/{fname}.nii")

In [5]:
def make_border(img, size, color):
    top, bottom, left, right = [size]*4
    img_with_border = cv2.copyMakeBorder(np.stack([img, img, img], axis=2), top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
    return img_with_border

In [6]:
def plot(img, title, path=None, **border_kwargs):
    if not len(border_kwargs) == 0:
        img = img.mul(255).byte()
    img = img.squeeze().detach().cpu().numpy()
    if not len(border_kwargs) == 0:
        img = make_border(img, **border_kwargs)
    if path is not None:
        if not len(border_kwargs) == 0:
            plt.imsave(path, img)
        else:
            plt.imsave(path, img, vmin=0, vmax=1, cmap="gray")
    else:
        plt.figure(dpi=100)
        plt.title(title)
        plt.imshow(img, vmin=0, vmax=1, cmap="gray")
        plt.axis("off")
        plt.show()

In [7]:
home = "/p/fastdata/bigbrains/personal/crijnen1"
data_root = f"{home}/data"
zse_path = f".."
model_path = f"{zse_path}/models/brain/adain_unet_3d"
dest = f"{zse_path}/reports/introduction"
torch.hub.set_dir(f"{home}/models")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
vgg19 = models.vgg19(pretrained=True)
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
adain_unet = AdaINUNet(vgg19, norm)

# Z-Stack

In [9]:
z = 29
data_test = ZStackDataset3D(f"{data_root}/bigbrain_1micron/{z}/test/blurry/*.hdf5", transform=transforms.ToTensor())
loader = DataLoader(data_test, batch_size=1, num_workers=32)

In [None]:
for i, batch in enumerate(loader):
    fname = batch["path"][0].split("/")[-1][:-5]
    # save_nib(content, dest, fname)

In [10]:
fname = "B21_0228_y25225_x17625"
to_tensor = transforms.ToTensor()

In [23]:
z = 29
path = f"{data_root}/bigbrain_1micron/{z}/test/blurry/{fname}.hdf5"
content = to_tensor(read_h5(path))[:,:128,:128].unsqueeze(1)
style = content[[14]].expand(z, -1, -1, -1)

In [24]:
blurry = content.squeeze()
plot(blurry[0], "xy_bot", f"{dest}/brain/{fname}_xy_bot.png")
plot(blurry[14], "xy_mid", f"{dest}/brain/{fname}_xy_mid.png")
plot(blurry[28], "xy_top", f"{dest}/brain/{fname}_xy_top.png")
plot(blurry[14], "s15", f"{dest}/brain/{fname}_s15.png", size=3, color=[0,0,255]) # Blue
plot(blurry[19], "s20", f"{dest}/brain/{fname}_s20.png", size=3, color=[0,255,255])
plot(blurry[24], "s25", f"{dest}/brain/{fname}_s25.png", size=3, color=[255,255,0])
plot(blurry[:, :, 63], "xz", f"{dest}/brain/{fname}_xz.png", size=1, color=[85,255,0]) # Green
plot(blurry[:, 63, :], "yz", f"{dest}/brain/{fname}_yz.png", size=1, color=[255,0,0]) # Red

In [25]:
sharp = style.squeeze()
pad_xz = torch.nn.ReflectionPad2d((0, 0, 14, 14))
pad_yz = torch.nn.ReflectionPad2d((z//2, z//2, 0, 0))
style_xz = pad_xz(sharp)
style_yz = pad_yz(sharp)
plot(style_xz[14, 0:29], "xz_bot", f"{dest}/brain/{fname}_xz_bot.png")
plot(style_xz[14, 63-14:63+15], "xz_mid", f"{dest}/brain/{fname}_xz_mid.png")
plot(style_xz[14, -30:-1], "xz_top", f"{dest}/brain/{fname}_xz_top.png")
plot(style_yz[14, :, 0:29], "xz_bot", f"{dest}/brain/{fname}_yz_bot.png")
plot(style_yz[14, :, 63-14:63+15], "xz_mid", f"{dest}/brain/{fname}_yz_mid.png")
plot(style_yz[14, :, -30:-1], "xz_top", f"{dest}/brain/{fname}_yz_top.png")

In [26]:
ckpt_path = f"{model_path}/gram/lr:0.001-beta:1000000.0_best.ckpt"
module = AdaINLitModule2D.load_from_checkpoint(ckpt_path, net=adain_unet, strict=False).to(device)
module.freeze()
module.eval()
out = module({"content": content.to(device), "style": style.to(device)}).squeeze().cpu()
module.cpu();

In [27]:
plot(out[0], "xy_bot", f"{dest}/brain/{fname}_out_xy_bot.png")
plot(out[14], "xy_mid", f"{dest}/brain/{fname}_out_xy_mid.png")
plot(out[-1], "xy_top", f"{dest}/brain/{fname}_out_xy_top.png")
plot(out[:, 0], "xz_bot", f"{dest}/brain/{fname}_out_xz_bot.png")
plot(out[:, 63], "xz_mid", f"{dest}/brain/{fname}_out_xz_mid.png")
plot(out[:, -1], "xz_top", f"{dest}/brain/{fname}_out_xz_top.png")
plot(out[:, :, 0].T, "yz_bot", f"{dest}/brain/{fname}_out_yz_bot.png")
plot(out[:, :, 63].T, "yz_mid", f"{dest}/brain/{fname}_out_yz_mid.png")
plot(out[:, :, -1].T, "yz_top", f"{dest}/brain/{fname}_out_yz_top.png")

In [16]:
z = 20
path = f"{data_root}/bigbrain_1micron/{z}/test/blurry/{fname}.hdf5"
content = to_tensor(read_h5(path))[:,:128,:128].unsqueeze(1)
style = content[[14]].expand(z, -1, -1, -1)

In [20]:
blurry = content.squeeze()
plot(blurry[0], "xy_bot", f"{dest}/brain/{fname}_{z}_bot.png")
plot(blurry[9], "xy_mid", f"{dest}/brain/{fname}_{z}_mid.png")
plot(blurry[19], "xy_top", f"{dest}/brain/{fname}_{z}_top.png")

In [18]:
ckpt_path = f"{model_path}/gram/lr:0.001-beta:1000000.0_best.ckpt"
module = AdaINLitModule2D.load_from_checkpoint(ckpt_path, net=adain_unet, strict=False).to(device)
module.freeze()
module.eval()
out = module({"content": content.to(device), "style": style.to(device)}).squeeze().cpu()
module.cpu();

In [19]:
plot(out[0], "xy_bot", f"{dest}/brain/{fname}_{z}_out_bot.png")
plot(out[9], "xy_mid", f"{dest}/brain/{fname}_{z}_out_mid.png")
plot(out[19], "xy_top", f"{dest}/brain/{fname}_{z}_out_top.png")

# Style Transfer

In [9]:
def test_transform(size, crop):
    transform_list = []
    if size != 0:
        transform_list.append(transforms.Resize(size))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

In [10]:
content_imgs = ["brad_pitt.jpg", "avril.jpg", "cornell.jpg", "chicago.jpg"]
style_imgs = ["sketch.png", "asheville.jpg", "la_muse.jpg", "mondrian_cropped.jpg"]
trans = test_transform(512, False)

In [18]:
for c, s in zip(content_imgs, style_imgs):
    c_path = f"{zse_path}/data/images/content/{c}"
    s_path = f"{zse_path}/data/images/style/{s}"
    content = trans(Image.open(str(c_path))).unsqueeze(0)
    style = trans(Image.open(str(s_path))).unsqueeze(0)
    out = style_transfer(vgg19.features, content, style, content_layers=['relu4_2'], content_weights=1., style_weights=1e7,
                         normalization=norm, optimizer=optim.LBFGS, loss_fn=functional.l1_loss, device=device)
    output_name = f"{dest}/style_transfer/{c[:-4]}_stylized_{s[:-4]}.jpg"
    save_image(out, output_name)