In [None]:
import sys
sys.path.append("..")

import random
import math
import time
import json
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *
from src.models.transform import *
from src.models.util import *

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
from experiments.datasets import all_image_patch_dataset
ds = all_image_patch_dataset((3, 64, 64))

In [None]:
patches = next(iter(DataLoader(ds, batch_size=64)))
patches.shape

In [None]:
VF.to_pil_image(make_grid(patches))

In [None]:
class RandomGradientTransform(nn.Module):

    def __init__(
            self,
            min_quantization: float = 0.1,
            max_quantization: float = 0.3,
            clamp_output: Union[bool, Tuple[float, float]] = (0., 1.),
    ):
        super().__init__()
        self.min_quantization = min_quantization
        self.max_quantization = max_quantization
        self.clamp_output = clamp_output
        
    def forward(self, image: torch.Tensor) -> torch.Tensor:
        if self.clamp_output is True:
            mi, ma = image.min(), image.max()
        q = max(0.000001, random.uniform(self.min_quantization, self.max_quantization))
        x = (image / q).round() * q

        if self.clamp_output is True:
            x = x.clamp(mi, ma)
        elif self.clamp_output:
            x = x.clamp(*self.clamp_output)
            
        return x

VF.to_pil_image(make_grid(
    RandomGradientTransform(min_quantization=.2, max_quantization=.4)(patches)
))

In [None]:
ks = 3
st = 2
pool = nn.MaxPool2d(kernel_size=ks, stride=st, return_indices=True)
input = torch.rand(1, 3, 32, 32)
pooled, idx = pool(input)
print("pooled:", pooled.shape)
pooled_conv = pooled
pooled_conv = nn.Conv2d(3, 3, 3, padding=1, padding_mode="reflect")(pooled)
display(VF.to_pil_image(resize(make_grid(list(pooled) + list(pooled_conv), normalize=True, scale_each=True), 5)))
#print(idx.shape)

unpool = nn.MaxUnpool2d(kernel_size=ks, stride=st)
output = unpool(pooled_conv, idx, output_size=input.shape[-2:])
print(output.shape)
VF.to_pil_image(resize(make_grid(list(input) + list(output)), 5))

In [None]:
class RBDNConv(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int,
            padding: int,
            activation: Union[None, str, Callable],
            batch_norm: bool,
            transposed: bool = False,
    ):
        super().__init__()

        self.conv = (nn.ConvTranspose2d if transposed else nn.Conv2d)(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding,
        )
        self.act = activation_to_module(activation)
        if batch_norm:
            self.bn = nn.BatchNorm2d(out_channels)

    def forward(
            self,
            x: torch.Tensor,
            output_size: Union[None, Tuple[int, int]] = None,
    ) -> torch.Tensor:
        x = self.conv(x)
        if output_size is not None and tuple(x.shape[-2:]) != output_size:
            x = F.pad(x, (0, output_size[-1] - x.shape[-1], 0, output_size[-2] - x.shape[-2]))

        if self.act:
            x = self.act(x)

        if hasattr(self, "bn"):
            x = self.bn(x)

        return x


