In [None]:
import sys
sys.path.append("..")

import random
import math
import time
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *
from src.models.util import *

In [None]:
def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for idx, entry in enumerate(tqdm(iterable, total=total)):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry) if callable(label) else idx)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

In [None]:
SHAPE = (3, 64, 64)
dataset = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt")[:1000])
dataset = TransformDataset(dataset, dtype=torch.float, multiply=1./255.)
print(len(dataset))
VF.to_pil_image(make_grid_labeled(
    [i[0] for i, _ in zip(dataset, range(8*8))]
))

In [None]:
def image_filter(image: torch.Tensor) -> torch.Tensor:
    spec = torch.fft.fft2(image)
    spec[..., 1:, 1:] = 0
    #return spec
    repro = torch.fft.ifft2(spec).real
    #return repro
    return (image - repro).clamp_min(0)

images = next(iter(DataLoader(dataset, batch_size=8)))[0]
display(VF.to_pil_image(make_grid(list(torch.fft.fft2(images).real.abs()) + list(torch.fft.fft2(images).imag.abs()), nrow=len(images), normalize=True,)))
output = image_filter(images)
print(output.shape)
#print(output)
display(VF.to_pil_image(resize(make_grid(list(images) + list((output).clamp(0, 1)), nrow=len(images)), 2)))
print("l1", (images - output).abs().mean())
display(VF.to_pil_image(resize(make_grid(list((images - output).abs()), normalize=True, scale_each=True, nrow=len(images)), 2)))

In [None]:
x = torch.fft.hfft2(images)
print("x", type(x), x.shape)
#x = torch.fft.fftshift(x, -3)
x[..., 6:, 6:] = 0
#x.imag *= .5
y = torch.fft.ifft2(x)
print("y", type(y), y.shape)
display(VF.to_pil_image(resize(make_grid((y).real.clamp(0, 1), nrow=len(images)), 2)))
display(VF.to_pil_image(resize(make_grid((images - y).real.clamp(0, 1), nrow=len(images)), 2)))

In [None]:
y = VF.gaussian_blur(images, 3, sigma=1.)
print("y", type(y), y.shape)
display(VF.to_pil_image(resize(make_grid((y).real.clamp(0, 1), nrow=len(images)), 2)))
display(VF.to_pil_image(resize(make_grid((images - y).real.clamp(0, 1), nrow=len(images)), 2)))

In [None]:
class FFTLayer(nn.Module):
    """
    Converts an n-dim input to fourier space.

    if `allow_complex==False`, the output shape for images (B, C, H, W) will be:

        type   concat_dim  output shape
        fft    -1          B, C, H, W * 2
        rfft   -1          B, C, H, W + 2
        hfft   -1          B, C, H, W * 2 - 2

        fft    -2          B, C, H * 2, W
        rfft   -2          B, C, H * 2, W // 2 + 1
        hfft   -2          B, C, H, W * 2 - 2        # hfft does not produce complex data so `concat_dim` is unused

        fft    -2          B, C * 2, H, W
        rfft   -2          B, C * 2, H, W // 2 + 1
        hfft   -2          B, C, H, W * 2 - 2

    if `allow_complex==True`, the output might be complex data and shapes are:

        type   output shape          is complex
        fft    B, C, H, W            yes
        rfft   B, C, H, W // 2 + 1   yes
        hfft   B, C, H, W * 2 - 2    no
    """
    def __init__(
            self,
            type: str = "fft",
            allow_complex: bool = False,
            concat_dim: int = -1,
            norm: str = "forward",
            inverse: bool = False,
    ):
        super().__init__()
        supported_types = [
            name[:-1] for name in dir(torch.fft)
            if name.endswith("fftn") and not name.startswith("i")
        ]
        if type not in supported_types:
            raise ValueError(f"Expected `type` to be one of {', '.join(supported_types)}, got '{type}'")

        supported_norm = ("forward", "backward", "ortho")
        if norm not in supported_norm:
            raise ValueError(f"Expected `norm` to be one of {', '.join(supported_norm)}, got '{norm}'")

        self.type = type
        self.norm = norm
        self.allow_complex = allow_complex
        self.concat_dim = concat_dim
        self.inverse = inverse

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.inverse and not self.allow_complex and not torch.is_complex(x) and not self.type == "hfft":
            x = torch.complex(
                    torch.slice_copy(x, self.concat_dim, 0, x.shape[self.concat_dim] // 2),
                    torch.slice_copy(x, self.concat_dim, x.shape[self.concat_dim] // 2),
            )

        func_name = f"{'i' if self.inverse else ''}{self.type}n"
        output = getattr(torch.fft, func_name)(x, norm=self.norm)

        if not self.inverse:
            if not self.allow_complex and torch.is_complex(output):
                output = torch.concat([output.real, output.imag], dim=self.concat_dim)
        else:
            output = output.real

        return output

    
input = torch.rand(1, 3, 24, 32)
output = FFTLayer("fft", False, -1)(input)
print(input.shape, "->", output.shape, torch.is_complex(output), output.real.min(), output.real.max(), output.real.sum())
#display(VF.to_pil_image(resize(make_grid(output, normalize=False, nrow=len(images)), 2)))

In [None]:
#torch.slice_copy(input, -1, 4).shape
torch.fft.ifft(output).real