In [None]:
from init_notebook import *
from src.train.experiment import load_experiment_trainer
from functools import partial

In [None]:
def plot(ds, count=16*16):
    batch = next(iter(DataLoader(ds, batch_size=count)))
    if isinstance(batch, (tuple, list)):
        images = batch[0]
        for b in batch[1:]:
            if isinstance(b, torch.Tensor) and b.shape[-3:] == images.shape[-3:]:
                images = torch.cat([images, b], dim=0)
    else:
        images = batch
        
    display(VF.to_pil_image(make_grid(images, nrow=int(math.sqrt(count)))))


## play with model

In [None]:
trainer = load_experiment_trainer("../experiments/img2img/extrusion/extrusion-simple.yml", device="cpu")
assert trainer.load_checkpoint("snapshot")
model = trainer.model

In [None]:
from PIL import ImageDraw, ImageFont

In [None]:
image_v = PIL.Image.open("../datasets/extrusion/validation/source/008.png")
image_v = VF.to_tensor(image_v)[:, 200:400, 200:400]
VF.to_pil_image(image_v)

In [None]:
fonts_and_sizes = [
    ("/home/bergi/.local/share/fonts/LEMONMILK-LIGHTITALIC.OTF", 20),
    ("/home/bergi/.local/share/fonts/LEMONMILK-MEDIUMITALIC.OTF", 20),
    #("/home/bergi/.local/share/fonts/unscii-16-full.ttf", 25),
    ("/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf", 25),
    ("/usr/share/fonts/truetype/open-sans/OpenSans-ExtraBold.ttf", 25),
    ("/usr/share/fonts/truetype/open-sans/OpenSans-ExtraBold.ttf", 40),
]
fonts = [
    ImageFont.truetype(file, int(size * 1.3))
    for file, size in fonts_and_sizes
]
image = PIL.Image.new("RGB", (200, 40 * len(fonts) + 24))
draw = ImageDraw.ImageDraw(image)
for i, font in enumerate(fonts):
    draw.text((6, 40 * i + 2), "hello world", font=font, fill=(255, 255, 255))
image = VF.to_tensor(image)
VF.to_pil_image(image)

In [None]:
with torch.no_grad():
    noisy_image = (image - image * torch.randn_like(image[:1]) * .2).clamp(0, 1)
    model.eval()
    output1 = model(image.unsqueeze(0)).squeeze(0).clamp(0, 1)
    output2 = model(noisy_image.unsqueeze(0)).squeeze(0).clamp(0, 1)
    grid = make_grid([image, noisy_image, output1, output2], nrow=2).clamp(0, 1)
    display(VF.to_pil_image(resize(grid, 3)))

In [None]:
from scripts.extrusion_dataset.render import get_light_map
normals = output1.numpy() * 2 - 1
light = torch.Tensor(get_light_map(normals, [-1, 2, 3])).unsqueeze(0).repeat(3, 1, 1)
#light *= image
#light = (light * 255).to(torch.int)
resize(VF.to_pil_image(light), 3)

In [None]:
image2 = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/eisenach/wartburg.jpg"))
image2 = resize(image2, .25, VF.InterpolationMode.BICUBIC)
image2 = (1. - image2).clamp(0, 1)
VF.to_pil_image(image2)

In [None]:
with torch.no_grad():
    output = model(image2.unsqueeze(0)).squeeze(0)
    display(VF.to_pil_image(resize(make_grid([image2, output.clamp(0, 1)], nrow=1), 2)))

In [None]:
images = []
for p in (
    f"../datasets/extrusion/validation/source/tubes-01.png",
    f"../datasets/extrusion/validation/target/tubes-01.png"
):
    i1 = PIL.Image.open(p)
    i1 = VF.to_tensor(i1)[:, :256, :256]
    images.append(i1)
VF.to_pil_image(make_grid(images, padding=10))