In [None]:
import sys
sys.path.append("..")

import io
import os
import random
import math
import time
import json
import shutil
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *
from src.models.transform import *
from src.models.util import *
from experiments import datasets
from experiments.denoise.resconv import ResConv

import yaml
import ipywidgets
from src.clipig.clipig_task import ClipigTask

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    if isinstance(img, PIL.Image.Image):
        shape = (img.height, img.width)
    else:
        shape = img.shape[-2:]
    return VF.resize(img, [max(1, int(s * scale)) for s in shape], mode, antialias=False)

class ImageWidget(ipywidgets.Image):

    def set_pil(self, image: PIL.Image.Image):
        fp = io.BytesIO()
        image.save(fp, "png")
        fp.seek(0)
        self.format = "png"
        self.value = fp.read()

    def set_torch(self, image: torch.Tensor):
        image = VF.to_pil_image(image)
        self.set_pil(image)

In [None]:
image = VF.to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/diffusion/cthulhu-16.jpeg"
    #"/home/bergi/Pictures/diffusion/cables-14.jpeg"
))

In [None]:
def downsample(image, factor=.5):
    org_size = image.shape
    image = resize(image, factor)
    image = VF.resize(image, org_size[-2:], interpolation=VF.InterpolationMode.BILINEAR, antialias=True)
    return image

dimage = downsample(image)
s1, s2 = slice(100, 132), slice(100, 132)
patch = image[:, s1, s2]
dpatch = dimage[:, s1, s2]
grid = make_grid([patch, dpatch, (patch - dpatch).abs()])
VF.to_pil_image(resize(grid, 5))

In [None]:
config = """
clip_model_name: ViT-B/32
device: auto
initialize: random
num_iterations: 10000
source_model:
  name: pixels
  params:
    channels: RGB
    size:
    - 224
    - 224
targets:
- batch_size: 2
  optimizer:
    betas:
    - 0.9
    - 0.999
    learnrate: 0.01
    optimizer: Adam
    weight_decay: 0
  target_features:
  - image: ''
    text: fisheye view of a cthulhu fractal
    type: image
    weight: 1.0
  transformations:
  - name: padding
    params:
      active: true
      pad_left: 100
      pad_right: 100
      pad_top: 100
      pad_bottom: 100
      padding_mode: symmetric
  - name: random_scale
    params:
      active: true
      scale_min_xy: [.4, .4]
      scale_max_xy: [1., 1.]
  - name: random_affine
    params:
      active: true
      degrees_min_max:
      - -5.6
      - 5.0
      interpolation: bilinear
      scale_min_max:
      - 0.9
      - 1.1
      shear_min_max:
      - -15.0
      - 15.0
      translate_xy:
      - 0.01
      - 0.01
  - name: random_crop
    params:
      active: true
      pad_if_needed: true
      padding_mode: constant
      size: 224
  - name: multiplication
    params:
      active: false
      add: 0.0
      multiply: 0.1
  - name: blur
    params:
      active: false
      kernel_size:
      - 3
      - 3
      mix: 0.7
      sigma:
      - 1.0
      - 1.0
"""

In [None]:
def run_config(
        config: str,
        image: torch.Tensor,
        num_iterations: int = 1000,
        max_mae: float = 0.1,
        preview: bool = True,
):
    #image = VF.resize(image, (224, 224), VF.InterpolationMode.BILINEAR, antialias=True).cuda()
    image = image.cuda()
    
    fp = io.StringIO(config)
    config = yaml.safe_load(fp)
    config["num_iterations"] = num_iterations
    config["pixel_yield_delay_sec"] = 0.
    config["initialize"] = "input"
    config["input_image"] = image
    config["source_model"]["params"]["size"] = tuple(reversed(image.shape[-2:]))
    config["targets"][0]["target_features"][0]["image"] = image

    if preview:
        image_widget = ImageWidget()
        display(image_widget)
    status_widget = ipywidgets.Text()
    display(status_widget)

    task = ClipigTask(config)    
    status = "requested"
    
    try:
        with tqdm(total=num_iterations) as progress:
            mae = 0
            for event in task.run():
                if "status" in event:
                    status = event["status"]
        
                if "pixels" in event:
                    pixels = event["pixels"].clamp(0, 1)
                    progress.update(1)
                    if preview:
                        image_widget.set_torch(resize(pixels, 1))

                    mae = (pixels - image).abs().mean()
                    if mae >= max_mae:
                        break
                    
                status_widget.value = (
                    f"status: {status}, mae={mae}"
                )
                
    except KeyboardInterrupt:
        print("stopped")
        pass

    return pixels.detach().cpu()

