In [12]:
import torch
from functools import reduce
from pyraul.tools.dumping import print_torch_tensor, gen_cpp_dtVec
from pyraul.tools.seed import set_seed

In [4]:
def broadcast_test(size, broadcast_size, start=0):
    r = reduce(lambda a,b: a*b, size, 1)
    x = torch.arange(start, start+r).reshape(size)
    y = torch.ones(broadcast_size)
    z = x*y
    print("x", x.shape, x, sep="\n")
    print("y", y.shape, y, sep="\n")
    print("z", z.shape, z, sep="\n")
    
    xf = x.flatten()
    zf = z.flatten()
    for i, x in enumerate(xf):
        print(f"{i}: x={xf[i].item()}")
    for i, x in enumerate(zf):
        print(f"{i}: z={zf[i].item()}")

In [5]:
broadcast_test(size=(1,2), broadcast_size=(2,2))

x
torch.Size([1, 2])
tensor([[0, 1]])
y
torch.Size([2, 2])
tensor([[1., 1.],
        [1., 1.]])
z
torch.Size([2, 2])
tensor([[0., 1.],
        [0., 1.]])
0: x=0
1: x=1
0: z=0.0
1: z=1.0
2: z=0.0
3: z=1.0


In [54]:
broadcast_test(size=(2,1), broadcast_size=(2,2))

x
torch.Size([2, 1])
tensor([[0],
        [1]])
y
torch.Size([2, 2])
tensor([[1., 1.],
        [1., 1.]])
z
torch.Size([2, 2])
tensor([[0., 0.],
        [1., 1.]])
0: x=0
1: x=1
0: z=0.0
1: z=0.0
2: z=1.0
3: z=1.0


In [55]:
broadcast_test(size=(2,2), broadcast_size=(2,2))

x
torch.Size([2, 2])
tensor([[0, 1],
        [2, 3]])
y
torch.Size([2, 2])
tensor([[1., 1.],
        [1., 1.]])
z
torch.Size([2, 2])
tensor([[0., 1.],
        [2., 3.]])
0: x=0
1: x=1
2: x=2
3: x=3
0: z=0.0
1: z=1.0
2: z=2.0
3: z=3.0


In [56]:
broadcast_test(size=(2,1,2), broadcast_size=(2,2,2))

x
torch.Size([2, 1, 2])
tensor([[[0, 1]],

        [[2, 3]]])
y
torch.Size([2, 2, 2])
tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])
z
torch.Size([2, 2, 2])
tensor([[[0., 1.],
         [0., 1.]],

        [[2., 3.],
         [2., 3.]]])
0: x=0
1: x=1
2: x=2
3: x=3
0: z=0.0
1: z=1.0
2: z=0.0
3: z=1.0
4: z=2.0
5: z=3.0
6: z=2.0
7: z=3.0


In [30]:
broadcast_test(size=(2,2,1), broadcast_size=(2,2,2))

x
torch.Size([2, 2, 1])
tensor([[[1],
         [2]],

        [[3],
         [4]]])
y
torch.Size([2, 2, 2])
tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])
z
torch.Size([2, 2, 2])
tensor([[[1., 1.],
         [2., 2.]],

        [[3., 3.],
         [4., 4.]]])


In [58]:
broadcast_test(size=(1,2,2), broadcast_size=(2,2,2))

x
torch.Size([1, 2, 2])
tensor([[[0, 1],
         [2, 3]]])
y
torch.Size([2, 2, 2])
tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])
z
torch.Size([2, 2, 2])
tensor([[[0., 1.],
         [2., 3.]],

        [[0., 1.],
         [2., 3.]]])
0: x=0
1: x=1
2: x=2
3: x=3
0: z=0.0
1: z=1.0
2: z=2.0
3: z=3.0
4: z=0.0
5: z=1.0
6: z=2.0
7: z=3.0


In [7]:
broadcast_test(size=(1,1,1,1), broadcast_size=(1,2,2,1), start=1)

x
torch.Size([1, 1, 1, 1])
tensor([[[[1]]]])
y
torch.Size([1, 2, 2, 1])
tensor([[[[1.],
          [1.]],

         [[1.],
          [1.]]]])
