# Alternative representation of MultitaskMultivariateNormal

## Issue:
Currently, the interleaved/non-interleaved representation of MTMVN creates a bunch of headaches. It would be nice to have a higher-level API that abstracts away from these details. 


## Suggestion:
Represent the covariance matrix as an "unrolled" tensor (e.g. `n x n x m x m`), instead of a `nm x nm` covariance matrix. That way things like scalarizations are easily done, and reshaping/viewing/getting items can be done more straightforwardly. 


## Challenge:
In order to sample from this MVN, we need to compute the full `nm x nm` covariance matrix (to either compute the cholesky decomposition or apply iterative approximate root decomposition methods. We need to make sure this is transparent, fast, and happens without much overhead.

In [1]:
from __future__ import annotations

import torch

In [2]:
n = 3  # number of points
m = 2  # number of outputs

# create some full covar

def make_rand_covar(k, batch_shape=torch.Size()):
    a = torch.rand(*batch_shape, k, k)
    return a @ a.transpose(-1, -2) + torch.diag_embed(torch.rand(k))
    
A = make_rand_covar(n * m)

### interleaved: block matrix where each block is an intra-point, cross-task covariance

In [3]:
C_inter = torch.zeros(n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for i in range(n):
    for i_ in range(n):
        for j in range(m):
            for j_ in range(m):
                C_inter[i, i_, j, j_] = A[i*m+j, i_*m+j_]

In [4]:
A

tensor([[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
        [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
        [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
        [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
        [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
        [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]])

In [5]:
# we can construct the full matrix as follows:
C_inter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
        [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
        [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
        [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
        [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
        [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]])

In [6]:
# batched version
b = 2
A_b = torch.stack([A, A + 1])
C_inter_b = torch.zeros(b, n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for i_ in range(n):
            for j in range(m):
                for j_ in range(m):
                    C_inter_b[b_, i, i_, j, j_] = A_b[b_, i*m+j, i_*m+j_]
                    
C_inter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n)

tensor([[[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
         [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
         [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
         [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
         [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
         [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]],

        [[2.6730, 2.4068, 2.4398, 2.0731, 2.8534, 1.9899],
         [2.4068, 4.3130, 3.0636, 2.8746, 4.3943, 3.4631],
         [2.4398, 3.0636, 3.5348, 2.4629, 3.6626, 2.8031],
         [2.0731, 2.8746, 2.4629, 3.2579, 3.3630, 2.6523],
         [2.8534, 4.3943, 3.6626, 3.3630, 5.4928, 4.0654],
         [1.9899, 3.4631, 2.8031, 2.6523, 4.0654, 4.7077]]])

In [7]:
# the general formulation:
batch_shape = C_inter_b.shape[:-4]
C_inter_b.permute(*range(len(batch_shape)), -4, -2, -3, -1).reshape(*batch_shape, m*n, m*n)

tensor([[[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
         [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
         [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
         [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
         [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
         [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]],

        [[2.6730, 2.4068, 2.4398, 2.0731, 2.8534, 1.9899],
         [2.4068, 4.3130, 3.0636, 2.8746, 4.3943, 3.4631],
         [2.4398, 3.0636, 3.5348, 2.4629, 3.6626, 2.8031],
         [2.0731, 2.8746, 2.4629, 3.2579, 3.3630, 2.6523],
         [2.8534, 4.3943, 3.6626, 3.3630, 5.4928, 4.0654],
         [1.9899, 3.4631, 2.8031, 2.6523, 4.0654, 4.7077]]])

In [8]:
# how do we go back? I.e. construct C_inter efficiently form the full matrix? Just do the same thing in reverse

batch_shape = A_b.shape[:-2]
C_inter_b_recov = A_b.reshape(*batch_shape, n, m, n, m).permute(*range(len(batch_shape)), -4, -2, -3, -1)

In [9]:
torch.allclose(C_inter_b, C_inter_b_recov)

True

### non-interleaved: block matrix where each block is an intra-task, cross-point covariance

In [10]:
C_noninter = torch.zeros(m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for i in range(m):
    for i_ in range(m):
        for j in range(n):
            for j_ in range(n):
                C_noninter[i, i_, j, j_] = A[i*n+j, i_*n+j_]

In [11]:
A

tensor([[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
        [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
        [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
        [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
        [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
        [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]])

In [12]:
# again we can construct the matrix as follows:
C_noninter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
        [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
        [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
        [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
        [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
        [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]])

In [13]:
# batched version
C_noninter_b = torch.zeros(b, m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(m):
        for i_ in range(m):
            for j in range(n):
                for j_ in range(n):
                    C_noninter_b[b_, i, i_, j, j_] = A_b[b_, i*n+j, i_*n+j_]
                    
torch.allclose(
    C_noninter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n),
    A_b,
)

True

In [14]:
# and again we can go back the same way

batch_shape = A_b.shape[:-2]
C_noninter_b_recov = A_b.reshape(*batch_shape, m, n, m, n).permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_noninter_b, C_noninter_b_recov)

True

### alternate mixed-interleaved: block matrix where each block is an intra-point, cross-task covariance

This seems to be the most useful internal representation, as this means we don't have to do any permuting

In [15]:
# alternate representation
C_alt = torch.zeros(n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for i in range(n):
    for j in range(m):
        for i_ in range(n):    
            for j_ in range(m):
                C_alt[i, j, i_, j_] = A[i*m+j, i_*m+j_]

In [16]:
A

tensor([[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
        [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
        [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
        [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
        [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
        [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]])

In [17]:
# this is super straightforward
C_alt.view(n*m, n*m)

tensor([[1.6730, 1.4068, 1.4398, 1.0731, 1.8534, 0.9899],
        [1.4068, 3.3130, 2.0636, 1.8746, 3.3943, 2.4631],
        [1.4398, 2.0636, 2.5348, 1.4629, 2.6626, 1.8031],
        [1.0731, 1.8746, 1.4629, 2.2579, 2.3630, 1.6523],
        [1.8534, 3.3943, 2.6626, 2.3630, 4.4928, 3.0654],
        [0.9899, 2.4631, 1.8031, 1.6523, 3.0654, 3.7077]])

In [18]:
# alternate representation
C_alt_b = torch.zeros(b, n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for j in range(m):
            for i_ in range(n):    
                for j_ in range(m):
                    C_alt_b[b_, i, j, i_, j_] = A_b[b_, i*m+j, i_*m+j_]

batch_shape = C_alt_b.shape[:-4]
torch.allclose(
    C_alt_b.view(*batch_shape, n*m, n*m),
    A_b,
)

True

In [19]:
# ....and going back

batch_shape = A_b.shape[:-2]
C_alt_b_recov = A_b.reshape(*batch_shape, n, m, n, m) #.permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_alt_b, C_alt_b_recov)

True

## Question: What if the tensor memory layout is different? Do we need to handle that?

### Scalarizing

This representation makes scalarizing across the outputs trivial

In [20]:
weights = torch.rand(m)
(C_alt @ weights).transpose(-1, -2) @ weights

tensor([[2.1402, 1.9346, 2.5527],
        [1.9346, 2.9159, 3.2304],
        [2.5527, 3.2304, 5.3240]])

In [21]:
((C_alt[0, :, 0, :] @ weights) * weights).sum()

tensor(2.1402)

In [22]:
((C_alt[1, :, 2, :] @ weights) * weights).sum()

tensor(3.2304)

### Indexing outputs is also easy

(straightforward to do this with more than one elements of the outputs)

In [23]:
C_alt[:, 1, :, 1]

tensor([[3.3130, 1.8746, 2.4631],
        [1.8746, 2.2579, 1.6523],
        [2.4631, 1.6523, 3.7077]])

In fact, we can have MTMVN themselves be indexable - if we have the external API assume that the shape is `batch_shape x n x m`, then doing `mtmvn[:4, :]` would access all outputs of the first four data points. Similarly, `mtmvn[0, ..., 2:]` would extract the mtmvn across all datapoints of the first batch element for all but the first outputs.

In [24]:
from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal

mean = torch.rand(2, 4)
covar = make_rand_covar(4, batch_shape=(2,))
mvn = MultivariateNormal(mean, covar)

In [25]:
mvn[-1].covariance_matrix

tensor([[3.2598, 1.2003, 1.2292, 0.9164],
        [1.2003, 1.7207, 0.9716, 0.7842],
        [1.2292, 0.9716, 1.9955, 0.6534],
        [0.9164, 0.7842, 0.6534, 1.0543]])

In [26]:
covar[-1]

tensor([[3.2598, 1.2003, 1.2292, 0.9164],
        [1.2003, 1.7207, 0.9716, 0.7842],
        [1.2292, 0.9716, 1.9955, 0.6534],
        [0.9164, 0.7842, 0.6534, 1.0543]])

In [27]:
mvn[:, -2:].covariance_matrix

tensor([[[2.2289, 1.1750],
         [1.1750, 2.3230]],

        [[1.9955, 0.6534],
         [0.6534, 1.0543]]])

In [28]:
covar[:, -2:, -2:]

tensor([[[2.2289, 1.1750],
         [1.1750, 2.3230]],

        [[1.9955, 0.6534],
         [0.6534, 1.0543]]])

In [29]:
mtmean = torch.rand(4, 2)
mtcovar = make_rand_covar(8)
mtmvn = MultitaskMultivariateNormal(mtmean, mtcovar)

In [30]:
mtmvn.mean

tensor([[0.4372, 0.5764],
        [0.4178, 0.2967],
        [0.3096, 0.2036],
        [0.1308, 0.0618]])

In [31]:
# This fails - we want to enable this easily with the new representation
# mtmvn[0]

## Representing independet outputs

Here we can just store the `m` individual `n x n` blocks as a `m x n x n` tensor, no need to store the cross covariances.

We could also think of the case where we have independence across points, and only inter-task correlation. This will typically not be the case though, so we can punt on this for now.

In [32]:
from gpytorch.lazy import BlockDiagLazyTensor

C_indep = torch.stack([make_rand_covar(n) for _ in range(m)])

# using the lazy here will speed up matrix operations
BlockDiagLazyTensor(C_indep).evaluate()

tensor([[1.2450, 0.7732, 0.4257, 0.0000, 0.0000, 0.0000],
        [0.7732, 2.5714, 1.7067, 0.0000, 0.0000, 0.0000],
        [0.4257, 1.7067, 2.0056, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 2.3348, 0.7860, 1.4855],
        [0.0000, 0.0000, 0.0000, 0.7860, 1.1482, 1.0489],
        [0.0000, 0.0000, 0.0000, 1.4855, 1.0489, 1.8948]])

In [33]:
C_indep

tensor([[[1.2450, 0.7732, 0.4257],
         [0.7732, 2.5714, 1.7067],
         [0.4257, 1.7067, 2.0056]],

        [[2.3348, 0.7860, 1.4855],
         [0.7860, 1.1482, 1.0489],
         [1.4855, 1.0489, 1.8948]]])

In [34]:
# to construct the full matrix (we generally don't want to do this!) we can just do
torch.block_diag(*C_indep)

tensor([[1.2450, 0.7732, 0.4257, 0.0000, 0.0000, 0.0000],
        [0.7732, 2.5714, 1.7067, 0.0000, 0.0000, 0.0000],
        [0.4257, 1.7067, 2.0056, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 2.3348, 0.7860, 1.4855],
        [0.0000, 0.0000, 0.0000, 0.7860, 1.1482, 1.0489],
        [0.0000, 0.0000, 0.0000, 1.4855, 1.0489, 1.8948]])

# We can also allow arbitrary order in the construction...

### ... so long as we have a consistent internal representation

The following orders are admissible (`n` is the number of points, `t` is the number tasks):

```
nntt
ntnt
nttn
ttnn
tntn
tnnt

nnt
ntt
ntn
ttn
tnt
tnn
```

So basically we have the set 
```
nnt
ntt
ntn
ttn
tnt
tnn
```

and then the whole set of admissible combinations we get by combining this wiht the set we get by post-pending with the single letter (we could pre-pend too but that would just generate duplicates

Now we know this is admissible. We'd like to store things consistently internally. There are four options:

- no independence -> can standardize to `n x t x n x t`
- cross-task independence only -> can standardize to `t x n x n`
- cross-point independence only -> can standardize to `n x t x t`
- cross-task AND cross-point independence: This is trivial -> can standardize to `n x t` (just marginal variances)

**For now we focus on the first two cases**, we can deal with the other ones later.

In [35]:
from gpytorch.distributions.multioutput_multivariate_normal import MultioutputMultivariateNormal, NormalizedOrder

In [36]:
mean_b = torch.randn(b, n, m)
mtmvn = MultioutputMultivariateNormal(mean=mean_b, covariance=C_alt_b, order=("n" ,"m", "n", "m"))

In [37]:
mtmvn

MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2)

In [38]:
mtmvn.mean.shape

torch.Size([2, 3, 2])

In [39]:
mtmvn.variance.shape

torch.Size([2, 3, 2])

In [40]:
mtmvn.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [41]:
mtmvn.log_prob(mtmvn.mean + torch.randn_like(mtmvn.mean))

tensor([-10.2001,  -8.4040])

#### Test subsetting

In [42]:
# select batch <- this should probably return an un-batched MOMVN
mtmvn[0]

MultioutputMultivariateNormal(n=3, m=2)

In [43]:
# select range of batches
mtmvn[:]

MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2)

In [44]:
# select data point <- should this return a MOMVN or just a batched MVN? batched MVN seems most logical
mtmvn[..., 0, :]

MultivariateNormal(batch_shape=(2,), n=2)

In [45]:
# select range of data points
mtmvn[..., :2, :]

MultioutputMultivariateNormal(batch_shape=(2,), n=2, m=2)

In [46]:
# select output <- this should definitely return just a MVN
mtmvn[..., 0]

MultivariateNormal(batch_shape=(2,), n=3)

In [47]:
# select range of outputs
mtmvn[..., :2]

MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2)

In [48]:
# mixed indexing select specific data point and output form specific batch
# ^this is just a single Normal - how to interpret?
mtmvn[0, 1, 0]

MultivariateNormal(n=1)

### Independent tasks

In [49]:
C_indep_b = torch.stack([C_indep, C_indep + 1])

mtmvn_indep = MultioutputMultivariateNormal(mean=mean_b, covariance=C_indep_b, order=("m", "n", "n"))

In [50]:
C_indep_b.shape  # batch_shape x m x n x n

torch.Size([2, 2, 3, 3])

In [51]:
mtmvn_indep

MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2, independent outputs)

In [52]:
mtmvn_indep.mean.shape

torch.Size([2, 3, 2])

In [53]:
mtmvn_indep.variance.shape

torch.Size([2, 3, 2])

In [54]:
mtmvn_indep.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [55]:
mtmvn_indep.log_prob(mtmvn_indep.mean + torch.randn_like(mtmvn_indep.mean))

tensor([-23.9029,  -9.9000])

In [56]:
mtmvn_indep

MultioutputMultivariateNormal(batch_shape=(2,), n=3, m=2, independent outputs)

In [57]:
mtmvn_indep[0]

MultioutputMultivariateNormal(n=3, m=2, independent outputs)

In [58]:
mtmvn_indep[..., :, 0]

MultivariateNormal(batch_shape=(2,), n=3)

In [59]:
mtmvn_indep[..., 0, :]

MultivariateNormal(batch_shape=(2,), n=2)

# Kronecker product structure

In [60]:
from gpytorch.lazy import KroneckerProductLazyTensor

In [61]:
batch_shape = torch.Size([2])

Kn = make_rand_covar(n, batch_shape)
Km = make_rand_covar(m, batch_shape)

covar = KroneckerProductLazyTensor(Km, Kn)

In [62]:
Kn

tensor([[[1.1333, 0.2893, 0.7803],
         [0.2893, 1.1635, 0.5701],
         [0.7803, 0.5701, 1.8817]],

        [[1.5586, 0.5207, 0.9835],
         [0.5207, 2.0579, 1.0481],
         [0.9835, 1.0481, 1.4063]]])

In [63]:
Km

tensor([[[0.5969, 0.0491],
         [0.0491, 0.8590]],

        [[0.6028, 0.0862],
         [0.0862, 1.5217]]])

In [64]:
covar.evaluate()

tensor([[[0.6764, 0.1727, 0.4658, 0.0557, 0.0142, 0.0383],
         [0.1727, 0.6945, 0.3403, 0.0142, 0.0572, 0.0280],
         [0.4658, 0.3403, 1.1232, 0.0383, 0.0280, 0.0924],
         [0.0557, 0.0142, 0.0383, 0.9734, 0.2485, 0.6703],
         [0.0142, 0.0572, 0.0280, 0.2485, 0.9995, 0.4897],
         [0.0383, 0.0280, 0.0924, 0.6703, 0.4897, 1.6163]],

        [[0.9396, 0.3139, 0.5929, 0.1343, 0.0449, 0.0847],
         [0.3139, 1.2406, 0.6318, 0.0449, 0.1773, 0.0903],
         [0.5929, 0.6318, 0.8478, 0.0847, 0.0903, 0.1212],
         [0.1343, 0.0449, 0.0847, 2.3717, 0.7923, 1.4967],
         [0.0449, 0.1773, 0.0903, 0.7923, 3.1316, 1.5949],
         [0.0847, 0.0903, 0.1212, 1.4967, 1.5949, 2.1401]]])

In [65]:
# can construct the full tensor from the kronecker product (we DON'T want to do this)

Dense = Km.unsqueeze(-1).unsqueeze(-1) * Kn.unsqueeze(-3).unsqueeze(-3).expand(*Km.shape, *Kn.shape[-2:])
Dense.permute(0, 1, 3, 2, 4).reshape(*Km.shape[:-2], m*n, m*n)

tensor([[[0.6764, 0.1727, 0.4658, 0.0557, 0.0142, 0.0383],
         [0.1727, 0.6945, 0.3403, 0.0142, 0.0572, 0.0280],
         [0.4658, 0.3403, 1.1232, 0.0383, 0.0280, 0.0924],
         [0.0557, 0.0142, 0.0383, 0.9734, 0.2485, 0.6703],
         [0.0142, 0.0572, 0.0280, 0.2485, 0.9995, 0.4897],
         [0.0383, 0.0280, 0.0924, 0.6703, 0.4897, 1.6163]],

        [[0.9396, 0.3139, 0.5929, 0.1343, 0.0449, 0.0847],
         [0.3139, 1.2406, 0.6318, 0.0449, 0.1773, 0.0903],
         [0.5929, 0.6318, 0.8478, 0.0847, 0.0903, 0.1212],
         [0.1343, 0.0449, 0.0847, 2.3717, 0.7923, 1.4967],
         [0.0449, 0.1773, 0.0903, 0.7923, 3.1316, 1.5949],
         [0.0847, 0.0903, 0.1212, 1.4967, 1.5949, 2.1401]]])