In [3]:
%cd ~/codeProjects/pythonProjects/Bayesian-Learning-with-Wasserstein-Barycenters

In [4]:
import abc
import typing as t
from bwb.utils import _DistributionT
import collections as c


class DistributionSampler(abc.ABC, t.Generic[_DistributionT]):
    r"""
    Base class for distributions that sampling other distributions. i.e. it represents a distribution :math:`\Lambda(dm) \in \mathcal{P}(\mathcal{M)}`, where :math:`\mathcal{M}` is the set of models. 
    """

    @abc.abstractmethod
    def draw(self, *args, **kwargs) -> _DistributionT:
        """Draw a sample."""
        ...

    @abc.abstractmethod
    def rvs(self, size=1, *args, **kwargs) -> t.Sequence[_DistributionT]:
        """Samples as many distributions as the ``size`` parameter indicates."""
        ...

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"


In [5]:
class DiscreteDistributionSampler(DistributionSampler[_DistributionT]):
    r"""
    Base class for distributions that have a discrete set of models. i.e. where the set of models is :math:`|\mathcal{M}| < +\infty`. 
    
    As the support is discrete, the distribution can be represented as a vector of probabilities, and therefore, the sampling process is reduced to drawing an index from a multinomial distribution. This property allows to save the samples and the number of times each model has been sampled, to get statistics about the sampling process.
    """
    def __init__(self, save_samples: bool = False):
        self.save_samples = save_samples
        self.samples_history: list[int] = []
        self.samples_counter: c.Counter[int] = c.Counter()
        self._models: dict[int, _DistributionT] = {}

    @abc.abstractmethod
    def _draw(self, *args, **kwargs) -> tuple[_DistributionT, int]:
        """To use template pattern on the draw method."""
        ...

    # @abc.abstractmethod
    def draw(self, *args, **kwargs) -> _DistributionT:
        """Draw a sample."""
        to_return, i = self._draw(*args, **kwargs)
        if self.save_samples:  # Register the sample
            self.samples_history.append(i)
            self.samples_counter[i] += 1
        return to_return

    @abc.abstractmethod
    def _rvs(self, size=1, *args, **kwargs) -> tuple[t.Sequence[_DistributionT], list[int]]:
        """Samples as many distributions as the ``size`` parameter indicates."""
        ...

    # @abc.abstractmethod
    def rvs(self, size=1, *args, **kwargs) -> t.Sequence[_DistributionT]:
        """Samples as many distributions as the ``size`` parameter indicates."""
        to_return, list_indices = self._rvs(size, *args, **kwargs)
        if self.save_samples:  # Register the samples
            self.samples_history.extend(list_indices)
            self.samples_counter.update(list_indices)
        return to_return
    
    @abc.abstractmethod
    def get_model(self, i: int) -> _DistributionT:
        """Get the model with index i."""
        ...

    def __repr__(self) -> str:
        to_return = self.__class__.__name__

        if self.save_samples:
            to_return += f"(samples={len(self.samples_history)})"

        return to_return

In [15]:
from quickdraw_dataset import QuickDraw
import torchvision.transforms as T
from pathlib import Path
from bwb.distributions import DistributionDraw
import torch
from bwb.config import config

ds = QuickDraw(
    Path("./data"),
    category="face",
    download=True,
    transform=T.Compose([
        T.ToTensor(),
        T.Lambda(lambda x: x.squeeze()),
    ])
)

ds_ = QuickDraw(
    Path("./data"),
    category="face",
    download=True,
    transform=T.Compose([
        T.ToTensor(),
        T.Lambda(lambda x: x.squeeze()),
        T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x))
    ])
)
first_face = ds_[0][0]
data = first_face.sample((100,)).reshape(1, -1)
data.shape
len(ds)

In [14]:
from bwb.utils import _ArrayLike

@t.runtime_checkable
class DiscreteModelsSet(t.Protocol, t.Generic[_DistributionT]):
    """
    Protocol for classes that are a set of models with a discrete support.
    """
    def compute_probability(self, data: _ArrayLike, **kwargs) -> torch.Tensor:
        """
        Compute the probabilities of the data given the models.
        
        :param data: The data to compute the probabilities.
        :return: A tensor with the probabilities.
        """
        ...

    def get(self, i: int, **kwargs) -> _DistributionT:
        """Get the model at the index ``i``."""
        ...

    def __len__(self) -> int:
        """Get the number of models."""
        ...

In [1]:
from time import time
import multiprocessing as mp

mp.cpu_count()

In [12]:
from torch.utils.data import DataLoader
for num_workers in range(0, mp.cpu_count()+1, 2):  
    train_loader = DataLoader(ds,shuffle=True,num_workers=num_workers,batch_size=256,pin_memory=True)
    start = time()
    for epoch in range(1, 3):
        for i, data in enumerate(train_loader, 0):
            pass
    end = time()
    print("Finish with:{:.4f} second, num_workers={}".format(end - start, num_workers))

