# Alternative representation of MultitaskMultivariateNormal

## Issue:
Currently, the interleaved/non-interleaved representation of MTMVN creates a bunch of headaches. It would be nice to have a higher-level API that abstracts away from these details. 


## Suggestion:
Represent the covariance matrix as an "unrolled" tensor (e.g. `n x n x m x m`), instead of a `nm x nm` covariance matrix. That way things like scalarizations are easily done, and reshaping/viewing/getting items can be done more straightforwardly. 


## Challenge:
In order to sample from this MVN, we need to compute the full `nm x nm` covariance matrix (to either compute the cholesky decomposition or apply iterative approximate root decomposition methods. We need to make sure this is transparent, fast, and happens without much overhead.

In [4]:
from __future__ import annotations

import torch

In [5]:
n = 3  # number of points
m = 2  # number of outputs

# create some full covar

def make_rand_covar(k, batch_shape=torch.Size()):
    a = torch.rand(*batch_shape, k, k)
    return a @ a.transpose(-1, -2) + torch.diag_embed(torch.rand(k))
    
A = make_rand_covar(n * m)

### interleaved: block matrix where each block is an intra-point, cross-task covariance

In [6]:
C_inter = torch.zeros(n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for i in range(n):
    for i_ in range(n):
        for j in range(m):
            for j_ in range(m):
                C_inter[i, i_, j, j_] = A[i*m+j, i_*m+j_]

In [7]:
A

tensor([[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
        [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
        [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
        [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
        [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
        [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]])

In [8]:
# we can construct the full matrix as follows:
C_inter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
        [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
        [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
        [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
        [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
        [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]])

In [9]:
# batched version
b = 2
A_b = torch.stack([A, A + 1])
C_inter_b = torch.zeros(b, n, n, m, m)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for i_ in range(n):
            for j in range(m):
                for j_ in range(m):
                    C_inter_b[b_, i, i_, j, j_] = A_b[b_, i*m+j, i_*m+j_]
                    
C_inter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n)

tensor([[[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
         [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
         [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
         [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
         [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
         [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]],

        [[3.7497, 2.9997, 3.5260, 2.8324, 2.0587, 2.5959],
         [2.9997, 3.3089, 3.5956, 2.2006, 1.9716, 2.3898],
         [3.5260, 3.5956, 4.9785, 2.4866, 2.2309, 2.7182],
         [2.8324, 2.2006, 2.4866, 2.8368, 1.6344, 2.1747],
         [2.0587, 1.9716, 2.2309, 1.6344, 2.4897, 1.7867],
         [2.5959, 2.3898, 2.7182, 2.1747, 1.7867, 2.8870]]])

In [10]:
# the general formulation:
batch_shape = C_inter_b.shape[:-4]
C_inter_b.permute(*range(len(batch_shape)), -4, -2, -3, -1).reshape(*batch_shape, m*n, m*n)

tensor([[[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
         [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
         [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
         [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
         [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
         [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]],

        [[3.7497, 2.9997, 3.5260, 2.8324, 2.0587, 2.5959],
         [2.9997, 3.3089, 3.5956, 2.2006, 1.9716, 2.3898],
         [3.5260, 3.5956, 4.9785, 2.4866, 2.2309, 2.7182],
         [2.8324, 2.2006, 2.4866, 2.8368, 1.6344, 2.1747],
         [2.0587, 1.9716, 2.2309, 1.6344, 2.4897, 1.7867],
         [2.5959, 2.3898, 2.7182, 2.1747, 1.7867, 2.8870]]])

In [11]:
# how do we go back? I.e. construct C_inter efficiently form the full matrix? Just do the same thing in reverse

batch_shape = A_b.shape[:-2]
C_inter_b_recov = A_b.reshape(*batch_shape, n, m, n, m).permute(*range(len(batch_shape)), -4, -2, -3, -1)

In [12]:
torch.allclose(C_inter_b, C_inter_b_recov)

True

### non-interleaved: block matrix where each block is an intra-task, cross-point covariance

In [13]:
C_noninter = torch.zeros(m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for i in range(m):
    for i_ in range(m):
        for j in range(n):
            for j_ in range(n):
                C_noninter[i, i_, j, j_] = A[i*n+j, i_*n+j_]

In [14]:
A

tensor([[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
        [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
        [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
        [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
        [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
        [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]])

In [15]:
# again we can construct the matrix as follows:
C_noninter.permute(0, 2, 1, 3).reshape(m*n, m*n)

tensor([[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
        [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
        [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
        [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
        [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
        [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]])

In [16]:
# batched version
C_noninter_b = torch.zeros(b, m, m, n, n)

# this is obviously super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(m):
        for i_ in range(m):
            for j in range(n):
                for j_ in range(n):
                    C_noninter_b[b_, i, i_, j, j_] = A_b[b_, i*n+j, i_*n+j_]
                    
torch.allclose(
    C_noninter_b.permute(0, 1, 3, 2, 4).reshape(b, m*n, m*n),
    A_b,
)

True

In [17]:
# and again we can go back the same way

batch_shape = A_b.shape[:-2]
C_noninter_b_recov = A_b.reshape(*batch_shape, m, n, m, n).permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_noninter_b, C_noninter_b_recov)

True

### alternate mixed-interleaved: block matrix where each block is an intra-point, cross-task covariance

This seems to be the most useful internal representation, as this means we don't have to do any permuting

In [18]:
# alternate representation
C_alt = torch.zeros(n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for i in range(n):
    for j in range(m):
        for i_ in range(n):    
            for j_ in range(m):
                C_alt[i, j, i_, j_] = A[i*m+j, i_*m+j_]

In [19]:
A

tensor([[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
        [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
        [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
        [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
        [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
        [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]])

In [20]:
# this is super straightforward
C_alt.view(n*m, n*m)

tensor([[2.7497, 1.9997, 2.5260, 1.8324, 1.0587, 1.5959],
        [1.9997, 2.3089, 2.5956, 1.2006, 0.9716, 1.3898],
        [2.5260, 2.5956, 3.9785, 1.4866, 1.2309, 1.7182],
        [1.8324, 1.2006, 1.4866, 1.8368, 0.6344, 1.1747],
        [1.0587, 0.9716, 1.2309, 0.6344, 1.4897, 0.7867],
        [1.5959, 1.3898, 1.7182, 1.1747, 0.7867, 1.8870]])

In [21]:
# alternate representation
C_alt_b = torch.zeros(b, n, m, n, m)

# this is (obciously) super inefficient, but it makes clear what we want
for b_ in range(b):
    for i in range(n):
        for j in range(m):
            for i_ in range(n):    
                for j_ in range(m):
                    C_alt_b[b_, i, j, i_, j_] = A_b[b_, i*m+j, i_*m+j_]

batch_shape = C_alt_b.shape[:-4]
torch.allclose(
    C_alt_b.view(*batch_shape, n*m, n*m),
    A_b,
)

True

In [22]:
# ....and going back

batch_shape = A_b.shape[:-2]
C_alt_b_recov = A_b.reshape(*batch_shape, n, m, n, m) #.permute(*range(len(batch_shape)), -4, -2, -3, -1)
torch.allclose(C_alt_b, C_alt_b_recov)

True

## Question: What if the tensor memory layout is different? Do we need to handle that?

### Scalarizing

This representation makes scalarizing across the outputs trivial

In [23]:
weights = torch.rand(m)
(C_alt @ weights).transpose(-1, -2) @ weights

tensor([[0.7856, 0.6805, 0.4484],
        [0.6805, 0.7218, 0.4162],
        [0.4484, 0.4162, 0.4479]])

## Representing independet outputs

Here we can just store the `m` individual `n x n` blocks as a `m x n x n` tensor, no need to store the cross covariances.

We could also think of the case where we have independence across points, and only inter-task correlation. This will typically not be the case though, so we can punt on this for now.

In [24]:
from gpytorch.lazy import BlockDiagLazyTensor

C_indep = torch.stack([make_rand_covar(n) for _ in range(m)])

# using the lazy here will speed up matrix operations
BlockDiagLazyTensor(C_indep).evaluate()

tensor([[0.7513, 0.2001, 0.3947, 0.0000, 0.0000, 0.0000],
        [0.2001, 1.3723, 0.1992, 0.0000, 0.0000, 0.0000],
        [0.3947, 0.1992, 1.4604, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.0000, 0.7192, 0.7787],
        [0.0000, 0.0000, 0.0000, 0.7192, 1.9332, 1.2580],
        [0.0000, 0.0000, 0.0000, 0.7787, 1.2580, 1.5258]])

In [25]:
C_indep

tensor([[[0.7513, 0.2001, 0.3947],
         [0.2001, 1.3723, 0.1992],
         [0.3947, 0.1992, 1.4604]],

        [[1.0000, 0.7192, 0.7787],
         [0.7192, 1.9332, 1.2580],
         [0.7787, 1.2580, 1.5258]]])

In [26]:
# how do we construct the full matrix (we don't want to do this in general, but maybe sometimes)
# Once https://github.com/pytorch/pytorch/issues/31932 goes in, we can just use that instead

batch_shape = C_indep.shape[:-3]
out = torch.zeros(*batch_shape, m*n, m*n, device=C_indep.device, dtype=C_indep.dtype)
for i, vals in enumerate(C_indep):
    start = i*n
    end = start + n
    out[..., start:end, start:end].copy_(vals)
    
out

tensor([[0.7513, 0.2001, 0.3947, 0.0000, 0.0000, 0.0000],
        [0.2001, 1.3723, 0.1992, 0.0000, 0.0000, 0.0000],
        [0.3947, 0.1992, 1.4604, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.0000, 0.7192, 0.7787],
        [0.0000, 0.0000, 0.0000, 0.7192, 1.9332, 1.2580],
        [0.0000, 0.0000, 0.0000, 0.7787, 1.2580, 1.5258]])

# We can also allow arbitrary order in the construction...

### ... so long as we have a consistent internal representation

The following orders are admissible (`n` is the number of points, `t` is the number tasks):

```
nntt
ntnt
nttn
ttnn
tntn
tnnt

nnt
ntt
ntn
ttn
tnt
tnn
```

So basically we have the set 
```
nnt
ntt
ntn
ttn
tnt
tnn
```

and then the whole set of admissible combinations we get by combining this wiht the set we get by post-pending with the single letter (we could pre-pend too but that would just generate duplicates

Now we know this is admissible. We'd like to store things consistently internally. There are four options:

- no independence -> can standardize to `n x t x n x t`
- cross-task independence only -> can standardize to `t x n x n`
- cross-point independence only -> can standardize to `n x t x t`
- cross-task AND cross-point independence: This is trivial -> can standardize to `n x t` (just marginal variances)

**For now we focus on the first two cases**, we can deal with the other ones later.

In [27]:
from gpytorch.distributions.new_multitask_multivariate_normal import MultitaskMultivariateNormal

In [32]:
mean_b = torch.randn(b, n, m)

mtmvn = MultitaskMultivariateNormal(mean=mean_b, covariance=C_alt_b, order=("n" ,"t", "n", "t"))

In [33]:
mtmvn.mean.shape

torch.Size([2, 3, 2])

In [34]:
mtmvn.variance.shape

torch.Size([2, 3, 2])

In [35]:
mtmvn.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [36]:
mtmvn.log_prob(mtmvn.mean + torch.randn_like(mtmvn.mean))

tensor([ -9.1266, -28.4039])

Independent tasks

In [37]:
C_indep_b = torch.stack([C_indep, C_indep + 1])

mtmvn_indep = MultitaskMultivariateNormal(mean=mean_b, covariance=C_indep_b, order=("t", "n", "n"))

In [38]:
mtmvn_indep.mean.shape

torch.Size([2, 3, 2])

In [39]:
mtmvn_indep.variance.shape

torch.Size([2, 3, 2])

In [40]:
mtmvn_indep.rsample(torch.Size([4])).shape

torch.Size([4, 2, 3, 2])

In [41]:
mtmvn_indep.log_prob(mtmvn_indep.mean + torch.randn_like(mtmvn_indep.mean))

tensor([ -7.5314, -11.0381])

# Kronecker product structure

In [42]:
from gpytorch.lazy import KroneckerProductLazyTensor

In [43]:
batch_shape = torch.Size([2])

Kn = make_rand_covar(n, batch_shape)
Km = make_rand_covar(m, batch_shape)

covar = KroneckerProductLazyTensor(Km, Kn)

In [44]:
Kn

tensor([[[1.2282, 0.2174, 0.5029],
         [0.2174, 0.8848, 0.4578],
         [0.5029, 0.4578, 1.3140]],

        [[2.1553, 1.2025, 0.7948],
         [1.2025, 1.7784, 0.6121],
         [0.7948, 0.6121, 1.5613]]])

In [45]:
Km

tensor([[[0.5423, 0.1257],
         [0.1257, 0.9608]],

        [[0.8937, 0.7911],
         [0.7911, 2.2355]]])

In [46]:
covar.evaluate()

tensor([[[0.6661, 0.1179, 0.2727, 0.1544, 0.0273, 0.0632],
         [0.1179, 0.4798, 0.2483, 0.0273, 0.1113, 0.0576],
         [0.2727, 0.2483, 0.7125, 0.0632, 0.0576, 0.1652],
         [0.1544, 0.0273, 0.0632, 1.1801, 0.2089, 0.4832],
         [0.0273, 0.1113, 0.0576, 0.2089, 0.8501, 0.4399],
         [0.0632, 0.0576, 0.1652, 0.4832, 0.4399, 1.2624]],

        [[1.9263, 1.0747, 0.7104, 1.7050, 0.9513, 0.6287],
         [1.0747, 1.5894, 0.5471, 0.9513, 1.4068, 0.4842],
         [0.7104, 0.5471, 1.3954, 0.6287, 0.4842, 1.2351],
         [1.7050, 0.9513, 0.6287, 4.8183, 2.6882, 1.7768],
         [0.9513, 1.4068, 0.4842, 2.6882, 3.9756, 1.3684],
         [0.6287, 0.4842, 1.2351, 1.7768, 1.3684, 3.4904]]])

In [51]:
# can construct the full tensor from the kronecker product (we DON'T want to do this)

Dense = Km.unsqueeze(-1).unsqueeze(-1) * Kn.unsqueeze(-3).unsqueeze(-3).expand(*Km.shape, *Kn.shape[-2:])
Dense.permute(0, 1, 3, 2, 4).reshape(*Km.shape[:-2], m*n, m*n)

tensor([[[0.6661, 0.1179, 0.2727, 0.1544, 0.0273, 0.0632],
         [0.1179, 0.4798, 0.2483, 0.0273, 0.1113, 0.0576],
         [0.2727, 0.2483, 0.7125, 0.0632, 0.0576, 0.1652],
         [0.1544, 0.0273, 0.0632, 1.1801, 0.2089, 0.4832],
         [0.0273, 0.1113, 0.0576, 0.2089, 0.8501, 0.4399],
         [0.0632, 0.0576, 0.1652, 0.4832, 0.4399, 1.2624]],

        [[1.9263, 1.0747, 0.7104, 1.7050, 0.9513, 0.6287],
         [1.0747, 1.5894, 0.5471, 0.9513, 1.4068, 0.4842],
         [0.7104, 0.5471, 1.3954, 0.6287, 0.4842, 1.2351],
         [1.7050, 0.9513, 0.6287, 4.8183, 2.6882, 1.7768],
         [0.9513, 1.4068, 0.4842, 2.6882, 3.9756, 1.3684],
         [0.6287, 0.4842, 1.2351, 1.7768, 1.3684, 3.4904]]])