image = VF.to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/diffusion/cthulhu-16.jpeg"
    #"/home/bergi/Pictures/diffusion/cables-14.jpeg"
    #"../datasets/declipig/original/100_1571.jpg"
))


noisy_image = run_config(
    config,
    image,
    #downsample(image),
    #num_iterations=10,
    max_mae=0.07,
)

In [None]:
VF.to_pil_image(
    (image - noisy_image).abs()*3
)

In [None]:
STORAGE_PATH = Path("../datasets/declipig/clipig-noise-07")
os.makedirs(STORAGE_PATH, exist_ok=True)
for filename in sorted(Path("../datasets/declipig/original/").glob("*.*")):
    target_filename = STORAGE_PATH / f"{filename.name}.pt"
    if target_filename.exists():
        continue
        
    print(filename)
    image = VF.to_tensor(PIL.Image.open(filename))
    
    noisy_image = run_config(
        config,
        image,
        #downsample(image),
        #num_iterations=10,
        max_mae=0.07,
        preview=False,
    )
    noise = noisy_image - image
    torch.save(noise, target_filename)
    #VF.to_pil_image(noisy_image).save(str(target_filename))
    

In [None]:
from src.datasets.base_iterable import BaseIterableDataset

In [None]:
class ClipNoiseDataset(BaseIterableDataset):
    def __init__(
            self,
            patch_size: Tuple[int, int],
            interleave_images: int = 1,
    ):
        self._patch_size = patch_size
        self._interleave_images = interleave_images
        self._directory_orig = Path("~/Pictures/diffusion").expanduser()
        #self._directory_noisy = Path(__file__).resolve().parent.parent.parent / "datasets/diffusion-clip-noised"
        self._directory_noisy = Path("../datasets/diffusion-clip-noised/")
        
    def __iter__(self):
        ps = self._patch_size
        image_pairs = []
        image_pair_iterable = self._iter_image_pairs()
        iter_count = -1
        while True:
            iter_count += 1
            while len(image_pairs) < self._interleave_images:
                try:
                    image_orig, image_noisy = next(image_pair_iterable)
                except StopIteration:
                    break

                size = image_orig.shape[-2:]
                count = (size[-2] // ps[-2]) * (size[-1] // ps[-1])
                count *= 3

                image_pairs.append({"count": count, "images": (image_orig, image_noisy)})

            if not image_pairs:
                break
        
            pair_index = iter_count % len(image_pairs)
            image_orig, image_noisy = image_pairs[pair_index]["images"]

            size = image_orig.shape[-2:]

            pos = (
                random.randrange(0, size[-2] - ps[-2]),
                random.randrange(0, size[-2] - ps[-2])
            )

            patch_orig = image_orig[:, pos[-2]: pos[-2] + ps[-2], pos[-1]: pos[-1] + ps[-1]]
            patch_noisy = image_noisy[:, pos[-2]: pos[-2] + ps[-2], pos[-1]: pos[-1] + ps[-1]]

            yield patch_orig, patch_noisy

            image_pairs[pair_index]["count"] -= 1
            if image_pairs[pair_index]["count"] <= 0:
                image_pairs.pop(pair_index)

    def _iter_image_pairs(self):
        for filename in sorted(self._directory_noisy.glob("*.jpeg")):
            image_noisy = VF.to_tensor(PIL.Image.open(filename))
            image_orig = VF.to_tensor(PIL.Image.open(
                self._directory_orig / filename.name #[:-4]
            ))
            yield image_orig, image_noisy

ds = ClipNoiseDataset((64, 64), interleave_images=3)
#next(iter(ds))
VF.to_pil_image(make_grid(ds.sample(8*8)[0]))
#ds.sample(8*8)[0].shape

In [None]:
for p in tqdm(ds):
    pass