In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
from matplotlib import pyplot as plt
from torchvision import models, transforms

In [3]:
from zse.datamodules.leishmania import LeishmaniaDataModule
from zse.style_transfer.style_transfer import style_transfer

In [9]:
home = "/p/fastdata/bigbrains/personal/crijnen1"
data_root = f"{home}/data"
zse_path = f"{home}/Z-Stack-Enhancement"
fig_path = f"{zse_path}/reports/figures/paper"
exp_path = f"{zse_path}/logs/experiments/runs"
dest = f"{data_root}/COMI/Leishmania/predictions/gatys_style"
torch.hub.set_dir(f"{home}/models")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
def dict_to(d: dict, device):
    for k in d.keys():
        if isinstance(d[k], dict):
            d[k] = dict_to(d[k], device)
        elif isinstance(d[k], torch.Tensor):
            d[k] = d[k].to(device=device)
    return d

In [6]:
vgg19 = models.vgg19(pretrained=True).features
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

In [7]:
path = f"{data_root}/COMI/Leishmania/Leishmania_blurred_train/*.jpg"
data_module = LeishmaniaDataModule(path, imsize=512, batch_size=1)
data_module.setup()
loader = data_module.test_dataloader()

In [10]:
for batch in loader:
    content: torch.Tensor = batch["content"]
    style: torch.Tensor = batch["style"]
    fname: str = batch['path'][0].split('/')[-1]

    out = style_transfer(vgg19, content, style, style_weights=1e4, normalization=norm, optimizer=torch.optim.LBFGS, num_iter=200, device=device)
    out = out.squeeze().permute(1, 2, 0).detach().cpu().numpy()
    if not os.path.exists(f"{dest}/{fname}"):
        plt.imsave(f"{dest}/{fname}", out, vmin=0, vmax=1)