In [1]:
from gpytorch.lazy import CatLazyTensor, NonLazyTensor
import torch, gpytorch

In [2]:
# Non-batched cat dim 0 pass
x = NonLazyTensor(torch.randn(5,1))
y = NonLazyTensor(torch.randn(4,1))

z = CatLazyTensor(*[x,y], dim=0)
z.size()

torch.Size([9, 1])

In [4]:
# Non-batched cat dim 1 pass
x = NonLazyTensor(torch.randn(4,3))
y = NonLazyTensor(torch.randn(4,2))

#%debug
z = CatLazyTensor(*[x,y], dim=1)
z.size()

torch.Size([4, 5])

In [6]:
z.evaluate()

tensor([[-0.0777, -2.9289,  1.7987, -1.3053, -1.6413],
        [-1.7167,  0.4824, -0.3797, -1.5872, -0.0787],
        [ 0.8224,  1.3980,  0.2137, -1.7026, -0.8035],
        [ 1.4306,  0.0311,  0.3998, -0.2568,  0.7621]])

In [7]:
# Non-batched cat dim 0 fail
x = NonLazyTensor(torch.randn(5,3))
y = NonLazyTensor(torch.randn(4,2))

z = CatLazyTensor(*[x,y], dim=0)

RuntimeError: All LazyTensors must have the same size in the non-concatenation dimension

In [8]:
# Non-batched cat dim 1 fail
x = NonLazyTensor(torch.randn(5,3))
y = NonLazyTensor(torch.randn(4,2))
z = CatLazyTensor(*[x,y], dim=1)

RuntimeError: All LazyTensors must have the same size in the non-concatenation dimension

In [3]:
x = torch.tensor([1,2,3])
x.unsqueeze(0).size()

torch.Size([1, 3])

In [4]:
x = torch.randn(5,4,3,2)

In [8]:
x[slice(0,1,None), slice(None, None, None), slice(None, None, None), slice(None, None, None)].size()

torch.Size([1, 4, 3, 2])

In [2]:
# Batched cat dim 0 pass
x = NonLazyTensor(torch.randn(3,4,1))
y = NonLazyTensor(torch.randn(1,4,1))

#%debug
z = CatLazyTensor(*[x,y], dim=0)
print(z.size())
z.evaluate()

torch.Size([4, 4, 1])
> /home/alex/ml/repos/gpytorch/gpytorch/lazy/cat_lazy_tensor.py(94)_matmul()
-> while rhs.ndimension() < self.ndimension():
(Pdb) c


tensor([[[-1.2167],
         [ 0.1334],
         [-0.3411],
         [ 0.2939]],

        [[-0.1123],
         [ 0.0913],
         [ 0.7820],
         [ 0.8626]],

        [[ 1.0736],
         [ 1.1179],
         [ 0.7341],
         [ 0.9163]],

        [[-0.0628],
         [ 0.5939],
         [ 0.3647],
         [-0.2154]]])

In [16]:
# Batched cat dim 1 pass
x = NonLazyTensor(torch.randn(1,5,1))
y = NonLazyTensor(torch.randn(1,4,1))

#%debug
z = CatLazyTensor(*[x,y], dim=1)
print(z.size())
z.evaluate()

torch.Size([1, 9, 1])


tensor([[[-0.6934],
         [-1.4482],
         [-0.7071],
         [-1.3278],
         [-0.6392],
         [-0.6384],
         [ 0.8040],
         [ 1.2791],
         [-1.0186]]])

In [17]:
# Batched cat dim 2 pass
x = NonLazyTensor(torch.randn(2,4,3))
y = NonLazyTensor(torch.randn(2,4,2))
z = CatLazyTensor(*[x,y], dim=2)
print(z.size())
z.evaluate()

torch.Size([2, 4, 5])


tensor([[[ 0.6173, -1.6392,  0.3309,  0.6950,  1.8639],
         [-0.3047,  0.0773, -0.6098, -0.2654,  0.0663],
         [ 1.6305,  0.4748, -0.0281,  1.5121, -0.1897],
         [-0.8223, -1.7555,  0.6705, -0.0978,  0.2077]],

        [[-0.5612, -0.8767,  0.1047,  0.5907, -0.4869],
         [-0.8295, -0.2056, -0.7326,  0.0462, -0.6485],
         [ 1.4546,  1.6113,  0.8614, -0.8414,  0.7671],
         [ 0.0742, -0.4846,  0.0099,  2.7439,  0.2789]]])

