In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device

In [None]:
#lib = GreedyLibrary.from_torch("../models/greedylib-1x31x31-388-photos-6M.pt")
#lib.plot_entries(signed=True)

In [None]:
!ls -l ../datasets

In [None]:
DATASETS = [
    # "../datasets/ca-64x64-i10-p05.pt",
    "../datasets/diverse-32x32-std01.pt",
    #"../datasets/fonts-regular-32x32.pt",
    "../datasets/ifs-1x128x128-uint8-1000x16.pt",
    "../datasets/kali-uint8-128x128.pt",
    "../datasets/pattern-1x128x128-uint.pt",
    "../datasets/photos-64x64-bcr03.pt",
]

SHAPE = (1, 31, 31)

In [None]:
ds_all = None
for name in DATASETS:
    tensor = torch.load(name)
    ds = TensorDataset(tensor)
    if tensor.dtype == torch.uint8:
        ds = TransformDataset(ds, dtype=torch.float, multiply=1. / 255.)
    ds = TransformDataset(ds, transforms=[VT.Grayscale(),])
    print(ds[0][0].mean(), name)
    if ds_all is None:
        ds_all = ds
    else:
        ds_all = ds_all + ds

print(len(ds_all), "images")

In [None]:
# ALTERNATIVE DATASET
            
ds_all = TransformIterableDataset(
    ImageFolderIterableDataset("/home/bergi/Pictures/photos/", recursive=True),
    transforms=[
        lambda x: x.to(torch.float) / 255. if x.dtype != torch.float else x,
        VT.Grayscale()
    ]
)
VF.to_pil_image(next(iter(ds_all))[0])

In [None]:
ds_crop = ImagePatchIterableDataset(
    ds_all, 
    shape=SHAPE[-2:],
    stride=SHAPE[-1] // 3,
)
ds_crop = IterableImageFilterDataset(ds_crop, ImageFilter(min_std=0.03)) 
ds_crop = IterableShuffle(ds_crop, max_shuffle=100_000)

for batch, in DataLoader(ds_crop, batch_size=30*30):
    print(batch.shape)
    #batch = batch.unsqueeze(1)
    img = VF.to_pil_image(make_grid(batch, nrow=30))
    img = VF.resize(img, (img.height * 2, img.width * 2), VF.InterpolationMode.NEAREST)
    break
img

In [None]:
#for batch, in tqdm(DataLoader(ds_crop, batch_size=100, num_workers=2)):
#    pass

In [None]:
lib = GreedyLibrary(1, SHAPE, std=.001, mean=0., device="cuda")

In [None]:
if 0:
    lib.entries = create_entries_2d(lib.shape, rotation_steps=8).to(lib.device)
    lib.entries -= lib.entries.mean()
    lib.entries *= .3
    lib.n_entries = lib.entries.shape[0]
    lib.hits = [0] * lib.n_entries

In [None]:
# --------- TRAIN ---------

try:
    count, last_count = 0, 0
    max_dist = math.prod(lib.shape) * .21
    #max_dist = -.001
    
    try:
        total = len(ds_crop)
    except TypeError:
        total = None
    with tqdm(total=total) as progress:    
        for batch, in DataLoader(ds_crop, batch_size=100, shuffle=not isinstance(ds_crop, IterableDataset), num_workers=2):
            if batch.ndim == 3:
                batch = batch.unsqueeze(1)
            lib.fit(batch, lr=1., skip_top_entries=0, zero_mean=True, grow_if_distance_above=max_dist, max_entries=1000)
            #lib.fit(batch, lr=.1, skip_top_entries=0, zero_mean=True, grow_if_distance_above=max_dist, max_entries=1000, metric="corr")
            #lib.fit(batch, lr=1., skip_top_entries=0, zero_mean=True)
            progress.desc = f"{lib.n_entries} entries"
            progress.update(batch.shape[0])
            
            count += batch.shape[0]
            if count > last_count + 200000:
                last_count = count

                n_entries = lib.n_entries
                lib.drop_entries(hits_lt=2, inplace=True)
                if lib.n_entries < 1000:
                    print(f"entries: {lib.n_entries}, dropped {n_entries - lib.n_entries}")

except KeyboardInterrupt:
    pass
