# Alternative representation of MultitaskMultivariateNormal

## Issue:
Currently, the interleaved/non-interleaved representation of MTMVN creates a bunch of headaches. It would be nice to have a higher-level API that abstracts away from these details. 


## Suggestion:
Represent the covariance matrix as an "unrolled" tensor (e.g. `n x n x m x m`), instead of a `nm x nm` covariance matrix. That way things like scalarizations are easily done, and reshaping/viewing/getting items can be done more straightforwardly. 


## Challenge:
In order to sample from this MVN, we need to compute the full `nm x nm` covariance matrix (to either compute the cholesky decomposition or apply iterative approximate root decomposition methods. We need to make sure this is transparent, fast, and happens without much overhead.

In [1]:
from __future__ import annotations

import torch

In [2]:
n = 3  # number of points
m = 2  # number of outputs

# create some full covar

def make_rand_covar(k, batch_shape=torch.Size()):
    a = torch.rand(*batch_shape, k, k)
    return a @ a.transpose(-1, -2) + torch.diag_embed(torch.rand(k))
    
A = make_rand_covar(n * m)

### interleaved: block matrix where each block is an intra-point, cross-task covariance

In [3]:
C_inter = torch.zeros(n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for i in range(n):
    for i_ in range(n):
        for j in range(m):
            for j_ in range(m):
                C_inter[i, i_, j, j_] = A[i*m+j, i_*m+j_]

In [4]:
A

tensor([[1.8296, 0.7429, 0.5221, 1.0727, 0.6583, 1.2720],
        [0.7429, 2.8304, 0.8278, 1.6706, 1.8234, 1.8707],
        [0.5221, 0.8278, 1.8112, 0.6711, 0.7029, 1.0244],
        [1.0727, 1.6706, 0.6711, 2.8601, 1.8187, 1.9545],
        [0.6583, 1.8234, 0.7029, 1.8187, 2.2391, 1.7424],
        [1.2720, 1.8707, 1.0244, 1.9545, 1.7424, 3.4549]])

In [5]:
# we can construct the full matrix as follows:
C_inter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[2.2407, 1.1343, 1.0139, 1.4896, 1.4506, 1.1681],
        [1.1343, 1.7848, 1.1197, 1.4541, 1.7069, 0.9432],
        [1.0139, 1.1197, 1.5502, 1.4861, 1.6303, 1.1749],
        [1.4896, 1.4541, 1.4861, 3.1754, 2.1889, 1.8183],
        [1.4506, 1.7069, 1.6303, 2.1889, 2.7358, 1.4337],
        [1.1681, 0.9432, 1.1749, 1.8183, 1.4337, 1.9787]])

In [6]:
# batched version
b = 2
A_b = torch.stack([A, A + 1])
C_inter_b = torch.zeros(b, n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for i_ in range(n):
            for j in range(m):
                for j_ in range(m):
                    C_inter_b[b_, i, i_, j, j_] = A_b[b_, i*m+j, i_*m+j_]
                    
C_inter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n)

tensor([[[2.2407, 1.1343, 1.0139, 1.4896, 1.4506, 1.1681],
         [1.1343, 1.7848, 1.1197, 1.4541, 1.7069, 0.9432],
         [1.0139, 1.1197, 1.5502, 1.4861, 1.6303, 1.1749],
         [1.4896, 1.4541, 1.4861, 3.1754, 2.1889, 1.8183],
         [1.4506, 1.7069, 1.6303, 2.1889, 2.7358, 1.4337],
         [1.1681, 0.9432, 1.1749, 1.8183, 1.4337, 1.9787]],

        [[3.2407, 2.1343, 2.0139, 2.4896, 2.4506, 2.1681],
         [2.1343, 2.7848, 2.1197, 2.4541, 2.7069, 1.9432],
         [2.0139, 2.1197, 2.5502, 2.4861, 2.6303, 2.1749],
         [2.4896, 2.4541, 2.4861, 4.1754, 3.1889, 2.8183],
         [2.4506, 2.7069, 2.6303, 3.1889, 3.7358, 2.4337],
         [2.1681, 1.9432, 2.1749, 2.8183, 2.4337, 2.9787]]])

In [7]:
# the general formulation:
batch_shape = C_inter_b.shape[:-4]
C_inter_b.permute(*range(len(batch_shape)), -4, -2, -3, -1).reshape(*batch_shape, m*n, m*n)