In [16]:
from torch.utils.data import DataLoader, Dataset
from bwb.config import config
from bwb.utils import _ArrayLike
import bwb.distributions as dist
import multiprocessing as mp

# Method for the class ``QuickDraw``.
def compute_probability(self, data: _ArrayLike, **kwargs) -> torch.Tensor:
    """
    Compute the probabilities of the data given the models.
    
    :param data: The data to compute the probabilities.
    :return: A tensor with the probabilities.
    """

    dataloader = DataLoader(self, batch_size=kwargs.get("batch_size", 1024), shuffle=False, num_workers=kwargs.get("num_workers", mp.cpu_count()))

    likelihoods = []

    for features, _ in dataloader:
        features = features.to(config.device)
        features = features.reshape(features.size(0), -1)
        features = features / features.sum(dim=1, keepdim=True)
        evaluations = torch.take_along_dim(features, data, 1)
        likelihood = torch.exp(evaluations.sum(dim=1))

        likelihoods.append(likelihood)

    likelihood_cache = torch.cat(likelihoods, dim=0)

    probabilities = likelihood_cache / likelihood_cache.sum()

    return probabilities

# Method for the class ``QuickDraw``.
def get(self, i: int, **kwargs) -> dist.DistributionDraw:
    """Get the model at the index ``i``."""
    return dist.DistributionDraw.from_grayscale_weights(self[i][0])

# Add the methods to the class.
QuickDraw.compute_probability = compute_probability
QuickDraw.get = get

ds.compute_probability(data.reshape(1, -1))

In [16]:
ds._get(0)

In [85]:


dataloader = DataLoader(ds, batch_size=2 ** 18, shuffle=False)

features, labels = next(iter(dataloader))

# features = features.to("cpu")

# features /= features.sum(dim=1, keepdim=True)

torch.take_along_dim(features, data, 1)

# features.sum(dim=1).shape
len(dataloader)

In [84]:
len(dataloader)

In [36]:
from bwb.utils import _ArrayLike
import torch
from bwb.config import config
import bwb.distributions as dist


def _set_generator(seed=None, device="cpu") -> torch.Generator:
    gen = torch.Generator(device=device)
    if seed is None:
        gen.seed()
        return gen
    gen.manual_seed(seed)
    return gen



class ExplicitPosteriorSampler(DiscreteDistributionSampler[dist.DistributionDraw]):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._fitted = False
        self.total_time = 0.0

    def fit(
            self,
            data: _ArrayLike,
            models: DiscreteModelsSet[dist.DistributionDraw],
            batch_size: int = 256,
        ):
        """
        Fit the posterior distribution.

        :param data: The data to fit the posterior.
        :param models: The models to fit the posterior.
        :param batch_size: The batch size to compute the probabilities.
        :return: The fitted posterior.
        """
        assert isinstance(models, DiscreteModelsSet), "The models must be a DiscreteModelsSet."


        self.data_: torch.Tensor = torch.as_tensor(data, device=config.device)
        self.models_: DiscreteModelsSet[dist.DistributionDraw] = models
        self.models_index_: torch.Tensor = torch.arange(len(models), device=config.device)

        data = self.data_.reshape(1, -1)

        self.probabilities_: torch.Tensor = models._compute_probability(data, batch_size=batch_size)

        self.support_ = self.models_index_[self.probabilities_ > 0]

        
        self._fitted = True

        return self
    
    def get_model(self, i: int) -> dist.DistributionDraw:
        """Get the model with index i."""
        if self._models.get(i) is None:
            self._models[i] = self.models_._get(i)
        return self._models[i]

    def _draw(self, seed=None, *args, **kwargs) -> tuple[_DistributionT, int]:
        rng: torch.Generator = _set_generator(seed=seed, device=config.device)

        i = torch.multinomial(input=self.probabilities_, num_samples=1, generator=rng).item()
        i = int(i)
        return self.get_model(i), i

    def _rvs(self, size=1, seed=None, *args, **kwargs) -> tuple[t.Sequence[_DistributionT], list[int]]:
        rng: torch.Generator = _set_generator(seed=seed, device=config.device)

        indices = torch.multinomial(input=self.probabilities_, num_samples=size, replacement=True, generator=rng)
        indices = indices.tolist()
        return [self.get_model(i) for i in indices], indices
    
    def __repr__(self) -> str:
        to_return = self.__class__.__name__

        if not self._fitted:
            to_return += "()"
            return to_return

        to_return += "("
        to_return += f"n_data={len(self.data_)}, "
        to_return += f"n_models={len(self.models_)}, "
        to_return += f"n_support={len(self.support_)}"
        to_return += ")"

        return to_return
        


posterior = ExplicitPosteriorSampler(save_samples=True)

posterior.fit(data, ds)

In [39]:
posterior.draw()

In [40]:
posterior._models