(lib
 #.drop_entries(hits_lt=10, inplace=True)
 .sort_entries(inplace=True, reverse=True)
)
display(lib.plot_entries(min_size=600, sort_by="hits"))
print(sorted(lib.hits, reverse=True))
display(lib.plot_entries(min_size=600, sort_by="tsne"))

In [None]:
lib.n_entries, lib.entries.min(), lib.entries.max()

In [None]:
lib.sort_entries(reverse=True).save_torch(f"../models/greedylib-{lib.shape[0]}x{lib.shape[1]}x{lib.shape[2]}-{lib.n_entries}-photos-6M.pt")

In [None]:
# RELOAD

lib = GreedyLibrary.from_torch("../models/greedylib-1x31x31-388-photos-6M.pt").cuda()
# lib.plot_entries(signed=True)

In [None]:
image = next(iter(ds_all))[0]
image -= image.mean()
VF.to_pil_image(make_grid(F.max_pool2d(lib.convolve(image, stride=SHAPE[-1]), 1).unsqueeze(1)).clamp(0, 1))

In [None]:
def image_labels(image):
    c = lib.convolve(image - image.mean(), padding=max(0, lib.shape[-1] - 2), stride=1)
    c = F.max_pool2d(c, lib.shape[-2:])
    #print(c.shape)
    return c.permute(1, 2, 0).argmax(dim=2)
    
images = []
for image, in ds_all:
    labels = image_labels(image)
    
    repro = make_grid([lib.entries[i] for i in labels.flatten()], nrow=labels.shape[-1], normalize=False)
    repro = signed_to_image(repro)
    s = (min(image.shape[-2], repro.shape[-2]), min(image.shape[-1], repro.shape[-1]))
    repro[:, :s[-2], :s[-1]] = repro[:, :s[-2], :s[-1]] * .3 + .7 * image[:, :s[-2], :s[-1]].to(lib.device)
    images.append(repro)
    
    if len(images) >= 8*8:
        break
VF.to_pil_image(make_grid(images, nrow=8))