tensor([[[2.2407, 1.1343, 1.0139, 1.4896, 1.4506, 1.1681],
         [1.1343, 1.7848, 1.1197, 1.4541, 1.7069, 0.9432],
         [1.0139, 1.1197, 1.5502, 1.4861, 1.6303, 1.1749],
         [1.4896, 1.4541, 1.4861, 3.1754, 2.1889, 1.8183],
         [1.4506, 1.7069, 1.6303, 2.1889, 2.7358, 1.4337],
         [1.1681, 0.9432, 1.1749, 1.8183, 1.4337, 1.9787]],

        [[3.2407, 2.1343, 2.0139, 2.4896, 2.4506, 2.1681],
         [2.1343, 2.7848, 2.1197, 2.4541, 2.7069, 1.9432],
         [2.0139, 2.1197, 2.5502, 2.4861, 2.6303, 2.1749],
         [2.4896, 2.4541, 2.4861, 4.1754, 3.1889, 2.8183],
         [2.4506, 2.7069, 2.6303, 3.1889, 3.7358, 2.4337],
         [2.1681, 1.9432, 2.1749, 2.8183, 2.4337, 2.9787]]])

In [8]:
# how do we go back? I.e. construct C_inter efficiently form the full matrix? Just do the same thing in reverse

batch_shape = A_b.shape[:-2]
C_inter_b_recov = A_b.reshape(*batch_shape, n, m, n, m).permute(*range(len(batch_shape)), -4, -2, -3, -1)

In [9]:
torch.allclose(C_inter_b, C_inter_b_recov)

True

### non-interleaved: block matrix where each block is an intra-task, cross-point covariance

In [10]:
C_noninter = torch.zeros(m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for i in range(m):
    for i_ in range(m):
        for j in range(n):
            for j_ in range(n):
                C_noninter[i, i_, j, j_] = A[i*n+j, i_*n+j_]

In [11]:
A

tensor([[2.2407, 1.1343, 1.0139, 1.4896, 1.4506, 1.1681],
        [1.1343, 1.7848, 1.1197, 1.4541, 1.7069, 0.9432],
        [1.0139, 1.1197, 1.5502, 1.4861, 1.6303, 1.1749],
        [1.4896, 1.4541, 1.4861, 3.1754, 2.1889, 1.8183],
        [1.4506, 1.7069, 1.6303, 2.1889, 2.7358, 1.4337],
        [1.1681, 0.9432, 1.1749, 1.8183, 1.4337, 1.9787]])

In [12]:
# again we can construct the matrix as follows:
C_noninter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[2.2407, 1.1343, 1.0139, 1.4896, 1.4506, 1.1681],
        [1.1343, 1.7848, 1.1197, 1.4541, 1.7069, 0.9432],
        [1.0139, 1.1197, 1.5502, 1.4861, 1.6303, 1.1749],
        [1.4896, 1.4541, 1.4861, 3.1754, 2.1889, 1.8183],
        [1.4506, 1.7069, 1.6303, 2.1889, 2.7358, 1.4337],
        [1.1681, 0.9432, 1.1749, 1.8183, 1.4337, 1.9787]])

In [13]:
# batched version
C_noninter_b = torch.zeros(b, m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(m):
        for i_ in range(m):
            for j in range(n):
                for j_ in range(n):
                    C_noninter_b[b_, i, i_, j, j_] = A_b[b_, i*n+j, i_*n+j_]
                    
torch.allclose(
    C_noninter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n),
    A_b,
)

True

In [14]:
# and again we can go back the same way

batch_shape = A_b.shape[:-2]
C_noninter_b_recov = A_b.reshape(*batch_shape, m, n, m, n).permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_noninter_b, C_noninter_b_recov)

True

### alternate mixed-interleaved: block matrix where each block is an intra-point, cross-task covariance

This seems to be the most useful internal representation, as this means we don't have to do any permuting

