In [None]:
import sys
sys.path.append("..")

import random
import math
import time
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import pandas as pd

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *
from src.models.transform import *
from src.models.loss import *

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for idx, entry in enumerate(tqdm(iterable, total=total)):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry) if callable(label) else idx)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

In [None]:
def histogram(x: torch.Tensor, bins: int, range: Optional[Tuple[int, int]] = None):
    if x.ndim == 1:
        x = x.unsqueeze(0)
    else:
        x = x.flatten(1)
        
    if range is None:
        range = (x.min(), x.max())

    hist = torch.zeros(



In [None]:
def differentiable_histogram(x, bins=255, min=0.0, max=1.0):

    hist_torch = torch.zeros(bins).to(x.device)
    delta = (max - min) / bins

    BIN_Table = torch.arange(start=0, end=bins, step=1) * delta

    for dim in range(1, bins-1, 1):
        h_r = BIN_Table[dim].item()             # h_r
        h_r_sub_1 = BIN_Table[dim - 1].item()   # h_(r-1)
        h_r_plus_1 = BIN_Table[dim + 1].item()  # h_(r+1)

        mask_sub = ((h_r > x) & (x >= h_r_sub_1)).float()
        mask_plus = ((h_r_plus_1 > x) & (x >= h_r)).float()

        hist_torch[dim] += torch.sum((x - h_r_sub_1) * mask_sub)
        hist_torch[dim] += torch.sum((h_r_plus_1 - x) * mask_plus)
        #hist_torch[dim] += torch.sum(((x - h_r_sub_1) * mask_sub)) #.view(n_samples, n_chns, -1), dim=-1)
        #hist_torch[dim] += torch.sum(((h_r_plus_1 - x) * mask_plus)) #.view(n_samples, n_chns, -1), dim=-1)

    return hist_torch / delta

#differentiable_histogram(torch.Tensor([[0, 1, 1, 2.2, 3]]), 4, 0, 4)
data = torch.rand(1, 10)
BINS, RANGE = 10, (0, 1)
h = differentiable_histogram(data, BINS, RANGE[0], RANGE[1])
h2 = torch.histc(data, BINS, *RANGE)
print(h.shape, h2.shape)
px.bar(pd.DataFrame({"d": h, "h": h2}), barmode="group")

In [None]:
img = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/there_is_no_threat.jpeg"))
BINS, RANGE = 100, (0, 1)
th = torch.histc(img, BINS, *RANGE)
td = differentiable_histogram(img, BINS, *RANGE)
px.line(pd.DataFrame({
    "torch": th,
    "error": (th - td),
    "diff": td,
}))

In [None]:
def soft_histogram_flat(x: torch.Tensor, bins: int, min: float, max, sigma: float = 100.):
    if x.ndim > 1:
        x = x.flatten(0)
        
    delta = (max - min) / bins
    centers = min + delta * (torch.arange(bins).float() + 0.5)

    x = torch.unsqueeze(x, 0) - torch.unsqueeze(centers, 1)
    x = torch.sigmoid(sigma * (x + delta / 2)) - torch.sigmoid(sigma * (x - delta / 2))
    x = x.sum(dim=-1)
    return x
        
soft = soft_histogram_flat(img.flatten(0), BINS, *RANGE, 10000)
if 1:
    display(px.line(pd.DataFrame({
        "torch": th,
        "error": (th - soft),
        "soft": soft,
    })))

In [None]:
def soft_histogram(x: torch.Tensor, bins: int, min: float, max, sigma: float = 100.):
    if x.ndim == 1:
        x = x.unsqueeze(0)
    elif x.ndim == 2:
        pass
    else:
        x = x.flatten(1)
        
    delta = (max - min) / bins
    centers = min + delta * (torch.arange(bins, device=x.device, dtype=x.dtype) + 0.5)
    
    x = torch.unsqueeze(x, 1) - torch.unsqueeze(centers, 1)
    x = torch.sigmoid(sigma * (x + delta / 2)) - torch.sigmoid(sigma * (x - delta / 2))
    x = x.sum(dim=-1)
    return x

soft = soft_histogram(img, BINS, *RANGE, 100000)
print(soft.shape)
display(px.line(soft.T))


In [None]:
torch.Tensor([[[1, 2, 3]], [[4, 5, 6]]]) - torch.Tensor([[0], [1.5], [1.6], [1.7]])

In [None]:
def phi_k(x, L = 1. / 256., W = 1. / 256. / 2.5):
    return torch.sigmoid((x + (L / 2)) / W) - torch.sigmoid((x - (L / 2)) / W)

phi_k(img).shape

In [None]:
def differentiable_histogram_2(x: torch.Tensor, bins: int, *, range: Optional[Tuple[float, float]] = None):
    if range is None:
        range = (x.min(), x.max())

    if isinstance(bins, int):
        bin_table = torch.linspace(*range, bins + 1)
    else:
        if not isinstance(bins, torch.Tensor):
            raise TypeError(f"Expected `bins` to be of type int or Tensor, got {type(bins).__name__}")
        if bins.ndim != 1:
            raise ValueError(f"Expected `bins` to be one-dimensional, got {bins.shape}")
            
        bin_table = bins
    
    hist = torch.zeros(bins).to(x.device)
    
    for dim, (bin_value, next_bin_value) in enumerate(zip(bin_table, bin_table[1:])):

        mask = ((x >= bin_value) & (x < next_bin_value)).float()
        hist[dim] = mask.sum()

    return hist, bin_table

#differentiable_histogram(torch.Tensor([[0, 1, 1, 2.2, 3]]), 4, 0, 4)
data = torch.rand(1, 10, 30)
h = differentiable_histogram_2(data, 4, range=(0, 1))[0]
h2 = torch.histogram(data, 4, range=(0, 1))[0]
print(h.shape, h2.shape)
px.bar(pd.DataFrame({"d": h, "h": h2}), barmode="group")

In [None]:
torch.histogram?

In [None]:
torch.histogram(torch.Tensor([0]), bins=10)

In [None]:
h = differentiable_histogram(torch.rand(1, 10000), 10, 0, 2)
print(h.shape)
px.bar(h[0, 0])

In [None]:
input_data = torch.rand(32, 128)
targets = torch.rand(32, 10)

model = nn.Linear(128, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for i in range(30):
    output = model(input_data)
    h = differentiable_histogram(output, 10, -1, 1)[0]
    loss = F.mse_loss(h, targets)
    model.zero_grad()
    loss.backward()
    optimizer.step()
    print(loss)

In [None]:
def random_data(shape: Tuple[int, int]):
    #data = torch.randn(shape)
    t = torch.linspace(0, 2 * torch.pi * shape[-1], shape[-1])
    t = t.unsqueeze(0).expand(shape[0], -1)
    data = torch.randn_like(t) * .3
    
    for i in range(4):
        f = torch.randn(3, shape[0], 1) 
        f *= torch.randn_like(f)
        data += torch.tanh(torch.sin(t * 7. * f[0]) * f[1] * .3 + f[2])
    return data

data = random_data((3, 1000))
print("data", data.shape)
th = torch.concat([torch.histc(d, 100, -1, 1).unsqueeze(0) for d in data])
sh = soft_histogram(data, 100, -1, 1, 100)
display(px.line(th.T))
display(px.line(sh.T))
px.line((th - sh).T)

In [None]:
bins = 10
for i in tqdm(range(1000)):
    while True:
        data = random_data((1, 10000))[0]
    
        t_hist = torch.histc(data, bins, -1, 1)
        if t_hist.max() >= 10:
            break
    
    s_hist = soft_histogram_flat(data, bins, -1, 1, sigma=100_000)
    # print(t_hist.shape, s_hist.shape)
    
    ma = t_hist.max()
    error = F.l1_loss(s_hist / ma, t_hist / ma)
    if error > 1.:
        print(error)
        display(px.line(pd.DataFrame({
            "torch": t_hist,
            "error": (t_hist - s_hist),
            "soft": s_hist,
        })))
        break

In [None]:
vec = torch.Tensor([[0, 0, 0], [0, 1, 2]])
n = torch.linalg.norm(vec, dim=-1, keepdim=True)
print(n)
(vec / n).pow(2)

In [None]:
hl = HistogramLayer(100, 0, 1, sigma=50)
px.line(hl(img).T)