In [None]:
image = VF.pil_to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/photos/katjacam/101MSDCF/DSC00012.JPG"
)).to(torch.float) / 255.
print(image.shape)
image = VF.resize(image, (image.shape[-2] // 4, image.shape[-1] // 4))
print(image.shape)
VF.to_pil_image(image)

In [None]:
c = lib.cpu().drop_entries(hits_lt=10).convolve(VT.Grayscale()(image), stride=10)
c = F.max_pool2d(c, 3)
labels = c.permute(1, 2, 0).argmax(dim=2)
repro = make_grid([lib.entries[i] for i in labels.flatten()], nrow=labels.shape[-1], normalize=False)
repro = signed_to_image(repro)
VF.to_pil_image(repro)

In [None]:
c = lib.cpu().convolve(VT.Grayscale()(image - image.mean()), stride=3)
c = F.max_pool2d(c, 3)
c2 = lib.convolve(c[:1])
VF.to_pil_image(c2[100:103].clamp(0, 1))

In [None]:
SHAPE2 = [1, 7, 7]
lib2 = GreedyLibrary(1, SHAPE2, std=.001, mean=0., device="auto")

In [None]:
# --------- TRAIN ---------

try:
    count, last_count = 0, 0
    max_dist = math.prod(lib2.shape) * 2.
    #max_dist = -1
    
    print("max_dist", max_dist)
    try:
        total = len(ds_crop)
    except TypeError:
        total = None
    with tqdm(total=total) as progress:    
        for image, in DataLoader(ds_all, batch_size=1, shuffle=not isinstance(ds_crop, IterableDataset), num_workers=2):
            image = image[0]
            
            conv = lib.convolve(image - image.mean(), stride=SHAPE[-1])
            for conv_layer in conv:
                #print(conv_layer.min(), conv_layer.max())
                for patches in iter_image_patches(conv_layer.unsqueeze(0), shape=SHAPE2[-2:], stride=SHAPE2[-1], batch_size=1000):
                    ids, distances = lib2.fit(patches, lr=1., skip_top_entries=0, zero_mean=False, grow_if_distance_above=max_dist, max_entries=1000, metric="l2")
                    #print(distances.min(), distances.max())
                    #lib2.fit(batch, lr=.1, skip_top_entries=0, zero_mean=True, grow_if_distance_above=max_dist, max_entries=1000, metric="corr")
                    #lib2.fit(batch, lr=1., skip_top_entries=0, zero_mean=True)
                # break
                
            progress.desc = f"{lib2.n_entries} entries"
            progress.update(1)
            
            count += 1
            if count > last_count + 100:
                last_count = count

                n_entries = lib2.n_entries
                lib2.drop_entries(hits_lt=2, inplace=True)
                if lib2.n_entries < 1000:
                    print(f"entries: {lib2.n_entries}, dropped {n_entries - lib2.n_entries}")

except KeyboardInterrupt:
    pass
(lib2
 #.drop_entries(hits_lt=10, inplace=True)
 .sort_entries(inplace=True, reverse=True)
)
display(lib2.plot_entries(min_size=600, sort_by="hits", signed=True))
print(sorted(lib2.hits, reverse=True))
display(lib2.plot_entries(min_size=600, sort_by="hits"))

In [None]:
display(lib2.cpu().drop_entries(hits_lt=100).plot_entries(min_size=600, sort_by="hits", signed=True))

In [None]:
from src.algo import Space2d

def create_entries_2d(
        shape: Iterable[int], 
        rotation_steps: int = 4,
):
    shape = tuple(shape)
    assert len(shape) == 3, f"Expected shape of ndim=3, got {len(shape)}"
    
    entries = None
    def _add_entry(entry, zero_mean=True):
        nonlocal entries
        if zero_mean:
            entry = entry.clamp(0, 1)
            entry -= entry.mean()
            
        entry = entry.reshape(1, *shape)
                
        if entries is None:
            entries = entry
        else:
            dist = (
                (entries - entry.repeat(entries.shape[0], *(1 for _ in shape))).abs()
                .reshape(entries.shape[0], -1).sum(dim=1)
                / math.prod(shape)
            )
            if not torch.any(dist < 0.001):
                entries = torch.concat([entries, entry])
            
    def _shape(mode, space):
        if mode == "bar":
            return 1. - (space[0] - 0.).abs()
        elif mode == "edge":
            return 1. - (space[0] + 0.5)
        elif mode == "circle":
            return 1. - ((space[0] + 1.) ** 2 + space[1] ** 2).sqrt() * .5
    
    def _iter_spaces(mode):
        for rotation_idx in range(rotation_steps // (2 if mode == "bar" else 1)):
            rotation = np.pi * 2. * rotation_idx / rotation_steps
            yield Space2d((2, *shape[1:]), rotate_2d=rotation).space()
    
    for i in torch.linspace(-1, 1, 8):
        _add_entry(torch.ones(shape) * i, zero_mean=False)

    for mode1 in ("edge", "circle", "bar"):        
        for space in _iter_spaces(mode1):
            entry = _shape(mode1, space)
            _add_entry(entry)
        
    for mode1 in ("edge", "circle", "bar"):        
        for space in _iter_spaces(mode1):
            entry = _shape(mode1, space)

            for mode2 in ("edge", "circle", "bar"):
                for space2 in _iter_spaces(mode2):
                    entry2 = _shape(mode2, space2)
                    
                    for mix_mode in ("min", "max", "mix"):
                        if mix_mode == "min":
                            mixed_entry = torch.maximum(entry, entry2)
                        elif mix_mode == "max":
                            mixed_entry = torch.minimum(entry, entry2)
                        elif mix_mode == "mix":
                            mixed_entry = (entry + entry2) * .5

                        _add_entry(mixed_entry)
                    
        return entries 
        
VF.to_pil_image(signed_to_image(make_grid(create_entries_2d((1, 31, 31), rotation_steps=8), nrow=32)))

In [None]:
VF.to_pil_image(signed_to_image(make_grid(create_entries_2d((1, 31, 31), rotation_steps=8), nrow=32)))

In [None]:
class LocalGreedyLibrary:
    """
    Collection of library patches of ndim=>1.

    Learns by adjusting the best matching patch to match the input patch
    """
    def __init__(
            self,
            n_entries: int,
            shape: Iterable[int],
            mean: float = 0.,
            std: float = 0.01,
            device: Union[None, str, torch.device] = "cpu",
    ):
        self.device = to_torch_device(device)
        self.shape = tuple(shape)
        self.n_entries = n_entries
        self.entries = mean + std * torch.randn(n_entries, *self.shape).to(self.device)
        self.hits = [0] * n_entries

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_entries}, {self.shape})"

    def __copy__(self):
        return self.copy()

    @property
    def ndim(self) -> int:
        return len(self.shape)

    @property
    def max_hits(self) -> int:
        """Maximum number of hits of all entries"""
        return max(*self.hits) if self.hits else 0

    def copy(self) -> "GreedyLibrary":
        d = self.__class__(0, self.shape, device=self.device)
        d.n_entries = self.n_entries
        d.entries = deepcopy(self.entries)
        d.hits = self.hits.copy()
        return d

    def to(self, device: Union[str, torch.device], inplace: bool = False) -> "GreedyLibrary":
        if not inplace:
            lib = self.copy()
            return lib.to(device, inplace=True)

        self.device = to_torch_device(device)
        self.entries = self.entries.to(self.device)
        return self

    def cpu(self, inplace: bool = False) -> "GreedyLibrary":
        return self.to("cpu", inplace=inplace)

    def cuda(self, inplace: bool = False) -> "GreedyLibrary":
        return self.to("cuda", inplace=inplace)

    def save_torch(self, f: torch.serialization.FILE_LIKE, **kwargs):
        torch.save(self._save_data(), f, **kwargs)

    def load_torch(self, f: torch.serialization.FILE_LIKE, **kwargs):
        data = torch.load(f, **kwargs)
        self._load_data(data)

    @classmethod
    def from_torch(cls, f: torch.serialization.FILE_LIKE, device: Union[None, str, torch.device] = "cpu") -> "GreedyLibrary":
        lib = cls(n_entries=0, shape=tuple(), device=device)
        lib.load_torch(f)
        return lib

    def _save_data(self) -> dict:
        return {
            "entries": self.entries,
            "hits": self.hits,
        }

    def _load_data(self, data: dict):
        self.entries = data["entries"]
        self.hits = data["hits"]
        self.n_entries = len(self.hits)
        self.shape = tuple(self.entries.shape[1:])
        self.to(self.device)

    def top_entry_index(self) -> Optional[int]:
        """Returns index of entry with most hits"""
        top_idx, top_hits = None, None
        for i, hits in enumerate(self.hits):
            if top_idx is None or hits > top_hits:
                top_idx, top_hits = i, hits
        return top_idx

    def entry_ranks(self, reverse: bool = False) -> List[int]:
        """
        Returns a list of ranks for each entry,
        where rank means the index sorted by number of hits.
        """
        entry_ids = list(range(self.n_entries))
        entry_ids.sort(key=lambda i: self.hits[i], reverse=reverse)
        return [entry_ids.index(i) for i in range(self.n_entries)]

    def entry_hits(self, reverse: bool = False) -> Dict[int, int]:
        """
        Returns a dict of `entry-index` -> `number-of-hits`.

        Sorted by number of hits.
        """
        entry_ids = list(range(self.n_entries))
        entry_ids.sort(key=lambda i: self.hits[i], reverse=reverse)
        return {
            i: self.hits[i]
            for i in entry_ids
        }

    def sort_entries(
            self,
            by: str = "hits",
            reverse: bool = False,
            inplace: bool = False,
    ):
        if not inplace:
            lib = self.copy()
            return lib.sort_entries(by=by, reverse=reverse, inplace=True)

        sorted_ids = self.sorted_entry_indices(by=by, reverse=reverse)
        self.entries = self.entries[sorted_ids]
        self.hits = [self.hits[sorted_ids[i]] for i in range(len(self.hits))]
        return self

    def sorted_entry_indices(
            self,
            by: str = "hits",
            reverse: bool = False,
    ) -> List[int]:
        if not self.n_entries:
            return []
        entry_ids = list(range(self.n_entries))
        if self.n_entries < 2:
            return entry_ids

        if by == "hits":
            entry_ids.sort(key=lambda i: self.hits[i], reverse=reverse)

        elif by == "tsne":
            from sklearn.manifold import TSNE
            tsne = TSNE(1, perplexity=min(30, self.n_entries - 1))
            positions = tsne.fit_transform(self.entries.reshape(self.entries.shape[0], -1).cpu().numpy())
            entry_ids.sort(key=lambda i: positions[i], reverse=reverse)

        else:
            raise ValueError(f"Unsupported sort by '{by}'")

        return entry_ids

    def fit(
            self,
            batch: torch.Tensor,
            lr: float = 1.,
            zero_mean: bool = False,
            skip_top_entries: Union[bool, int] = False,
            grow_if_distance_above: Optional[float] = None,
            max_entries: int = 1000,
    ) -> None:
        """
        Partially fit a batch of patches.

        :param batch: Tensor of N patches of shape matching the library's shape
        :param lr: learning rate, range [0, 1]
        :param zero_mean: True to subtract the mean from each patch in the batch
        :param skip_top_entries: bool or int,
            Do not match the top N entries (1 for True), sorted by number of hits
        """
        batch = batch.to(self.device)

        if zero_mean:
            batch_mean = batch
            for i in range(self.ndim):
                batch_mean = batch_mean.mean(dim=i+1, keepdim=True)
            batch = batch - batch_mean

        best_entry_ids, distances = self.best_entries_for(batch, skip_top_entries=skip_top_entries)
        print(best_entry_ids, distances)
        for i in range(batch.shape[0]):
            entry_id = best_entry_ids[i]

            if grow_if_distance_above is not None:
                if distances[i] > grow_if_distance_above:
                    self.entries = torch.concat(
                        self.entries,
                        torch.randn(1, *self.shape).to(self.device) * 0.001 + self.entries.mean()
                    )
                    self.hits.append(0)
                    entry_id = self.n_entries
                    self.n_entries += 1

            weight = 1. / (1 + self.hits[entry_id])
            self.entries[entry_id] += lr * weight * (batch[i] - self.entries[entry_id])
            self.hits[entry_id] += 1

    def convolve(
            self,
            x: torch.Tensor,
            stride: Union[int, Iterable[int]] = 1,
            padding: Union[int, Iterable[int]] = 0,
    ) -> torch.Tensor:
        func = getattr(F, f"conv{self.ndim - 1}d", None)
        if not callable(func):
            raise NotImplementedError(f"{self.ndim - 1}-d convolution not supported")

        return func(x.to(self.device), self.entries, stride=stride, padding=padding)

    def best_entries_for(
            self,
            batch: torch.Tensor,
            skip_top_entries: Union[bool, int] = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns the index of the best matching entry for each patch in the batch.

        :param batch: Tensor of N patches of shape matching the library's shape
        :param skip_top_entries: bool or int,
            Do not match the top N entries (1 for True), sorted by number of hits
        :return: tuple of
            - Tensor of int64: entry indices
            - Tensor of float: distances
        """
        assert batch.ndim == len(self.shape) + 1, f"Got {batch.shape}"
        assert batch.shape[1:] == self.shape, f"Got {batch.shape}"
        ones = tuple(1 for _ in self.shape)

        repeated_entries = self.entries.repeat(batch.shape[0], *ones)
        repeated_batch = batch.to(self.device).repeat(1, self.n_entries, *ones[1:]).reshape(-1, *self.shape)

        dist = (repeated_entries - repeated_batch).abs()

        # TODO: correlation, about like: dist = (repeated_entries * repeated_batch).sum(??)

        while dist.ndim > 1:
            dist = dist.sum(1)
        dist = dist.reshape(batch.shape[0], -1)
        if not skip_top_entries:
            indices = torch.argmin(dist, 1)
            return indices, dist.flatten()[indices + torch.linspace(0, indices.shape[0] - 1, indices.shape[0]).to(torch.int64) * l.n_entries]

        skip_top_entries = int(skip_top_entries)
        sorted_indices = torch.argsort(dist, 1)
        entry_ranks = self.entry_ranks(reverse=True)
        best_entries = []
        for indices in sorted_indices:
            idx = 0
            while idx + 1 < len(indices) and entry_ranks[indices[idx]] < skip_top_entries:
                idx += 1
            best_entries.append(indices[idx])
        print(best_entries)
        indices = torch.Tensor(best_entries).to(torch.int64)
        return indices, dist.flatten()[indices + torch.linspace(0, indices.shape[0] - 1, indices.shape[0]).to(torch.int64) * l.n_entries]

    def drop_unused(self, inplace: bool = False) -> "GreedyLibrary":
        return self.drop_entries(hits_lt=1, inplace=inplace)

    def drop_entries(
            self,
            hits_lt: Optional[int] = None,
            inplace: bool = False,
    ) -> "GreedyLibrary":
        if not inplace:
            lib = self.copy()
            return lib.drop_entries(
                hits_lt=hits_lt,
                inplace=True,
            )

        drop_idx = set()
        if hits_lt is not None:
            for i, hits in enumerate(self.hits):
                if hits <= hits_lt:
                    drop_idx.add(i)

        if drop_idx:
            entries = []
            hits = []
            for i, (entry, h) in enumerate(zip(self.entries, self.hits)):
                if i not in drop_idx:
                    entries.append(entry.unsqueeze(0))
                    hits.append(h)
            self.entries = torch.concat(entries) if entries else torch.Tensor()
            self.hits = hits
            self.n_entries = len(self.hits)
            self.to(self.device, inplace=True)

        return self

    def plot_entries(
            self,
            min_size: int = 300,
            with_hits: bool = True,
            sort_by: Optional[str] = None,
    ) -> PIL.Image.Image:
        if len(self.shape) == 1:
            entries = self.entries.reshape(-1, 1, 1, *self.shape)
        elif len(self.shape) == 2:
            entries = self.entries.reshape(-1, 1, *self.shape)
        elif len(self.shape) == 3:
            entries = self.entries
        else:
            raise RuntimeError(f"Can't plot entries with shape {self.shape} (ndim>3)")
        if entries.shape[0]:

            e_min, e_max = entries.min(), entries.max()
            if e_min != e_max:
                entries = (entries - e_min) / (e_max - e_min)

            if with_hits:
                max_hits = max(1, self.max_hits)
                entry_list = []
                for entry, hits in zip(entries, self.hits):
                    if entry.shape[0] == 1:
                        entry = entry.repeat(3, *(1 for _ in entry.shape[1:]))
                    elif entry.shape[0] == 3:
                        pass
                    else:
                        raise ValueError(f"Can't plot entries with {entry.shape[0]} channels")

                    background = torch.Tensor([0, hits / max_hits, 0])
                    background = background.reshape(3, *((1,) * (len(entry.shape) - 1)))
                    background = background.repeat(1, *(s + 2 for s in entry.shape[1:]))
                    background[:, 1:-1, 1:-1] = entry
                    entry_list.append(background)
                entries = entry_list

            if sort_by:
                if not isinstance(entries, list):
                    entries = list(entries)
                entry_ids = self.sorted_entry_indices(by=sort_by, reverse=True)
                entries = [entries[i] for i in entry_ids]

            grid = make_grid(entries, nrow=max(1, int(np.sqrt(self.n_entries))), normalize=False)
            if grid.shape[-1] < min_size:
                grid = VF.resize(
                    grid,
                    [
                        int(grid.shape[-2] * min_size / grid.shape[-1]),
                        min_size,
                    ],
                    VF.InterpolationMode.NEAREST
                )
        else:
            grid = torch.zeros(1, min_size, min_size)
        return VF.to_pil_image(grid.cpu())

l = LocalGreedyLibrary(3, (1,))
l.entries = torch.Tensor([[1], [2], [3]])
l.fit(torch.Tensor([[0], [1.7], [3.1], [3.2]]))
l.fit(torch.Tensor([[0], [1.7], [3.1], [3.2]]), skip_top_entries=1)

In [None]:
indices = torch.Tensor([0, 1, 2, 2]).to(torch.int64)
dist = torch.Tensor(
       [[1.0000, 2.0000, 3.0000],
        [0.7000, 0.3000, 1.3000],
        [2.1000, 1.1000, 0.1000],
        [2.2000, 1.2000, 0.2000]])

dist.flatten()[indices + torch.linspace(0, 3, 4).to(torch.int64) * l.n_entries]