In [15]:
# alternate representation
C_alt = torch.zeros(n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for i in range(n):
    for j in range(m):
        for i_ in range(n):    
            for j_ in range(m):
                C_alt[i, j, i_, j_] = A[i*m+j, i_*m+j_]

In [16]:
A

tensor([[2.2407, 1.1343, 1.0139, 1.4896, 1.4506, 1.1681],
        [1.1343, 1.7848, 1.1197, 1.4541, 1.7069, 0.9432],
        [1.0139, 1.1197, 1.5502, 1.4861, 1.6303, 1.1749],
        [1.4896, 1.4541, 1.4861, 3.1754, 2.1889, 1.8183],
        [1.4506, 1.7069, 1.6303, 2.1889, 2.7358, 1.4337],
        [1.1681, 0.9432, 1.1749, 1.8183, 1.4337, 1.9787]])

In [17]:
# this is super straightforward
C_alt.view(n*m, n*m)

tensor([[2.2407, 1.1343, 1.0139, 1.4896, 1.4506, 1.1681],
        [1.1343, 1.7848, 1.1197, 1.4541, 1.7069, 0.9432],
        [1.0139, 1.1197, 1.5502, 1.4861, 1.6303, 1.1749],
        [1.4896, 1.4541, 1.4861, 3.1754, 2.1889, 1.8183],
        [1.4506, 1.7069, 1.6303, 2.1889, 2.7358, 1.4337],
        [1.1681, 0.9432, 1.1749, 1.8183, 1.4337, 1.9787]])

In [18]:
# alternate representation
C_alt_b = torch.zeros(b, n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for j in range(m):
            for i_ in range(n):    
                for j_ in range(m):
                    C_alt_b[b_, i, j, i_, j_] = A_b[b_, i*m+j, i_*m+j_]

batch_shape = C_alt_b.shape[:-4]
torch.allclose(
    C_alt_b.view(*batch_shape, n*m, n*m),
    A_b,
)

True

In [19]:
# ....and going back

batch_shape = A_b.shape[:-2]
C_alt_b_recov = A_b.reshape(*batch_shape, n, m, n, m) #.permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_alt_b, C_alt_b_recov)

True

## Question: What if the tensor memory layout is different? Do we need to handle that?

### Scalarizing

This representation makes scalarizing across the outputs trivial

In [20]:
weights = torch.rand(m)
(C_alt @ weights).transpose(-1, -2) @ weights

tensor([[1.7039, 1.0163, 1.2945],
        [1.0163, 1.4423, 1.4991],
        [1.2945, 1.4991, 2.0874]])

In [21]:
((C_alt[0, :, 0, :] @ weights) * weights).sum()

tensor(1.7039)

In [22]:
((C_alt[1, :, 2, :] @ weights) * weights).sum()

tensor(1.4991)

### Indexing outputs is also easy

(straightforward to do this with more than one elements of the outputs)

In [23]:
C_alt[:, 1, :, 1]

tensor([[1.7848, 1.4541, 0.9432],
        [1.4541, 3.1754, 1.8183],
        [0.9432, 1.8183, 1.9787]])

In fact, we can have MTMVN themselves be indexable - if we have the external API assume that the shape is `batch_shape x n x m`, then doing `mtmvn[:4, :]` would access all outputs of the first four data points. Similarly, `mtmvn[0, ..., 2:]` would extract the mtmvn across all datapoints of the first batch element for all but the first outputs.

In [24]:
from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal
mean = torch.rand(4)
covar = torch.eye(4)
mvn = MultivariateNormal(mean, covar)

In [25]:
mvn[-1].covariance_matrix

tensor(1.)

In [26]:
mvn[1:4].covariance_matrix

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

In [27]:
mvn.covariance_matrix

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [28]:
mtmean = torch.rand(4, 2)
mtcovar = torch.eye(mtmean.numel())
mtmvn = MultitaskMultivariateNormal(mtmean, mtcovar)
mtmvn[0]

RuntimeError: mean should be a matrix or a batch matrix (batch mode)

## Representing independet outputs

Here we can just store the `m` individual `n x n` blocks as a `m x n x n` tensor, no need to store the cross covariances.

We could also think of the case where we have independence across points, and only inter-task correlation. This will typically not be the case though, so we can punt on this for now.

In [29]:
from gpytorch.lazy import BlockDiagLazyTensor

C_indep = torch.stack([make_rand_covar(n) for _ in range(m)])

# using the lazy here will speed up matrix operations
BlockDiagLazyTensor(C_indep).evaluate()

tensor([[1.9865, 1.0179, 0.7360, 0.0000, 0.0000, 0.0000],
        [1.0179, 0.9832, 0.3909, 0.0000, 0.0000, 0.0000],
        [0.7360, 0.3909, 1.0778, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.8661, 1.1281, 0.8259],
        [0.0000, 0.0000, 0.0000, 1.1281, 1.9154, 0.5993],
        [0.0000, 0.0000, 0.0000, 0.8259, 0.5993, 1.0213]])

In [30]:
C_indep

tensor([[[1.9865, 1.0179, 0.7360],
         [1.0179, 0.9832, 0.3909],
         [0.7360, 0.3909, 1.0778]],

        [[1.8661, 1.1281, 0.8259],
         [1.1281, 1.9154, 0.5993],
         [0.8259, 0.5993, 1.0213]]])

In [31]:
# how do we construct the full matrix (we don't want to do this in general, but maybe sometimes)
# Once https://github.com/pytorch/pytorch/issues/31932 goes in, we can just use that instead

batch_shape = C_indep.shape[:-3]
out = torch.zeros(*batch_shape, m*n, m*n, device=C_indep.device, dtype=C_indep.dtype)
for i, vals in enumerate(C_indep):
    start = i*n
    end = start + n
    out[..., start:end, start:end].copy_(vals)
    
out

tensor([[1.9865, 1.0179, 0.7360, 0.0000, 0.0000, 0.0000],
        [1.0179, 0.9832, 0.3909, 0.0000, 0.0000, 0.0000],
        [0.7360, 0.3909, 1.0778, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.8661, 1.1281, 0.8259],
        [0.0000, 0.0000, 0.0000, 1.1281, 1.9154, 0.5993],
        [0.0000, 0.0000, 0.0000, 0.8259, 0.5993, 1.0213]])

# We can also allow arbitrary order in the construction...

### ... so long as we have a consistent internal representation

The following orders are admissible (`n` is the number of points, `t` is the number tasks):

```
nntt
ntnt
nttn
ttnn
tntn
tnnt

nnt
ntt
ntn
ttn
tnt
tnn
```

So basically we have the set 
```
nnt
ntt
ntn
ttn
tnt
tnn
```

and then the whole set of admissible combinations we get by combining this wiht the set we get by post-pending with the single letter (we could pre-pend too but that would just generate duplicates

Now we know this is admissible. We'd like to store things consistently internally. There are four options:

- no independence -> can standardize to `n x t x n x t`
- cross-task independence only -> can standardize to `t x n x n`
- cross-point independence only -> can standardize to `n x t x t`
- cross-task AND cross-point independence: This is trivial -> can standardize to `n x t` (just marginal variances)

**For now we focus on the first two cases**, we can deal with the other ones later.

In [52]:
from gpytorch.distributions.new_multitask_multivariate_normal import MultitaskMultivariateNormal

ModuleNotFoundError: No module named 'gpytorch.distributions.new_multitask_multivariate_normal'

In [32]:
mean_b = torch.randn(b, n, m)

mtmvn = MultitaskMultivariateNormal(mean=mean_b, covariance=C_alt_b, order=("n" ,"t", "n", "t"))

In [33]:
mtmvn.mean.shape

torch.Size([2, 3, 2])

In [34]:
mtmvn.variance.shape

torch.Size([2, 3, 2])

In [35]:
mtmvn.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [36]:
mtmvn.log_prob(mtmvn.mean + torch.randn_like(mtmvn.mean))

tensor([ -9.1266, -28.4039])

Independent tasks

In [37]:
C_indep_b = torch.stack([C_indep, C_indep + 1])

mtmvn_indep = MultitaskMultivariateNormal(mean=mean_b, covariance=C_indep_b, order=("t", "n", "n"))

In [38]:
mtmvn_indep.mean.shape

torch.Size([2, 3, 2])

In [39]:
mtmvn_indep.variance.shape

torch.Size([2, 3, 2])

In [40]:
mtmvn_indep.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [41]:
mtmvn_indep.log_prob(mtmvn_indep.mean + torch.randn_like(mtmvn_indep.mean))

tensor([ -7.5314, -11.0381])

# Kronecker product structure

In [42]:
from gpytorch.lazy import KroneckerProductLazyTensor

In [43]:
batch_shape = torch.Size([2])

Kn = make_rand_covar(n, batch_shape)
Km = make_rand_covar(m, batch_shape)

covar = KroneckerProductLazyTensor(Km, Kn)

In [44]:
Kn

tensor([[[1.2282, 0.2174, 0.5029],
         [0.2174, 0.8848, 0.4578],
         [0.5029, 0.4578, 1.3140]],

        [[2.1553, 1.2025, 0.7948],
         [1.2025, 1.7784, 0.6121],
         [0.7948, 0.6121, 1.5613]]])

In [45]:
Km

tensor([[[0.5423, 0.1257],
         [0.1257, 0.9608]],

        [[0.8937, 0.7911],
         [0.7911, 2.2355]]])

In [46]:
covar.evaluate()

tensor([[[0.6661, 0.1179, 0.2727, 0.1544, 0.0273, 0.0632],
         [0.1179, 0.4798, 0.2483, 0.0273, 0.1113, 0.0576],
         [0.2727, 0.2483, 0.7125, 0.0632, 0.0576, 0.1652],
         [0.1544, 0.0273, 0.0632, 1.1801, 0.2089, 0.4832],
         [0.0273, 0.1113, 0.0576, 0.2089, 0.8501, 0.4399],
         [0.0632, 0.0576, 0.1652, 0.4832, 0.4399, 1.2624]],

        [[1.9263, 1.0747, 0.7104, 1.7050, 0.9513, 0.6287],
         [1.0747, 1.5894, 0.5471, 0.9513, 1.4068, 0.4842],
         [0.7104, 0.5471, 1.3954, 0.6287, 0.4842, 1.2351],
         [1.7050, 0.9513, 0.6287, 4.8183, 2.6882, 1.7768],
         [0.9513, 1.4068, 0.4842, 2.6882, 3.9756, 1.3684],
         [0.6287, 0.4842, 1.2351, 1.7768, 1.3684, 3.4904]]])

In [51]:
# can construct the full tensor from the kronecker product (we DON'T want to do this)

Dense = Km.unsqueeze(-1).unsqueeze(-1) * Kn.unsqueeze(-3).unsqueeze(-3).expand(*Km.shape, *Kn.shape[-2:])
Dense.permute(0, 1, 3, 2, 4).reshape(*Km.shape[:-2], m*n, m*n)

tensor([[[0.6661, 0.1179, 0.2727, 0.1544, 0.0273, 0.0632],
         [0.1179, 0.4798, 0.2483, 0.0273, 0.1113, 0.0576],
         [0.2727, 0.2483, 0.7125, 0.0632, 0.0576, 0.1652],
         [0.1544, 0.0273, 0.0632, 1.1801, 0.2089, 0.4832],
         [0.0273, 0.1113, 0.0576, 0.2089, 0.8501, 0.4399],
         [0.0632, 0.0576, 0.1652, 0.4832, 0.4399, 1.2624]],

        [[1.9263, 1.0747, 0.7104, 1.7050, 0.9513, 0.6287],
         [1.0747, 1.5894, 0.5471, 0.9513, 1.4068, 0.4842],
         [0.7104, 0.5471, 1.3954, 0.6287, 0.4842, 1.2351],
         [1.7050, 0.9513, 0.6287, 4.8183, 2.6882, 1.7768],
         [0.9513, 1.4068, 0.4842, 2.6882, 3.9756, 1.3684],
         [0.6287, 0.4842, 1.2351, 1.7768, 1.3684, 3.4904]]])

# More testing

In [4]:
import torch
from gpytorch.distributions.multioutput_multivariate_normal import MultioutputMultivariateNormal

In [5]:
n, m = C_inter.shape[0], C_inter.shape[2]
mean = torch.rand(n, m)

In [6]:
C_inter.shape

torch.Size([3, 3, 2, 2])

In [7]:
n, m

(3, 2)

In [11]:
momvn = MultioutputMultivariateNormal(mean, C_inter, "nnmm")

In [30]:
momvn.log_prob(torch.rand(3, 2))

tensor([-6.9305])

In [12]:
momvn._covar.shape

torch.Size([3, 2, 3, 2])

In [13]:
momvn._nlzd_order

<NormalizedOrder.FULL: 'nmnm'>

In [14]:
momvn.event_shape

torch.Size([3, 2])

In [15]:
momvn.sample(torch.Size([2]))

tensor([[[ 0.1760, -0.0507],
         [-1.2688,  1.2880],
         [ 0.0351,  0.1168]],

        [[ 2.4282,  3.6193],
         [ 3.2335,  2.6217],
         [ 1.4828,  2.8011]]])

In [16]:
momvn.batch_shape

torch.Size([])

In [17]:
momvn.event_shape

torch.Size([3, 2])

In [18]:
momvn.variance.shape

torch.Size([3, 2])

In [19]:
momvn.mean.shape

torch.Size([3, 2])

In [20]:
momvn._covar.shape

torch.Size([3, 2, 3, 2])

In [25]:
msub = momvn[:1]

In [26]:
msub.event_shape

torch.Size([1, 2])

In [27]:
msub.mean

tensor([[0.9740, 0.5032]])

In [28]:
msub.variance

tensor([[2.1095, 2.7778]])

In [29]:
msub._covar.shape

torch.Size([1, 2, 1, 2])

In [31]:
msub.log_prob(torch.rand(1, 2))

tensor([-2.3864])

In [32]:
msub = momvn[:2, 1:]

In [33]:
msub.mean

tensor([[0.5032],
        [0.2452]])

In [34]:
msub.variance

tensor([[2.7778],
        [2.8220]])

In [35]:
momvn.mean

tensor([[0.9740, 0.5032],
        [0.3761, 0.2452],
        [0.1522, 0.1482]])