In [None]:
from init_notebook import *
from src.train.experiment import load_experiment_trainer
from functools import partial

In [None]:
from experiments.datasets import *

In [None]:
def plot(ds, count=16*16):
    batch = next(iter(DataLoader(ds, batch_size=count)))
    if isinstance(batch, (tuple, list)):
        images = batch[0]
        for b in batch[1:]:
            if isinstance(b, torch.Tensor) and b.shape[-3:] == images.shape[-3:]:
                images = torch.cat([images, b], dim=0)
    else:
        images = batch
        
    display(VF.to_pil_image(make_grid(images, nrow=int(math.sqrt(count)))))


In [None]:
from src.datasets.base_dataset import BaseDataset
from torchvision.datasets.folder import is_image_file

class ImageSourceTargetDataset(BaseDataset):
    def __init__(
            self,
            path: Union[str, Path],
            source_subpath: str = "source",
            target_subpath: str = "target",
    ):
        path = Path(path)
        self._source_path = path / source_subpath
        self._target_path = path / target_subpath
        self._source_images: Dict[str, Optional[torch.Tensor]] = {}
        self._target_images: Dict[str, Optional[torch.Tensor]] = {}

        for filename in sorted(self._source_path.glob("*")):
            if is_image_file(str(filename)):
                self._source_images[filename.name] = VF.to_tensor(PIL.Image.open(filename))

        for filename in sorted(self._target_path.glob("*")):
            if is_image_file(str(filename)):
                self._target_images[filename.name] = VF.to_tensor(PIL.Image.open(filename))

        if sorted(self._source_images) != sorted(self._target_images):
            raise RuntimeError(f"Source and target filenames are not identical")

        self._index = {
            i: key
            for i, key in enumerate(self._source_images)
        }
        
    def __len__(self):
        return len(self._source_images)
        
    def __getitem__(self, idx: int):
        key = self._index[idx]
        return self._source_images[key], self._target_images[key]


class ImageSourceTargetCropDataset(BaseDataset):
    def __init__(
            self,
            path: Union[str, Path],
            shape: Tuple[int, int],
            num_crops: int,  # per image
            source_subpath: str = "source",
            target_subpath: str = "target",
            random: bool = False,
    ):
        self._dataset = ImageSourceTargetDataset(path=path, source_subpath=source_subpath, target_subpath=target_subpath)
        self._shape = shape
        self._num_crops = num_crops
        self._random = random
        if not self._random:
            self._crop_positions = []
            rng = globals()["random"].Random(23)
            for idx in range(len(self._dataset) * self._num_crops):
                image_idx = idx % len(self._dataset)
                source_image, target_image = self._dataset[image_idx]
                self._crop_positions.append((image_idx, *self._get_crop_pos(source_image, rng)))
    
    def __len__(self):
        return len(self._dataset) * self._num_crops
        
    def __getitem__(self, idx: int):
        if self._random:
            image_idx = random.randrange(len(self._dataset))
            
            source_image, target_image = self._dataset[image_idx]
            assert source_image.shape == target_image.shape

            x, y = self._get_crop_pos(source_image, random)
        else:
            image_idx, x, y = self._crop_positions[idx]
            source_image, target_image = self._dataset[image_idx]
            assert source_image.shape == target_image.shape

        return (
            source_image[..., y: y + self._shape[0], x: x + self._shape[1]],
            target_image[..., y: y + self._shape[0], x: x + self._shape[1]],
        )
        
    def _get_crop_pos(self, image: torch.Tensor, rng: random.Random) -> Tuple[int, int]:
        H, W = image.shape[-2:]
        if self._shape[0] > H or self._shape[1] > W:
            raise RuntimeError(f"Crop shape {self._shape} is too large for image {image.shape}")
        x = rng.randrange(W - self._shape[1])
        y = rng.randrange(H - self._shape[0])
        return x, y

ds = ImageSourceTargetCropDataset("../datasets/shiny-tubes/train", (32, 32), 5, random=False)

plot(ds)

## play with model

In [None]:
trainer = load_experiment_trainer("../experiments/img2img/shinytubes-spikes-gate.yml", device="cpu")
assert trainer.load_checkpoint("snapshot")
model = trainer.model

In [None]:
from PIL import ImageDraw, ImageFont

In [None]:
image_v = PIL.Image.open("../datasets/shiny-tubes2/validation/source/tubes-01.png")
image_v = VF.to_tensor(image_v)[:, :100, :100]
VF.to_pil_image(image_v)

In [None]:
font = ImageFont.truetype(
    #"/home/bergi/.local/share/fonts/LEMONMILK-LIGHTITALIC.OTF", 20
    "/home/bergi/.local/share/fonts/LEMONMILK-MEDIUMITALIC.OTF", 20
    #"/home/bergi/.local/share/fonts/unscii-16-full.ttf", 25
    #"/usr/share/fonts/truetype/open-sans/OpenSans-ExtraBold.ttf", 25
    #"/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf", 25
)
image = PIL.Image.new("RGB", (200, 40))
draw = ImageDraw.ImageDraw(image)
draw.text((30, 7), "hello world", font=font, fill=(255, 255, 255))
image = VF.to_tensor(image)
VF.to_pil_image(image)

In [None]:
with torch.no_grad():
    noisy_image = (image - image * torch.randn_like(image[:1]) * .4).clamp(0, 1)
    model.eval()
    output1 = model(image.unsqueeze(0)).squeeze(0).clamp(0, 1)
    output2 = model(noisy_image.unsqueeze(0)).squeeze(0).clamp(0, 1)
    grid = make_grid([image, noisy_image, output1, output2], nrow=2).clamp(0, 1)
    display(VF.to_pil_image(resize(grid, 3)))

In [None]:
image2 = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/eisenach/wartburg.jpg"))
image2 = resize(image2, .25, VF.InterpolationMode.BICUBIC)
image2 = (1. - image2).clamp(0, 1)
VF.to_pil_image(image2)

In [None]:
with torch.no_grad():
    output = model(image2.unsqueeze(0)).squeeze(0)
    display(VF.to_pil_image(resize(make_grid([image2, output.clamp(0, 1)], nrow=1), 2)))

In [None]:
images = []
for p in (
    f"../datasets/shiny-tubes3/validation/source/tubes-01.png",
    f"../datasets/shiny-tubes3/validation/target/tubes-01.png"
):
    i1 = PIL.Image.open(p)
    i1 = VF.to_tensor(i1)[:, :256, :256]
    images.append(i1)
VF.to_pil_image(make_grid(images, padding=10))