In [19]:
# Batched cat dim 1 fail
x = NonLazyTensor(torch.randn(1, 5,3))
y = NonLazyTensor(torch.randn(1, 4,2))
z = CatLazyTensor(*[x,y], dim=1)

RuntimeError: All LazyTensors must have the same size in the non-concatenation dimension

In [18]:
# Batched cat dim 2 fail
x = NonLazyTensor(torch.randn(2, 5,3))
y = NonLazyTensor(torch.randn(2, 4,2))

z = CatLazyTensor(*[x,y], dim=2)

RuntimeError: All LazyTensors must have the same size in the non-concatenation dimension

In [23]:
# _matmul non-batched against 1D, dim=0
x = NonLazyTensor(torch.randn(5,2))
y = NonLazyTensor(torch.randn(4,2))
z = CatLazyTensor(*[x,y], dim=0)

b = torch.randn(2)

print(z._matmul(b))

tensor([ 1.5682, -0.4692,  1.9467,  0.6504,  1.5787, -1.2740, -0.7203, -0.5776,
         0.8800])


In [24]:
# _matmul non-batched against 2D , dim=0
x = NonLazyTensor(torch.randn(5,2))
y = NonLazyTensor(torch.randn(4,2))
z = CatLazyTensor(*[x,y], dim=0)

b = torch.randn(2,3)

z._matmul(b)

tensor([[-1.0397,  0.9819,  2.6137],
        [-0.6649, -0.2920,  1.2983],
        [-0.6485,  0.9176,  1.7541],
        [-0.2305, -0.6749,  0.2173],
        [-1.8805,  2.9155,  5.1901],
        [-0.8279, -0.8363,  1.4246],
        [ 1.5897, -0.7003, -3.6715],
        [ 0.8470, -2.0063, -2.6189],
        [ 0.4620,  0.7941, -0.6621]])

In [25]:
# _matmul batched against 1D, dim=1
x = NonLazyTensor(torch.randn(3,5,2))
y = NonLazyTensor(torch.randn(3,4,2))
z = CatLazyTensor(*[x,y], dim=1)

b = torch.randn(2)

z._matmul(b)

tensor([[-2.1308e-01,  1.0324e-02,  1.3557e-01, -8.6884e-02, -5.6084e-02,
         -4.2278e-01, -6.8678e-02,  2.1316e-01, -1.0412e-01],
        [-3.6976e-02, -2.6244e-02,  2.1939e-02, -3.0445e-04, -8.6439e-03,
          7.3080e-02,  7.2223e-02, -5.6279e-02, -9.2961e-03],
        [-6.6874e-03,  2.9258e-01, -1.8811e-01,  2.5883e-01, -1.4092e-04,
         -2.6298e-02, -8.4113e-03, -7.9362e-03,  5.3528e-02]])

In [26]:
# _matmul batched against 2D, dim=0
x = NonLazyTensor(torch.randn(3,5,2))
y = NonLazyTensor(torch.randn(3,4,2))
z = CatLazyTensor(*[x,y], dim=1)

b = torch.randn(2,6)

z._matmul(b)