z
torch.Size([1, 2, 2, 1])
tensor([[[[1.],
          [1.]],

         [[1.],
          [1.]]]])
0: x=1
0: z=1.0
1: z=1.0
2: z=1.0
3: z=1.0


## Mul

In [19]:
set_seed(0)
x = torch.rand(5,1,2,3)
y = torch.rand(5,3,1,3)
z=x*y
print(z.shape)
print(gen_cpp_dtVec(x.data.flatten(),"x"))
print(gen_cpp_dtVec(y.data.flatten(),"y"))
print(gen_cpp_dtVec(z.data.flatten(),"z"))

torch.Size([5, 3, 2, 3])
const raul::dtVec x{0.49625658988952637_dt, 0.7682217955589294_dt, 0.08847743272781372_dt, 0.13203048706054688_dt, 0.30742281675338745_dt, 0.6340786814689636_dt, 0.4900934100151062_dt, 0.8964447379112244_dt, 0.455627977848053_dt, 0.6323062777519226_dt, 0.3488934636116028_dt, 0.40171730518341064_dt, 0.022325754165649414_dt, 0.16885894536972046_dt, 0.2938884496688843_dt, 0.518521785736084_dt, 0.6976675987243652_dt, 0.800011396408081_dt, 0.16102945804595947_dt, 0.28226858377456665_dt, 0.6816085577011108_dt, 0.9151939749717712_dt, 0.39709991216659546_dt, 0.8741558790206909_dt, 0.41940832138061523_dt, 0.5529070496559143_dt, 0.9527381062507629_dt, 0.036164820194244385_dt, 0.1852310299873352_dt, 0.37341737747192383_dt};
const raul::dtVec y{0.3051000237464905_dt, 0.9320003986358643_dt, 0.17591017484664917_dt, 0.2698335647583008_dt, 0.15067976713180542_dt, 0.03171950578689575_dt, 0.20812976360321045_dt, 0.9297990202903748_dt, 0.7231091856956482_dt, 0.7423362731933594_dt

In [40]:
set_seed(0)
x = torch.rand(2,1,1,1, requires_grad=True)
y = torch.rand(1,1,1,3, requires_grad=True)
z=x*y
z.requires_grad_(True)
z.sum().backward()
print_torch_tensor("x", x, grad=True)
print_torch_tensor("y", y, grad=True)
print_torch_tensor("z", z)
print("==============")
print(gen_cpp_dtVec(x.data.flatten(),"x"))
print(gen_cpp_dtVec(y.data.flatten(),"y"))
print(gen_cpp_dtVec(z.data.flatten(),"z"))
print(gen_cpp_dtVec(x.grad.flatten(),"x_grad"))
print(gen_cpp_dtVec(y.grad.flatten(),"y_grad"))

x (torch.Size([2, 1, 1, 1])):
[0.49625658988952637, 0.7682217955589294]
grad of x (torch.Size([2, 1, 1, 1])):
[0.527930736541748, 0.527930736541748]
y (torch.Size([1, 1, 1, 3])):
[0.08847743272781372, 0.13203048706054688, 0.30742281675338745]
grad of y (torch.Size([1, 1, 1, 3])):
[1.2644784450531006, 1.2644784450531006, 1.2644784450531006]
z (torch.Size([2, 1, 1, 3])):
[0.04390750825405121, 0.0655210018157959, 0.15256059169769287, 0.06797029078006744, 0.10142869502305984, 0.23616890609264374]
const raul::dtVec x{0.49625658988952637_dt, 0.7682217955589294_dt};
const raul::dtVec y{0.08847743272781372_dt, 0.13203048706054688_dt, 0.30742281675338745_dt};
const raul::dtVec z{0.04390750825405121_dt, 0.0655210018157959_dt, 0.15256059169769287_dt, 0.06797029078006744_dt, 0.10142869502305984_dt, 0.23616890609264374_dt};
const raul::dtVec x_grad{0.527930736541748_dt, 0.527930736541748_dt};
const raul::dtVec y_grad{1.2644784450531006_dt, 1.2644784450531006_dt, 1.2644784450531006_dt};