class RBDNBranch(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            hidden_channels: int,
            num_hidden_layers: int = 1,

            conv_kernel_size: int = 3,
            conv_stride: int = 1,
            conv_padding: int = 0,
            pool_kernel_size: int = 3,
            pool_stride: int = 1,

            hidden_kernel_size: int = 3,
            hidden_stride: int = 1,
            hidden_padding: int = 1,

            batch_norm: bool = True,
            batch_norm_last_layer: bool = False,
            activation: Union[None, str, Callable] = "relu",
            activation_last_layer: Union[None, str, Callable] = "relu",
            sub_branch: Union[None, "RBDNBranch"] = None,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        actual_hidden_channels = hidden_channels
        if sub_branch is not None:
            actual_hidden_channels += sub_branch.out_channels

        self.conv_in = RBDNConv(
            in_channels, hidden_channels, conv_kernel_size, stride=conv_stride, padding=conv_padding,
            activation=activation, batch_norm=batch_norm
        )

        self.sub_branch = sub_branch

        self.pool = nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride, return_indices=True)

        self.hidden = nn.Sequential()
        for i in range(num_hidden_layers):
            self.hidden.add_module(f"conv_{i+1}", RBDNConv(
                actual_hidden_channels, actual_hidden_channels, hidden_kernel_size, stride=hidden_stride, padding=hidden_padding,
                activation=activation, batch_norm=batch_norm
            ))

        self.unpool = nn.MaxUnpool2d(kernel_size=pool_kernel_size, stride=pool_stride)

        self.conv_out = RBDNConv(
            actual_hidden_channels, out_channels, conv_kernel_size, stride=conv_stride, padding=conv_padding,
            activation=activation_last_layer, batch_norm=batch_norm_last_layer,
            transposed=True,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = self.conv_in(input)

        if self.sub_branch is not None:
            sub_x = self.sub_branch(x)
            x = torch.concat([x, sub_x], dim=-3)

        unpooled_size = x.shape[-2:]
        x, indices = self.pool(x)

        # print("hidden:", x.shape)
        x = self.hidden(x)

        x = self.unpool(x, indices, output_size=unpooled_size)
        x = self.conv_out(x, output_size=input.shape[-2:])

        return x

    @torch.no_grad()
    def get_inner_shape(self, shape: Tuple[int, int, int]) -> dict:
        x = self.conv_in(torch.zeros(1, *shape))

        if self.sub_branch is not None:
            branch_shape = self.sub_branch.get_inner_shape(x.shape[-3:])
            sub_x = self.sub_branch(x)
            x = torch.concat([x, sub_x], dim=-3)
        else:
            branch_shape = None
            
        x, indices = self.pool(x)

        ret = {"shape": x.shape[-3:], "branch": branch_shape}
        return ret
            
        
class RBDN(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            hidden_channels: int,
            num_branches: int,
            num_hidden_layers: int = 1,

            conv_kernel_size: int = 3,
            conv_stride: int = 1,
            conv_padding: int = 0,
            pool_kernel_size: int = 3,
            pool_stride: int = 1,

            branch_conv_kernel_size: int = 3,
            branch_conv_stride: int = 2,
            branch_conv_padding: int = 1,
            branch_pool_kernel_size: int = 3,
            branch_pool_stride: int = 1,

            hidden_kernel_size: int = 3,
            hidden_stride: int = 1,
            hidden_padding: int = 1,

            batch_norm: bool = True,
            batch_norm_last_layer: bool = False,
            activation: Union[None, str, Callable] = "relu",
            activaton_last_layer: Union[None, str, Callable] = "sigmoid",
    ):
        super().__init__()

        branches = None
        for i in range(num_branches):
            branches = RBDNBranch(
                hidden_channels, hidden_channels, hidden_channels,
                conv_kernel_size=branch_conv_kernel_size, conv_stride=branch_conv_stride, conv_padding=branch_conv_padding,
                pool_kernel_size=branch_pool_kernel_size, pool_stride=branch_pool_stride,
                hidden_kernel_size=hidden_kernel_size, hidden_stride=hidden_stride, hidden_padding=hidden_padding,
                activation=activation, batch_norm=batch_norm,
                sub_branch=branches,
            )

        self.branches = RBDNBranch(
            in_channels, out_channels, hidden_channels,
            conv_kernel_size=conv_kernel_size, conv_stride=conv_stride, conv_padding=conv_padding,
            pool_kernel_size=pool_kernel_size, pool_stride=pool_stride,
            hidden_kernel_size=hidden_kernel_size, hidden_stride=hidden_stride, hidden_padding=hidden_padding,
            activation=activation,
            batch_norm=batch_norm,
            num_hidden_layers=num_hidden_layers,
            sub_branch=branches,
            batch_norm_last_layer=batch_norm_last_layer,
            activation_last_layer=activaton_last_layer,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.branches(x)

    def get_inner_shape(self, shape: Tuple[int, int, int]) -> dict:
        return self.branches.get_inner_shape(shape)


with torch.no_grad():
    model = RBDN(
        3, 3, 64,
        num_branches=2,
        conv_kernel_size=3,
        conv_padding=1,
        branch_conv_kernel_size=3,
        branch_conv_padding=1,
        branch_conv_stride=2,
        num_hidden_layers=5,
        #branch_pool_kernel_size=1,
        #branch_pool_stride=1,
    )
    
    model.eval()
    print(f"params: {num_module_parameters(model):,}")
    
    input = torch.randn(4, 3, 64, 64)
    #input = torch.randn(1, 3, 128, 128)
    output = model(input)
    #output = model.branches.sub_branch(torch.rand(4, 64, 64, 64))[:, :3, ...]
    
    #output = ((output-.0) * 3).clamp(0, 1)
    print(output.shape)
    display(VF.to_pil_image(make_grid(list(input) + list(output), padding=0)))

    print(json.dumps(model.get_inner_shape(input.shape[-3:]), indent=2))
    
    print(model)

In [None]:
class ConvDenoiser(nn.Module):

    def __init__(
            self,
            shape: Tuple[int, int, int],
            channels: Iterable[int],
            kernel_size: Union[int, Iterable[int]] = 3,
            stride: Union[int, Iterable[int]] = 1,
            batch_norm: bool = True,
            activation: Union[None, str, Callable] = "gelu",
    ):
        super().__init__()
        self.shape = shape

        self._channels = [shape[0], *channels, shape[0]]
        if isinstance(kernel_size, int):
            kernel_sizes = [kernel_size] * len(self._channels)
        else:
            kernel_sizes = list(kernel_size)
            if len(kernel_sizes) != len(self._channels) - 1:
                raise ValueError(f"Expected kernel_size of length {len(self._channels) - 1}, got {len(kernel_sizes)}")

        if isinstance(stride, int):
            strides = [stride] * len(self._channels)
        else:
            strides = list(stride)
            if len(strides) != len(self._channels) - 1:
                raise ValueError(f"Expected stride of length {len(self._channels) - 1}, got {len(strides)}")

        self.encoder = nn.ModuleDict()
        decoder_paddings = []
        with torch.no_grad():
            tmp_state = torch.zeros(1, *shape)

            for i, (ch, ch_next, kernel_size, stride) in enumerate(zip(self._channels, self._channels[1:], kernel_sizes, strides)):
                if batch_norm:
                    self.encoder[f"layer{i+1}_bn"] = nn.BatchNorm2d(ch)
                self.encoder[f"layer{i+1}_conv"] = nn.Conv2d(ch, ch_next, kernel_size, stride=stride)
                if activation:
                    self.encoder[f"layer{i+1}_act"] = activation_to_module(activation)

                in_shape = tmp_state.shape[-2:]
                tmp_state = self.encoder[f"layer{i+1}_conv"](tmp_state)
                dec_shape = nn.ConvTranspose2d(ch_next, ch_next, kernel_size, stride=stride)(tmp_state).shape[-2:]
                decoder_paddings.append(
                    [s - ds for s, ds in zip(in_shape, dec_shape)]
                )

        channels = list(reversed(self._channels))
        kernel_sizes = list(reversed(kernel_sizes))
        strides = list(reversed(strides))
        self.decoder = nn.ModuleDict()
        for i, (ch, ch_next, kernel_size, stride, pad) in enumerate(
                zip(channels, channels[1:], kernel_sizes, strides, list(reversed(decoder_paddings)))
        ):
            if batch_norm:
                self.decoder[f"layer{i+1}_bn"] = nn.BatchNorm2d(ch)

            self.decoder[f"layer{i+1}_conv"] = nn.ConvTranspose2d(
                ch, ch_next, kernel_size, stride=stride, output_padding=pad,
            )
            if activation and i < len(channels) - 2:
                self.decoder[f"layer{i+1}_act"] = activation_to_module(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        state_history = []
        state = x
        for i in range(len(self._channels) - 1):
            if f"layer{i+1}_bn" in self.encoder:
                state = self.encoder[f"layer{i+1}_bn"](state)
            state = self.encoder[f"layer{i+1}_conv"](state)
            if f"layer{i+1}_act" in self.encoder:
                state = self.encoder[f"layer{i+1}_act"](state)
            state_history.append(state)
            #print("ENC", state.shape)

        for i in range(len(self._channels) - 1):
            if i > 0:
                # print("DEC", state.shape, state_history[-(i+1)].shape)
                state = state + state_history[-(i+1)]
            if f"layer{i+1}_bn" in self.decoder:
                state = self.decoder[f"layer{i+1}_bn"](state)
            state = self.decoder[f"layer{i+1}_conv"](state)
            if f"layer{i+1}_act" in self.decoder:
                state = self.decoder[f"layer{i+1}_act"](state)

        return state

model = ConvDenoiser(
    (3, 64, 64), 
    [64, 128, 128, 128, 64],
    stride=[1, 1, 2, 2, 2, 1],
).eval()
print(f"params: {num_module_parameters(model):,}")

input = torch.randn(1, *model.shape)
#input = torch.randn(1, 3, 128, 128)
output = model(input)
print(output.shape)
display(VF.to_pil_image(make_grid([input[0], output[0]])))

print(model)

In [None]:
from src.util.params import *

def map_image_patches(
        image: torch.Tensor,
        function: Callable[[torch.Tensor], torch.Tensor],
        patch_size: Union[int, Tuple[int, int]],
        overlap: Union[int, Tuple[int, int]] = 0,
        cut_away: Union[int, Tuple[int, int]] = 0,
        batch_size: int = 64,
        auto_pad: bool = True,
        window: Union[bool, torch.Tensor, Callable] = False,
        verbose: bool = False,
) -> torch.Tensor:
    """
    Pass an image patch-wise through `function` with shape [batch_size, C, *patch_size]
    and return processed image.
    """
    if image.ndim != 3:
        raise ValueError(f"Expected image.ndim = 3, got {image.shape}")

    patch_size = param_make_tuple(patch_size, 2, "patch_size")
    overlap = param_make_tuple(overlap, 2, "overlap")
    cut_away = param_make_tuple(cut_away, 2, "cut_away")

    for d in (-1, -2):
        if cut_away[d]:
            if cut_away[d] >= patch_size[d] // 2:
                raise ValueError(f"`cut_away` must be smaller than half the patch_size {patch_size}, got {cut_away}")
            if cut_away[d] * 2 + overlap[d] >= patch_size[d]:
                raise ValueError(
                    f"2 * `cut_away` + `overlap` must be smaller than the patch_size {patch_size}"
                    f", got cut_away={cut_away}, overlap={overlap}"
                )

    for i in (-1, -2):
        if overlap[i] >= patch_size[i]:
            raise ValueError(f"`overlap` must be smaller than the patch_size {patch_size}, got {overlap}")

    LEFT, RIGHT, TOP, BOTTOM = range(4)
    padding = [0, 0, 0, 0]

    is_cut_away = bool(any(cut_away))
    if is_cut_away:
        overlap = tuple(o + c * 2 for o, c in zip(overlap, cut_away))
        padding[LEFT] = cut_away[-1]
        padding[TOP] = cut_away[-2]
        
    stride = [patch_size[0] - overlap[0], patch_size[1] - overlap[1]]

    if isinstance(window, torch.Tensor):
        if window.shape != patch_size:
            raise ValueError(
                f"`window` must match patch_size {patch_size}, got {window.shape}"
            )
    elif window is True:
        window = get_image_window(shape=patch_size)
    elif callable(window):
        window = get_image_window(shape=patch_size, window_function=window)
    else:
        window = None

    if is_cut_away and window is not None:
        window = window[cut_away[-2]: window.shape[-2] - cut_away[-2], cut_away[-1]: window.shape[-1] - cut_away[-1]]

    image_shape = (image.shape[-2] + padding[TOP], image.shape[-1] + padding[LEFT])
                   
    if auto_pad:
        for d, pad_pos in ((-1, RIGHT), (-2, BOTTOM)):
            #grid_size = max(1, int(math.ceil((image_shape[d] + overlap[d] + stride[d] - 1) / stride[d])))
            #recon_size = grid_size * stride[d] #+ overlap[d] #- cut_away[d] 
            # print(f"X dim={d} grid{grid_size} * stride{stride[d]} + over{overlap[d]} = {recon_size}")
            #while recon_size < image_shape[d]:
            #    grid_size += 1
            #    recon_size = grid_size * stride[d] #+ overlap[d] #- cut_away[d]
            grid_size = max(1, int(math.ceil(image_shape[d] / stride[d])))
            while True:
                needed_size = grid_size * stride[d] + overlap[d]
                if needed_size >= image_shape[d]:
                    break
                print(f"  INCREASED dim={d} needed={needed_size} image={image_shape[d]}")
                grid_size += 1
                
            if needed_size > image_shape[d]:
                padding[pad_pos] = needed_size - image_shape[d]
                print(f"  needed > image, added padding {padding[pad_pos]}")

    # padding[BOTTOM] += 2
    if any(padding):
        image = F.pad(image, padding)

    print(f"stride={stride} grid={grid_size} image={image.shape[-2:]} pad={padding} overlap={overlap}")

    output = torch.zeros_like(image)
    output_sum = torch.zeros_like(image[0])

    for patch_batch, pos_batch in iter_image_patches(
            image=image,
            shape=patch_size,
            stride=stride,
            batch_size=batch_size,
            with_pos=True,
            verbose=verbose,
    ):
        patch_batch = function(patch_batch)
        for patch, pos in zip(patch_batch, pos_batch):
            ps = patch_size
            if is_cut_away:
                patch = patch[..., cut_away[-2]: patch.shape[-2] - cut_away[-2], cut_away[-1]: patch.shape[-1] - cut_away[-1]]
                pos = (pos[0] + cut_away[0], pos[1] + cut_away[1])
                ps = patch.shape[-2:]

            add_pixel = 1.
            if window is not None:
                patch *= window
                add_pixel = window

            print(f"ADDPATCH x={pos[-1]} y={pos[-2]} ps={ps} pixel={add_pixel}")

            output_sum[pos[-2]: pos[-2] + ps[-2], pos[-1]: pos[-1] + ps[-1]] += add_pixel
            output[:, pos[-2]: pos[-2] + ps[-2], pos[-1]: pos[-1] + ps[-1]] += patch

    display(VF.to_pil_image(resize(output_sum.unsqueeze(0) / output_sum.max(), 4)))
    #print(output)
    
    mask = output_sum > 0
    output[:, mask] /= output_sum[mask].unsqueeze(0)

    if any(padding):
        output = output[
            ...,
            padding[TOP]: output.shape[-2] - padding[BOTTOM],
            padding[LEFT]: output.shape[-1] - padding[RIGHT],
        ]

    return output

#  0123456789
# .##.
#   .##.
#     .##.
#       .##.

#  0         1         2         3         4         5         6   4     7 2       8      |
#  012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789
#  ---...........................---                               |
#                    ---...........................---             |
#                                      ---...........................---
#                                                        ---...........................---
image = torch.ones(1, 64, 64) * .3
output = map_image_patches(
    image, patch_size=(32, 32), overlap=1, cut_away=5, #auto_pad=False,
    function=lambda x: 1. - x,
)
print(output.mean())
VF.to_pil_image(resize(make_grid([image, output, output]), 4))


In [None]:
(64 + 1) / 64

In [None]:
x = 0
for i in range(4):
    print(x + 32 - 2)
    x += 17
    #(80-2-10) / 17

In [None]:
model = ConvDenoiser(
    (3, 64, 64), [64, 128, 256],
    stride=[1, 2, 2, 1],
    kernel_size=3,
)
state = torch.load("../checkpoints/denoise/conv02-ks-3_chan-64,128,256_stride-1,2,2,1/best.pt")
print(f'inputs: {state["num_input_steps"]:,}')
model.load_state_dict(state["state_dict"])
print(f"params: {num_module_parameters(model):,}")

In [None]:
model = ConvDenoiser(
    (3, 64, 64), 
    [64, 128, 128, 128, 64],
    stride=[1, 1, 2, 2, 2, 1],
    kernel_size=3,
)
state = torch.load("../checkpoints/denoise/conv04-ks-3_chan-64,128,128,128,64_stride-1,1,2,2,2,1/snapshot.pt")
print(f'inputs: {state["num_input_steps"]:,}')
model.load_state_dict(state["state_dict"])
print(f"params: {num_module_parameters(model):,}")

In [None]:
from src.models.img2img import RBDN

model = RBDN(
    3, 3, 64,
    num_branches=2,
    num_hidden_layers=1,
    
    conv_kernel_size=9,
    conv_padding=1,
    
    branch_conv_kernel_size=9,
    branch_conv_padding=1,
    branch_conv_stride=2,
).eval()
state = torch.load("../checkpoints/denoise/rbdn-noise-strong-ks-9_chans-64_layers-1_blayers-1_nbranch-2_bstride-2_pad-1/snapshot.pt")
print(f'inputs: {state["num_input_steps"]:,}')
model.load_state_dict(state["state_dict"])
print(f"params: {num_module_parameters(model):,}")

In [None]:
image = VF.to_tensor(PIL.Image.open(
    #"/home/bergi/Pictures/there_is_no_threat.jpeg"
    "/home/bergi/Pictures/__diverse/gordon_brown.jpg",
))

VF.to_pil_image(image)

In [None]:
with torch.no_grad():
    part = 0+image#[:, 0:128, 0:128]
    noise = torch.randn_like(part) 
    noise = noise[:1].repeat(3, 1, 1)
    #noise = VF.gaussian_blur(noise, 5, 2)
    part = (part + noise * .3).clamp(0, 1)
    output = model(part.unsqueeze(0))[0]
    display(VF.to_pil_image(make_grid([part, output])))

In [None]:
@torch.no_grad()
def denoise(
        image: torch.Tensor, 
        denoiser: nn.Module, 
        batch_size: int = 64,
        overlap: Union[int, Tuple[int, int]] = 0,
        auto_pad: bool = True,
        verbose: bool = True,
):
    denoiser.eval()
    return map_image_patches(
        image=image, 
        function=lambda x: denoiser(x).clamp(0, 1),
        patch_size=denoiser.shape[-2:],
        batch_size=batch_size,
        overlap=overlap,
        auto_pad=auto_pad,
        window=True if overlap else False,
        verbose=verbose,
    )

image2 = image
#image2 = resize(image, .8)
denoised = denoise(
    image2 + 0.1 * torch.randn_like(image2),
    model,
    overlap=16,
)
for i in range(3):
    denoised = denoise(denoised, model, overlap=12+i)
display(VF.to_pil_image(denoised))
display(VF.to_pil_image(signed_to_image(image2 - denoised)))

In [None]:
@torch.no_grad()
def denoise_image(
        image: Union[torch.Tensor, str, Path], 
        #denoiser: nn.Module, 
        overlap: Union[int, Tuple[int, int]] = 0,
):
    if not isinstance(image, torch.Tensor):
        image = VF.to_tensor(PIL.Image.open(image))[:3]
    
    denoised = denoise(
        image, # + 0.1 * torch.randn_like(image),
        model,
        overlap=overlap,
    )
    #for i in range(3):
    #    denoised = denoise(denoised, model, overlap=12+i)
    images = [
        image, denoised, signed_to_image(image - denoised)
    ]
    images = make_grid(images)
   # images = resize(images, 2)
    display(VF.to_pil_image(images))
#display(VF.to_pil_image(signed_to_image(image2 - denoised)))
denoise_image(
    #"/home/bergi/Pictures/clipig2/cthulhu-fractal-01.png",
    #"~/Pictures/clipig2/cthulhu-fractal-01-many-steps.png",
    #"/home/bergi/Pictures/clipig2/fractal-industrial-pipes.png",
    #"/home/bergi/Pictures/clipig2/fractal-escher-2.png",
    #"/home/bergi/Pictures/clipig2/fractal-escher.png",
    #"/home/bergi/Pictures/clipig2/pixelart-fish.png",
    #"/home/bergi/Pictures/clipig2/monkey-island.png",
    #"/home/bergi/Pictures/clipig2/rocky-surface.png",
    #"/home/bergi/Pictures/clipig2/waterfall-06-c.png",
    "/home/bergi/Pictures/clipig2/maze-of-pipes-01.png",

    overlap=3,
)

In [None]:
VF.to_pil_image(signed_to_image(image - denoised))

In [None]:
with torch.no_grad():
    conv = nn.Conv2d(1, 1, 3, padding=1)
    x = torch.zeros(1, 32, 32)
    for i in range(14):
        x = conv(x)    
        print("  ", x.shape)
    print(x.shape)