# Queue

In [2]:
import os
import sys

sys.path.append("../")

In [None]:
import numpy as np
from sklearn.datasets import make_blobs, make_moons, make_circles
import seaborn as sns
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Function

from src.ml.sinkhorn import pot_sinkhorn

In [None]:
sns.set(style="whitegrid")

## Dataset generation

In [4]:
n_features = 512
n_clusters = 128
batch_size = 64
n_samples = 2000

In [5]:
class ToyDataset(Dataset):
    def __init__(self, n_features, n_clusters, n_samples):
        self.n_features = n_features
        self.n_clusters = n_clusters
        self.n_samples = n_samples
    
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y_true[idx]
    
    def plot(self):
        fig, ax = plt.subplots(
            nrows=1,
            ncols=1,
            figsize=(5, 5)
        )

        sns.scatterplot(  # plot first 2 components
            x=self.X[:, 0],
            y=self.X[:, 1],
            hue=map(str, self.y_true),
            ax=ax,
            legend=False
        )

        ax.set_xlabel("Component 1")
        ax.set_ylabel("Component 2")
        ax.set_title("Clusters visualization")

        return fig


class BlobsDataset(ToyDataset):
    """
    https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html
    """
    def __init__(self, n_features, n_clusters, n_samples):
        super().__init__(n_features, n_clusters, n_samples)
        
        X, y_true = make_blobs(
            n_samples=n_samples,
            n_features=n_features,
            centers=n_clusters,
            cluster_std=.8,
            random_state=0
        )
        
        self.X = torch.FloatTensor(X)
        self.y_true = torch.LongTensor(y_true)

In [6]:
dataset = BlobsDataset(n_features, n_clusters, n_samples)
dataloader = DataLoader(dataset, batch_size=batch_size)

## Model definition

In [7]:
class Model(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
        )
    
    def forward(self, inputs):
        mlp_out = self.mlp(inputs)
        softmax_out = nn.LogSoftmax(dim=1)(mlp_out)
        
        return softmax_out
    
model = Model(
    input_dim=n_features,
    output_dim=n_clusters
)

## Queue implementation

- [x] stored_M (= previous batchs dans la queue)
- [x] M = current batch
- [x] tu rajoutes stored_M comme argument supplementaire à SinkhornValueFunc
- [x] sinkhorn est execute sur M_full = concatenation de M et stored_M
- [ ] puis dans SinkhornValue tu implemente la logique de storer les M quand tu call la fonction, avec une logique de queue FIFO 
- [x] et pas besoin d'utiliser une queue pour cela, juste un tensor me semble suffisant, juste quand tu arrives à la limite de taille, pour inserer le nouveau batch, tu "roll" le tensor


In [1]:
stored_M = torch.Tensor()   # tensor acts as queue
max_n_batches_in_queue = 4  # max number of batches in queue

for batch_ix, (inputs, labels) in enumerate(dataloader):
    # M is model output
    M = model(inputs)
    
    #################
    # Sinkhorn step #
    #################
    
    M_concat = torch.cat([M, stored_M])
    
    # Compute marginals
    a = torch.ones(M_concat.shape[0])
    b = torch.ones(M_concat.shape[0] / M.shape[1]) / M.shape[1]
    
    # Compute sinkhorn
    P = pot_sinkhorn(M_concat, a, b, epsilon=0.1)

    ################ 
    # Update queue #
    ################
    
    # Update stored M
    n_batches_in_queue = stored_M.shape[0] / batch_size
    if  n_batches_in_queue < max_n_batches_in_queue:
        # Append current batch to previous batches
        stored_M = M_concat
    else:
        # Roll stored M, older batch comes first, replace it with M
        stored_M = torch.roll(stored_M, batch_size, 0)  # roll, first batch is the oldest
        stored_M[:batch_size, :] = M                    # update first batch with new one
    
    # Print for debug
    print("Batch {0}: {1}".format(batch_ix, stored_M.shape))
    
    if batch_ix == 10:
        break

NameError: name 'torch' is not defined

## Queue integration

In [86]:
class SinkhornValueFunc(Function):
    @staticmethod
    def forward(ctx, M, stored_M, a, b, epsilon, solver, solver_options):
        # Run Sinkhorn
        P = solver(
            torch.cat([M, stored_M]),  # Use the queue
            a,
            b,
            epsilon,
            **solver_options
        )
        P = P[:M.shape[0], :]  # Take only current batch

        ctx.save_for_backward(P)
        return (P*M).sum()

    @staticmethod
    def backward(ctx, grad_output):
        P, = ctx.saved_tensors
        grad_M = P * grad_output

        return grad_M, None, None, None, None, None, None


class SinkhornValue(nn.Module):
    """Sinkhorn value.

    Returns optimal value for the regularized OT problem:
        L(M) = max <M, P> + \epsilon H[P] s.t. \sum_j P_ij = a_i and \sum_i P_ij = b_j
    with entropy H[P] = - \sum_ij P_ij [log(P_ij) - 1]

    Args:
        epsilon (float): regularization parameter
        solver (function): OT solver
        solver_kwargs (int): options to pass to the solver
    """
    def __init__(self, epsilon, max_n_batches_in_queue, solver, **solver_options):
        super().__init__()
        # Sinkhorn params
        self.epsilon = epsilon
        self.solver = solver
        self.solver_options = solver_options
        
        # Queue params
        self.stored_M = torch.Tensor()                        # tensor acts as queue
        self.max_n_batches_in_queue = max_n_batches_in_queue  # max number of batches in queue

    def forward(self, M):
        batch_size = M.shape[0]
        
        #################
        # Sinkhorn step #
        #################
        # Compute marginals
        M_concat = torch.cat([M, self.stored_M])
        a = torch.ones(M_concat.shape[0])
        b = torch.ones(M.shape[1]) / (M.shape[0] / M.shape[1])

        # Compute sinkhorn
        loss = SinkhornValueFunc.apply(
            M,
            self.stored_M,
            a,
            b,
            self.epsilon,
            self.solver,
            self.solver_options
        )
        
        ################
        # Update queue #
        ################
        n_batches_in_queue = self.stored_M.shape[0] / batch_size
        if n_batches_in_queue < self.max_n_batches_in_queue:
            # Append current batch to previous batches
            self.stored_M = M_concat
        else:
            # Roll stored M, older batch comes first, replace it with M
            self.stored_M = torch.roll(self.stored_M, batch_size, 0)  # roll, first batch is the oldest
            self.stored_M[:batch_size, :] = M                         # update first batch with new one
        
        return loss

    def extra_repr(self):
        return (
            f"epsilon={self.epsilon:.2e}, solver={self.solver}"
            "solver_options={self.solver_options}"
        )

In [87]:
SV = SinkhornValue(
    epsilon=0.1,
    solver=pot_sinkhorn,
    max_n_batches_in_queue=4
)

for batch_ix, (inputs, labels) in enumerate(dataloader):
    M = model(inputs)
    
    loss = SV(-M)
    print(SV.stored_M.shape, loss)
    
    if batch_ix == 10:
        break

torch.Size([64, 128]) tensor(4.2025, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([128, 128]) tensor(2.0841, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([192, 128]) tensor(1.3816, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(1.0406, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(0.8220, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(0.8202, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(0.8283, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(0.8239, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(0.8215, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(0.8300, grad_fn=<SinkhornValueFuncBackward>)
torch.Size([256, 128]) tensor(0.8281, grad_fn=<SinkhornValueFuncBackward>)
