# Alternative representation of MultitaskMultivariateNormal

## Issue:
Currently, the interleaved/non-interleaved representation of MTMVN creates a bunch of headaches. It would be nice to have a higher-level API that abstracts away from these details. 


## Suggestion:
Represent the covariance matrix as an "unrolled" tensor (e.g. `n x n x m x m`), instead of a `nm x nm` covariance matrix. That way things like scalarizations are easily done, and reshaping/viewing/getting items can be done more straightforwardly. 


## Challenge:
In order to sample from this MVN, we need to compute the full `nm x nm` covariance matrix (to either compute the cholesky decomposition or apply iterative approximate root decomposition methods. We need to make sure this is transparent, fast, and happens without much overhead.

In [1]:
from __future__ import annotations

import torch

In [2]:
n = 3  # number of points
m = 2  # number of outputs

# create some full covar

def make_rand_covar(k):
    a = torch.rand(k, k)
    return a @ a.t() + torch.diag_embed(torch.rand(k))
    
A = make_rand_covar(n * m)

### interleaved: block matrix where each block is an intra-point, cross-task covariance

In [3]:
C_inter = torch.zeros(n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for i in range(n):
    for i_ in range(n):
        for j in range(m):
            for j_ in range(m):
                C_inter[i, i_, j, j_] = A[i*m+j, i_*m+j_]

In [4]:
A

tensor([[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
        [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
        [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
        [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
        [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
        [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]])

In [5]:
# we can construct the full matrix as follows:
C_inter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
        [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
        [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
        [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
        [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
        [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]])

In [6]:
# batched version
b = 2
A_b = torch.stack([A, A + 1])
C_inter_b = torch.zeros(b, n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for i_ in range(n):
            for j in range(m):
                for j_ in range(m):
                    C_inter_b[b_, i, i_, j, j_] = A_b[b_, i*m+j, i_*m+j_]
                    
C_inter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n)

tensor([[[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
         [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
         [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
         [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
         [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
         [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]],

        [[3.7159, 2.0785, 2.1687, 2.3138, 2.2568, 2.6569],
         [2.0785, 2.1027, 1.7116, 1.5418, 1.7340, 2.2092],
         [2.1687, 1.7116, 2.1992, 1.5868, 1.7759, 2.0248],
         [2.3138, 1.5418, 1.5868, 3.1368, 2.1354, 2.3857],
         [2.2568, 1.7340, 1.7759, 2.1354, 3.5999, 2.7795],
         [2.6569, 2.2092, 2.0248, 2.3857, 2.7795, 3.8634]]])

In [7]:
# the general formulation:
batch_shape = C_inter_b.shape[:-4]
C_inter_b.permute(*range(len(batch_shape)), -4, -2, -3, -1).reshape(*batch_shape, m*n, m*n)

tensor([[[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
         [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
         [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
         [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
         [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
         [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]],

        [[3.7159, 2.0785, 2.1687, 2.3138, 2.2568, 2.6569],
         [2.0785, 2.1027, 1.7116, 1.5418, 1.7340, 2.2092],
         [2.1687, 1.7116, 2.1992, 1.5868, 1.7759, 2.0248],
         [2.3138, 1.5418, 1.5868, 3.1368, 2.1354, 2.3857],
         [2.2568, 1.7340, 1.7759, 2.1354, 3.5999, 2.7795],
         [2.6569, 2.2092, 2.0248, 2.3857, 2.7795, 3.8634]]])

In [8]:
# how do we go back? I.e. construct C_inter efficiently form the full matrix? Just do the same thing in reverse

batch_shape = A_b.shape[:-2]
C_inter_b_recov = A_b.reshape(*batch_shape, n, m, n, m).permute(*range(len(batch_shape)), -4, -2, -3, -1)

In [9]:
torch.allclose(C_inter_b, C_inter_b_recov)

True

### non-interleaved: block matrix where each block is an intra-task, cross-point covariance

In [10]:
C_noninter = torch.zeros(m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for i in range(m):
    for i_ in range(m):
        for j in range(n):
            for j_ in range(n):
                C_noninter[i, i_, j, j_] = A[i*n+j, i_*n+j_]

In [11]:
A

tensor([[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
        [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
        [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
        [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
        [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
        [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]])

In [12]:
# again we can construct the matrix as follows:
C_noninter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
        [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
        [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
        [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
        [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
        [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]])

In [13]:
# batched version
C_noninter_b = torch.zeros(b, m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(m):
        for i_ in range(m):
            for j in range(n):
                for j_ in range(n):
                    C_noninter_b[b_, i, i_, j, j_] = A_b[b_, i*n+j, i_*n+j_]
                    
torch.allclose(
    C_noninter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n),
    A_b,
)

True

In [14]:
# and again we can go back the same way

batch_shape = A_b.shape[:-2]
C_noninter_b_recov = A_b.reshape(*batch_shape, m, n, m, n).permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_noninter_b, C_noninter_b_recov)

True

### alternate mixed-interleaved: block matrix where each block is an intra-point, cross-task covariance

This seems to be the most useful internal representation, as this means we don't have to do any permuting

In [15]:
# alternate representation
C_alt = torch.zeros(n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for i in range(n):
    for j in range(m):
        for i_ in range(n):    
            for j_ in range(m):
                C_alt[i, j, i_, j_] = A[i*m+j, i_*m+j_]

In [16]:
A

tensor([[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
        [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
        [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
        [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
        [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
        [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]])

In [17]:
# this is super straightforward
C_alt.view(n*m, n*m)

tensor([[2.7159, 1.0785, 1.1687, 1.3138, 1.2568, 1.6569],
        [1.0785, 1.1027, 0.7116, 0.5418, 0.7340, 1.2092],
        [1.1687, 0.7116, 1.1992, 0.5868, 0.7759, 1.0248],
        [1.3138, 0.5418, 0.5868, 2.1368, 1.1354, 1.3857],
        [1.2568, 0.7340, 0.7759, 1.1354, 2.5999, 1.7795],
        [1.6569, 1.2092, 1.0248, 1.3857, 1.7795, 2.8634]])

In [18]:
# alternate representation
C_alt_b = torch.zeros(b, n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for j in range(m):
            for i_ in range(n):    
                for j_ in range(m):
                    C_alt_b[b_, i, j, i_, j_] = A_b[b_, i*m+j, i_*m+j_]

batch_shape = C_alt_b.shape[:-4]
torch.allclose(
    C_alt_b.view(*batch_shape, n*m, n*m),
    A_b,
)

True

In [19]:
# ....and going back

batch_shape = A_b.shape[:-2]
C_alt_b_recov = A_b.reshape(*batch_shape, n, m, n, m) #.permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_alt_b, C_alt_b_recov)

True

## Question: What if the tensor memory layout is different? Do we need to handle that?

### Scalarizing

This representation makes scalarizing across the outputs trivial

In [20]:
weights = torch.rand(m)
(C_alt @ weights).transpose(-1, -2) @ weights

tensor([[1.1511, 0.6632, 0.7950],
        [0.6632, 0.6734, 0.6435],
        [0.7950, 0.6435, 1.4682]])

## Representing independet outputs

Here we can just store the `m` individual `n x n` blocks as a `m x n x n` tensor, no need to store the cross covariances.

We could also think of the case where we have independence across points, and only inter-task correlation. This will typically not be the case though, so we can punt on this for now.

In [21]:
from gpytorch.lazy import BlockDiagLazyTensor

C_indep = torch.stack([make_rand_covar(n) for _ in range(m)])

# using the lazy here will speed up matrix operations
BlockDiagLazyTensor(C_indep).evaluate()

tensor([[1.4507, 0.7111, 0.4856, 0.0000, 0.0000, 0.0000],
        [0.7111, 1.3544, 0.1176, 0.0000, 0.0000, 0.0000],
        [0.4856, 0.1176, 1.4699, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 2.0618, 0.7923, 0.7034],
        [0.0000, 0.0000, 0.0000, 0.7923, 1.4361, 0.1751],
        [0.0000, 0.0000, 0.0000, 0.7034, 0.1751, 1.9329]])

In [22]:
C_indep

tensor([[[1.4507, 0.7111, 0.4856],
         [0.7111, 1.3544, 0.1176],
         [0.4856, 0.1176, 1.4699]],

        [[2.0618, 0.7923, 0.7034],
         [0.7923, 1.4361, 0.1751],
         [0.7034, 0.1751, 1.9329]]])

In [23]:
# how do we construct the full matrix (we don't want to do this in general, but maybe sometimes)
# Once https://github.com/pytorch/pytorch/issues/31932 goes in, we can just use that instead

batch_shape = C_indep.shape[:-3]
out = torch.zeros(*batch_shape, m*n, m*n, device=C_indep.device, dtype=C_indep.dtype)
for i, vals in enumerate(C_indep):
    start = i*n
    end = start + n
    out[..., start:end, start:end].copy_(vals)
    
out

tensor([[1.4507, 0.7111, 0.4856, 0.0000, 0.0000, 0.0000],
        [0.7111, 1.3544, 0.1176, 0.0000, 0.0000, 0.0000],
        [0.4856, 0.1176, 1.4699, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 2.0618, 0.7923, 0.7034],
        [0.0000, 0.0000, 0.0000, 0.7923, 1.4361, 0.1751],
        [0.0000, 0.0000, 0.0000, 0.7034, 0.1751, 1.9329]])

# We can also allow arbitrary order in the construction...

### ... so long as we have a consistent internal representation

The following orders are admissible (`n` is the number of points, `t` is the number tasks):

```
nntt
ntnt
nttn
ttnn
tntn
tnnt

nnt
ntt
ntn
ttn
tnt
tnn
```

So basically we have the set 
```
nnt
ntt
ntn
ttn
tnt
tnn
```

and then the whole set of admissible combinations we get by combining this wiht the set we get by post-pending with the single letter (we could pre-pend too but that would just generate duplicates

Now we know this is admissible. We'd like to store things consistently internally. There are four options:

- no independence -> can standardize to `n x t x n x t`
- cross-task independence only -> can standardize to `t x n x n`
- cross-point independence only -> can standardize to `n x t x t`
- cross-task AND cross-point independence: This is trivial -> can standardize to `n x t` (just marginal variances)

**For now we focus on the first two cases**, we can deal with the other ones later.

In [24]:
from gpytorch.distributions.new_multitask_multivariate_normal import MultitaskMultivariateNormal

In [25]:
mean_b = torch.randn(b, n, m)

mtmvn = MultitaskMultivariateNormal(mean=mean_b, covariance=C_alt_b, order="ntnt")

In [26]:
mtmvn.mean.shape

torch.Size([2, 3, 2])

In [27]:
mtmvn.variance.shape

torch.Size([2, 3, 2])

In [28]:
mtmvn.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [29]:
mtmvn.log_prob(mtmvn.mean + torch.randn_like(mtmvn.mean))

tensor([-7.0052, -8.6431])

Independent tasks

In [30]:
C_indep_b = torch.stack([C_indep, C_indep + 1])

mtmvn_indep = MultitaskMultivariateNormal(mean=mean_b, covariance=C_indep_b, order="tnn")

In [31]:
mtmvn_indep.mean.shape

torch.Size([2, 3, 2])

In [32]:
mtmvn_indep.variance.shape

torch.Size([2, 3, 2])

In [33]:
mtmvn_indep.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [34]:
mtmvn_indep.log_prob(mtmvn_indep.mean + torch.randn_like(mtmvn_indep.mean))

tensor([-10.5470, -12.9468])