In [None]:
from init_notebook import *

In [None]:
image1 = VF.to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/__diverse/_1983018_orson_150.jpg"
    #"/home/bergi/Pictures/__diverse/keinmenschistillegal.jpg"
).convert("RGB"))
display(VF.to_pil_image(image1))
display(image1.shape)

image2 = VF.to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/__diverse/keinmenschistillegal.jpg"
).convert("RGB"))
display(VF.to_pil_image(image2))
display(image2.shape)

image3 = VF.to_tensor(PIL.Image.open(
    "../datasets/MarsMarken.png"
).convert("RGB"))
display(VF.to_pil_image(image3))
display(image3.shape)

image4 = VF.to_tensor(PIL.Image.open(
    "../datasets/pixilart.png"
).convert("RGB"))
display(VF.to_pil_image(image4))
display(image4.shape)

In [None]:
def generalized_mean_image(
        target_images: torch.Tensor,  # B,C,H,W
        perceptual_model: nn.Module,
        steps: int = 20000,
        batch_size: int = 32,
        learnrate: float = 0.005,
        loss_function: Callable = F.mse_loss,
        device: str = "auto",
        ret_image: bool = False,
):
    torch.cuda.empty_cache()
    
    device = to_torch_device(device)
    if callable(getattr(perceptual_model, "to", None)):
        perceptual_model.to(device)
    
    source_image = nn.Parameter(torch.zeros(target_images.shape[1:]).to(device))
    optimizer = torch.optim.Adam([source_image], lr=learnrate)
    
    target_batch = target_images.repeat(batch_size // target_images.shape[0], 1, 1, 1)[:batch_size].to(device)
    with torch.no_grad():
        p_target_batch = perceptual_model(target_batch)

    try:
        with tqdm(total=steps) as progress:
            for i in range(steps // batch_size):
                progress.update(batch_size)
    
                source_batch = source_image.unsqueeze(0).repeat(batch_size, 1, 1, 1)
                p_source_batch = perceptual_model(source_batch)
    
                loss = loss_function(p_source_batch, p_target_batch)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                progress.set_postfix({"loss": float(loss)})
    except KeyboardInterrupt:
        pass

    source_image = source_image.detach().cpu().clamp(0, 1)
    if not ret_image:
        display(VF.to_pil_image(source_image))
    torch.cuda.empty_cache()
    if ret_image:
        return VF.to_pil_image(source_image)
    
def create_target_images(image: torch.Tensor, o: int = 2):
    h, w = image.shape[-2:]
    h -= o
    w -= o
    return torch.cat([
        image[None, :, :h, :w],
        image[None, :, o:h+o, :w],
        image[None, :, o:h+o, o:w+o],
        image[None, :, :h, o:w+o],
    ])
    
image1_l1 = generalized_mean_image(
    create_target_images(image1, o=3),
    nn.Identity(),
    loss_function=F.l1_loss,
    ret_image=True,
)       

In [None]:
generalized_mean_image(
    create_target_images(image4),
    nn.Identity(),
    loss_function=F.mse_loss,
)

In [None]:
generalized_mean_image(
    create_target_images(image3),
    nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=1),
        nn.ReLU(),
    ),
    loss_function=F.l1_loss,
)

In [None]:
generalized_mean_image(
    create_target_images(image3),
    nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=1),
        nn.ReLU(),
    ),
    loss_function=F.mse_loss,
)

In [None]:
generalized_mean_image(
    create_target_images(image3),
    nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=1),
        nn.ReLU(),
        nn.AvgPool2d(32, 16),
    ),
    loss_function=F.l1_loss,
)

# special kernel

In [None]:
perceptual_model = nn.Conv2d(3, 3, 3, padding=1)
w = torch.Tensor([
    [-1, 0, -1],
    [0, 4, 0],
    [-1, 0, -1],
]) * math.sqrt(2) / 2
with torch.no_grad():
    perceptual_model.weight[:] = torch.cat([
        (w.unsqueeze(0).repeat(3, 1, 1) * torch.Tensor([1, 0, 0])).unsqueeze(0),
        (w.unsqueeze(0).repeat(3, 1, 1) * torch.Tensor([0, 1, 0])).unsqueeze(0),
        (w.unsqueeze(0).repeat(3, 1, 1) * torch.Tensor([0, 0, 1])).unsqueeze(0),
    ])
#print(perceptual_model.weight)

generalized_mean_image(
    create_target_images(image2),
    perceptual_model,
    loss_function=F.l1_loss,
)

In [None]:
import pywt
from src.models.wavelet.util import create_wavelet_filter
#pywt.wavelist()
#wl = pywt.Wavelet("haar")
#wl.dec_hi
create_wavelet_filter("haar", 3, 3)[0].shape

# sobel

In [None]:
VF.gaussian_blur?

In [None]:
def sobel(x):
    blur = VF.gaussian_blur(x, 5, 2.)
    return x - blur * .5

generalized_mean_image(
    create_target_images(image4),
    sobel,
    loss_function=F.l1_loss,
)


# FFT

In [None]:
generalized_mean_image(
    create_target_images(image3),
    #nn.Identity(), 
    torch.fft.fft2,
    loss_function=F.l1_loss,
    steps=100_000,
)

In [None]:
def fft_func(x):
    x = torch.fft.fft2(x)
    return torch.concat([x.real, x.imag], dim=-3)
    