tensor([[[-1.6745e-01, -1.9775e-01, -7.4203e-01, -3.3731e+00, -6.8577e-01,
          -5.9837e-01],
         [-1.8873e-01,  2.4295e-01,  1.9252e-01,  2.9572e-01,  9.9126e-02,
          -1.6060e-01],
         [ 1.2610e+00, -1.8315e+00, -1.7465e+00, -3.8086e+00, -1.0523e+00,
           8.4321e-01],
         [ 6.5896e-01, -4.6355e-01,  1.7752e-01,  2.3516e+00,  3.7411e-01,
           9.8511e-01],
         [-1.7409e+00,  2.0945e+00,  1.4522e+00,  1.4388e+00,  6.4003e-01,
          -1.6431e+00],
         [-1.4900e+00,  1.4595e+00,  5.0731e-01, -1.6980e+00, -7.5672e-02,
          -1.7736e+00],
         [ 1.5061e+00, -1.6302e+00, -8.5493e-01,  3.5390e-01, -2.1348e-01,
           1.6220e+00],
         [ 2.0969e+00, -2.1030e+00, -8.2221e-01,  1.9585e+00,  1.4727e-02,
           2.4419e+00],
         [ 4.5337e-01, -5.9462e-01, -4.8679e-01, -8.0724e-01, -2.5873e-01,
           3.7366e-01]],

        [[-6.6169e-03,  2.8002e-01,  6.0642e-01,  2.3987e+00,  5.1176e-01,
           2.9385e-01],
        

In [27]:
# _matmul non-batched against 1D, dim=1
x = NonLazyTensor(torch.randn(3,2))
y = NonLazyTensor(torch.randn(3,5))
z = CatLazyTensor(*[x,y], dim=1)

b = torch.randn(7)

z._matmul(b)

tensor([-0.1509,  0.8554,  0.6617])

In [28]:
# _matmul non-batched against 2D , dim=1
x = NonLazyTensor(torch.randn(7,3))
y = NonLazyTensor(torch.randn(7,2))
z = CatLazyTensor(*[x,y], dim=1)

b = torch.randn(5,3)

z._matmul(b)

tensor([[-0.5197,  0.2915,  1.1359],
        [ 1.8880, -0.1331, -1.0217],
        [-1.0260, -0.1039,  1.2906],
        [ 0.7460,  1.0556,  0.1110],
        [-1.2223,  0.1739,  0.1970],
        [ 2.4149, -0.7756,  0.7513],
        [ 3.0334, -3.3928, -2.1592]])

In [29]:
# _matmul batched against 1D, dim=1
x = NonLazyTensor(torch.randn(3,4,8))
y = NonLazyTensor(torch.randn(3,4,2))
z = CatLazyTensor(*[x,y], dim=2)

b = torch.randn(10)

z._matmul(b)

tensor([[ 2.0460,  1.1405,  3.7088, -3.5238],
        [ 0.4248, -0.1310, -2.9800,  1.2447],
        [ 4.2467, -2.2254,  3.6346,  1.2865]])

In [30]:
# _matmul batched against 2D, dim=1
x = NonLazyTensor(torch.randn(3,4,10))
y = NonLazyTensor(torch.randn(3,4,2))
z = CatLazyTensor(*[x,y], dim=2)

b = torch.randn(12,1)

z._matmul(b)

tensor([[[ 0.1869],
         [ 3.0840],
         [-4.1718],
         [ 5.6867]],

        [[-0.6346],
         [ 1.2843],
         [ 7.0701],
         [-1.1512]],

        [[ 1.8299],
         [ 1.1328],
         [-3.2156],
         [ 0.3970]]])

In [31]:
# transpose
x = NonLazyTensor(torch.randn(3,4,10))
y = NonLazyTensor(torch.randn(3,4,2))

print(z.size(), z.transpose(1, 2).size())

torch.Size([3, 4, 12]) torch.Size([3, 12, 4])


In [32]:
# CatLazyTensor of CatLazyTensors
x = NonLazyTensor(torch.randn(5,1))
y = NonLazyTensor(torch.randn(4,1))
z = CatLazyTensor(*[x,y], dim=0)

zz = CatLazyTensor(*[z,z], dim=1)

In [33]:
zz.size()

torch.Size([9, 2])

In [34]:
zz.evaluate()

tensor([[-0.0991, -0.0991],
        [-0.3708, -0.3708],
        [ 2.0858,  2.0858],
        [ 1.5745,  1.5745],
        [-0.3786, -0.3786],
        [ 0.4450,  0.4450],
        [ 0.1694,  0.1694],
        [ 0.1170,  0.1170],
        [-0.9882, -0.9882]])

In [9]:
# Slice a CatLazyTensor
x = NonLazyTensor(torch.randn(5,1))
y = NonLazyTensor(torch.randn(4,1))
z = CatLazyTensor(*[x,y], dim=0)

zz = CatLazyTensor(*[z,z], dim=1)
print(zz.evaluate())
print(zz[2:5, :].evaluate())

tensor([[-1.3493, -1.3493],
        [ 0.5002,  0.5002],
        [-0.6933, -0.6933],
        [-0.1996, -0.1996],
        [ 0.4352,  0.4352],
        [-2.2446, -2.2446],
        [-0.1572, -0.1572],
        [ 0.9347,  0.9347],
        [ 0.1795,  0.1795]])
tensor([[-0.6933, -0.6933],
        [-0.1996, -0.1996],
        [ 0.4352,  0.4352]])