generalized_mean_image(
    create_target_images(image3),
    fft_func,
    loss_function=F.mse_loss, #l1_loss,
    steps=200_000,
)

In [None]:
def fft_func(x):
    x = torch.fft.fft2(x)
    return x[..., :10, :10]
    
generalized_mean_image(
    create_target_images(image3),
    fft_func,
    loss_function=F.l1_loss,
    steps=100_000,
)
display(image1_l1)

In [None]:
def fft_func(x):
    x = torch.fft.fft2(x)
    h, w = x.shape[-2:]
    h, w = h * 2 // 3, w * 2 // 3
    #x[..., h:, w:] = x[..., h:, w:] * 10
    #x[..., :3, :3] = x[..., :3, :3] * 1000
    return x
    #f1 = x[..., :10, :10].flatten(-2)
    #f2 = x[..., 10:, 10:].flatten(-2)
    #return torch.concat([f1, f2], dim=-1)
    
generalized_mean_image(
    create_target_images(image1),
    fft_func,
    loss_function=F.l1_loss,
    steps=100_000,
)
display(image1_l1)

In [None]:
with torch.no_grad():
    x = torch.fft.fft2(image1)
    #x = x[..., :30, :30] / 20
    x[..., :, 1:20] = 0
    #f2 = x[..., 10:, 10:].flatten(-2)
    display(VF.to_pil_image(torch.fft.ifft2(x).real.clamp(0, 1)))

In [None]:
CH = 16
ACT = nn.GELU()
KS = 1
for KS in [1, 2, 3, 4, 5, 6, 7]:
#for CH in [4, 8, 16, 32, 64, 128]:
    print(f"CH={CH}, KS={KS}, ACT={ACT}")
    grid = []
    for i in range(4):
        grid.append(VF.to_tensor(generalized_mean_image(
            create_target_images(image1),
            nn.Sequential(
                nn.Conv2d(3, CH, kernel_size=KS),#, dilation=3), 
                ACT,
                nn.Conv2d(CH, CH, kernel_size=KS),#, dilation=5), 
                ACT,
                nn.Conv2d(CH, CH, kernel_size=KS), 
                ACT,
            ),
            loss_function=F.l1_loss,
            batch_size=16,
            steps=6000,
            ret_image=True,
        )))
    print(f"CH={CH}, KS={KS}, ACT={ACT}")
    if len(grid) == 1:
        display(VF.to_pil_image(grid[0]))
    else:
        display(VF.to_pil_image(make_grid(grid)))
                   

In [None]:
CH = 128
ACT = nn.GELU()
KS = 1
for KS in [1, 3, 5, 7, 9]:
#for CH in [4, 8, 16, 32, 64, 128]:
    print(f"CH={CH}, KS={KS}, ACT={ACT}")
    grid = []
    for i in range(1):
        grid.append(VF.to_tensor(generalized_mean_image(
            create_target_images(image2),
            nn.Sequential(
                nn.Conv2d(3, CH, kernel_size=KS, padding=(KS - 1) // 2),#, dilation=3), 
                #ACT,
                #nn.Conv2d(CH, CH, kernel_size=KS),#, dilation=5), 
                #ACT,
                #nn.Conv2d(CH, CH, kernel_size=KS), 
                #ACT,
            ),
            loss_function=F.l1_loss,
            batch_size=16,
            steps=6000,
            ret_image=True,
        )))
    print(f"CH={CH}, KS={KS}, ACT={ACT}")
    if len(grid) == 1:
        display(VF.to_pil_image(grid[0]))
    else:
        display(VF.to_pil_image(make_grid(grid)))

In [None]:
CH = 128
ACT = nn.ReLU()
generalized_mean_image(
    create_target_images(image1, o=3),
    nn.Sequential(
        nn.Conv2d(3, CH, kernel_size=3), 
        ACT,
        nn.Conv2d(CH, CH, kernel_size=3), 
        ACT,
        nn.Conv2d(CH, CH, kernel_size=3), 
        ACT,
        nn.AvgPool2d(32, 16),
        #nn.Conv2d(CH, CH, kernel_size=1), 
        #ACT,
        #nn.Conv2d(CH, CH, kernel_size=1), 
        #ACT,
        #nn.Conv2d(CH, CH, kernel_size=1), 
        #ACT,
        #nn.Conv2d(CH, CH, kernel_size=1), 
        #ACT,
    ),
    loss_function=F.mse_loss,
    batch_size=16,
)

In [None]:
def fft_func(x):
    x = torch.fft.fft2(x)
    return torch.concat([x.real, x.imag], dim=-3)
    
generalized_mean_image(
    create_target_images(image2),
    nn.Identity(),
    loss_function=lambda s, t: F.huber_loss(s, t, delta=.4),
    #loss_function=lambda s, t: -F.cosine_similarity(s, t).mean(),
    #steps=200_000,
)

In [None]:
[n for n in dir(F) if "loss" in n]

In [None]:
F.smooth_l1_loss?

In [None]:
import torchvision
ds = torchvision.datasets.STL10(
    root=Path("~/prog/data/datasets/").expanduser(),
    #download=True,
)
for _, (image, id) in zip(range(10), ds):
    display(image)

In [None]:
F.pixel_unshuffle(VF.to_tensor(image), 3).shape

In [None]:
64*4**2