In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import accumulate, product, chain
import torch

In [2]:
nCells = (10,10,10)
cellSize = (3,3,12)
n = 2.2 # PWO refractive index
c0 = 299.792458/n # mm/ns
c = c0/n

In [3]:
def onAxis_SolidAngle(a, b, d):
    alpha = a/(2*d)
    beta  = b/(2*d)
    return 4*np.arcsin(alpha*beta/np.sqrt((1+alpha**2)*(1+beta**2)))

In [4]:
def offAxis_SolidAngle(A, B, a, b, d):
    sign_A = np.sign(A)
    A *= sign_A
    sign_B = np.sign(B)
    B *= sign_B

    o1 = onAxis_SolidAngle(2*(a+sign_A*A), 2*(b+sign_B*B), d)
    o2 = onAxis_SolidAngle(2*A,            2*(b+sign_B*B), d)
    o3 = onAxis_SolidAngle(2*(a+sign_A*A), 2*B,            d)
    o4 = onAxis_SolidAngle(2*A,            2*B,            d)

    return (o1 - sign_A*o2 - sign_B*o3 + sign_A*sign_B*o4)/4


In [5]:
def generate_tuples_two_integers(n):
    """
    Generate all tuples of two integers (a, b) such that
    the sum of their absolute values equals n.

    Parameters:
        n (int): Target sum of absolute values.

    Returns:
        list of tuples: List of (a, b) satisfying the condition.
    """
    result = []
    for a in range(-n, n + 1):  # Iterate over possible values of a
        b_abs = n - abs(a)      # Compute the absolute value of b
        if b_abs >= 0:          # Ensure b_abs is non-negative
            # Add both positive and negative combinations of b
            result.append((a, b_abs))
            if b_abs != 0:  # Avoid duplicates when b_abs is 0
                result.append((a, -b_abs))
    return result


def all_reflections(n):
    """
    Generate all possible reflection paths such that
    the total amount of reflections equals n. A reflection 
    path is a (x, y, z) tuple where each coordinate is a
    (positive or negative) integer which signifies the amount
    of times a ray reflected in that direction.
    Positive values mean that the ray first reflected in the face
    in the positive direction of the axis and viceversa.

    Parameters:
        n (int): Target number of reflections.

    Returns:
        list of tuples: List of (x,y,z) satisfying the condition.
    """
    paths_horizontal = generate_tuples_two_integers(n)
    paths_vertical = generate_tuples_two_integers(n-1)
    paths = [(x, 0, z) for x, z in paths_horizontal] + [(x, -1, z) for x, z in paths_vertical]
    return paths
    

In [None]:
x = np.arange(0,6)
n_extra_point = [len(all_reflections(n)) for n in x]
total_point = list(accumulate(n_extra_point))
def test(N):
    if N==0:
        return 1
    elif N==1:
        return 5
    elif N > 1:
        return 4*(2*N-1)
r = [test(n) for n in x]
plt.plot(x, n_extra_point, 'o', label="Extra points per reflection")
plt.plot(x, r, 'o', label='4(2n-1)')
#plt.plot(x, total_point, 'o', label='Total points')
#plt.plot(x, 4*x**2, 'o',label='4x^2')
plt.legend()
plt.xlabel("# Reflections")
plt.xticks(np.arange(0, 22, 2))

In [None]:
n_extra_point

In [6]:
def reflect(old_coordinates, reflection, nCells=nCells):

    def reflect_onedim(old_coordinate, reflection, nCell):
        nc = old_coordinate
        if reflection != 0:
            sign = np.sign(reflection)
            reflection *= sign 
            partial_shift = reflection%2
            if partial_shift:
                if sign>0:
                    nc = 2*nCell-1-old_coordinate
                else:
                    nc = -1*old_coordinate-1
            whole_shifts = reflection//2
            nc += whole_shifts*sign*2*nCell
        return nc
    
    nc = [reflect_onedim(oc, r, n) for oc, r, n in zip(old_coordinates, reflection, nCells)]

    return nc

In [7]:
def find_distances(edep_idx, sensor_idx, cellSize=cellSize, nCells=nCells):
    shift_x = abs(edep_idx[0] - sensor_idx[0])
    if shift_x == 0:
        A = -0.5*cellSize[0]
    else:
        A = cellSize[0]*(shift_x-1+0.5)
    d = cellSize[1]*(nCells[1]-1-edep_idx[1]+0.5)
    shift_z = abs(edep_idx[2] - sensor_idx[1])
    if shift_z == 0:
        B = -0.5*cellSize[2]
    else:
        B = cellSize[2]*(shift_z-1+0.5)
    
    t = np.sqrt(d**2 + (shift_x*cellSize[0])**2 + (shift_z*cellSize[2])**2)/c

    return A, B, d, t

In [8]:
def compute_angles(edep_idx, sensor_idx, N, cellSize=cellSize):
    reflections = [all_reflections(n) for n in range(N+1)]
    new_coordinates = [[reflect(edep_idx, r) for r in ref] for ref in reflections]
    lengths = [[find_distances(edep, sensor_idx) for edep in n] for n in new_coordinates]
    omegas = [[(offAxis_SolidAngle(A, B, cellSize[0], cellSize[2], d), t) for A, B, d, t in l] for l in lengths]
    return omegas

In [None]:
def compute_matrices(N):

    reflections = [all_reflections(n) for n in range(N+1)]

    new_coordinates = [[[[[reflect((x, y, z), r) for x in range(nCells[0])] # cell idz x
                                                 for y in range(nCells[1])] # cell idx y
                                                 for z in range(nCells[2])] # cell idx z
                                                 for r in ref]              # possible reflections fixing n
                                                 for ref in reflections]    # all different ns
        
    lengths = [[[[[[[find_distances(edep_x, (s_x, s_z)) for s_x in range(nCells[0])]   # sensor idx x
                                                        for s_z in range(nCells[2])]   # sensor idx z
                                                        for edep_x in edep_y]          # cell idx x
                                                        for edep_y in edep_z]          # cell idx y
                                                        for edep_z in edep_ref]        # cell idx z
                                                        for edep_ref in edep_n]        # possible reflections fixing n
                                                        for edep_n in new_coordinates] # all different ns

    omegas = [[[[[[[[offAxis_SolidAngle(A, B, cellSize[0], cellSize[2], d) for A, B, d in l] for l in lengths]
    times = []

    return omegas, times

In [None]:
N = 5
reflections = [all_reflections(n) for n in range(N+1)]
new_coordinates = [[[[[reflect((x, y, z), r) for x in range(nCells[0])]
                                                     for y in range(nCells[1])]
                                                      for z in range(nCells[2])]
                                                      for r in ref]
                                                      for ref in reflections]

In [None]:
lengths = [[[[[[[find_distances(edep_x, (s_x, s_z)) for s_x in range(nCells[0])]
                                                 for s_z in range(nCells[2])]
                                                 for edep_x in edep_y]
                                                 for edep_y in edep_z]
                                                 for edep_z in edep_ref]
                                                 for edep_ref in edep_n]
                                                 for edep_n in new_coordinates]

In [None]:
import sys
def total_size(obj, seen=None):
    """Recursively find the size of an object, including referenced objects."""
    if seen is None:
        seen = set()

    obj_id = id(obj)
    if obj_id in seen:
        return 0
    seen.add(obj_id)

    size = sys.getsizeof(obj)
    if isinstance(obj, dict):
        size += sum(total_size(k, seen) + total_size(v, seen) for k, v in obj.items())
    elif isinstance(obj, (list, tuple, set)):
        size += sum(total_size(i, seen) for i in obj)

    return size
total_size(lengths)

In [None]:
a = [[10*y+x for x in range(3)] for y in range(5)]
print(a)
print(a[0][2])

In [None]:
a = torch.rand(3,5)
print(a)
print(a[0][2])

In [None]:
o = torch.zeros(*nCells, *(nCells[0,2]))

In [None]:
o.size()

In [None]:
N = 6
edep_idx = (0,2,4)
reflections = [all_reflections(n) for n in range(N+1)]
new_coordinates = [[reflect(edep_idx, r) for r in ref] for ref in reflections]

In [None]:
from itertools import chain
a = list(chain.from_iterable(reflections))
len(a)

In [None]:
for n in reflections:
    print(len(n))
    print(type(n[0]))

In [None]:
N = 5
edep_idx = (0,0,0)
sensor_idx = (9,9)
omegas = compute_angles(edep_idx, sensor_idx, N)

In [None]:
omega_sum = [np.sum(x) for x in omegas]
omega_cum = list(accumulate(omega_sum))

In [None]:
x = 6
plt.title("New Solid Angle coverage added by each reflection")
plt.ylabel("Solid Angle [srad]")
plt.xlabel("# Reflections")
plt.plot(np.arange(x), omega_sum[:x], 'o')
plt.plot(np.arange(x), omega_sum[:x], '--', color="lightblue")

In [None]:
x = 6
plt.title("Total Solid Angle coverage")
plt.ylabel("Solid Angle [srad]")
plt.xlabel("# Reflections")
plt.plot(np.arange(x), omega_cum[:x], 'o')
plt.plot(np.arange(x), omega_cum[:x], '--', color="lightblue")

In [None]:
a = [0,2,3,4]
for i, n in enumerate(a):
    print(i, n)

In [None]:
Blue (#1f77b4)
Orange (#ff7f0e)
Green (#2ca02c)
Red (#d62728)
Purple (#9467bd)
Brown (#8c564b)
Pink (#e377c2)
Gray (#7f7f7f)
Olive (#bcbd22)
Cyan (#17becf)

In [None]:
N = 5
o = compute_angles((5,5,5), (1, 1), N)
a = list(chain.from_iterable(o))
a

In [17]:
# Generate all possible combinations
combinations = list(product(range(10), range(10), range(10)))
combinations

[(0, 0, 0),
 (0, 0, 1),
 (0, 0, 2),
 (0, 0, 3),
 (0, 0, 4),
 (0, 0, 5),
 (0, 0, 6),
 (0, 0, 7),
 (0, 0, 8),
 (0, 0, 9),
 (0, 1, 0),
 (0, 1, 1),
 (0, 1, 2),
 (0, 1, 3),
 (0, 1, 4),
 (0, 1, 5),
 (0, 1, 6),
 (0, 1, 7),
 (0, 1, 8),
 (0, 1, 9),
 (0, 2, 0),
 (0, 2, 1),
 (0, 2, 2),
 (0, 2, 3),
 (0, 2, 4),
 (0, 2, 5),
 (0, 2, 6),
 (0, 2, 7),
 (0, 2, 8),
 (0, 2, 9),
 (0, 3, 0),
 (0, 3, 1),
 (0, 3, 2),
 (0, 3, 3),
 (0, 3, 4),
 (0, 3, 5),
 (0, 3, 6),
 (0, 3, 7),
 (0, 3, 8),
 (0, 3, 9),
 (0, 4, 0),
 (0, 4, 1),
 (0, 4, 2),
 (0, 4, 3),
 (0, 4, 4),
 (0, 4, 5),
 (0, 4, 6),
 (0, 4, 7),
 (0, 4, 8),
 (0, 4, 9),
 (0, 5, 0),
 (0, 5, 1),
 (0, 5, 2),
 (0, 5, 3),
 (0, 5, 4),
 (0, 5, 5),
 (0, 5, 6),
 (0, 5, 7),
 (0, 5, 8),
 (0, 5, 9),
 (0, 6, 0),
 (0, 6, 1),
 (0, 6, 2),
 (0, 6, 3),
 (0, 6, 4),
 (0, 6, 5),
 (0, 6, 6),
 (0, 6, 7),
 (0, 6, 8),
 (0, 6, 9),
 (0, 7, 0),
 (0, 7, 1),
 (0, 7, 2),
 (0, 7, 3),
 (0, 7, 4),
 (0, 7, 5),
 (0, 7, 6),
 (0, 7, 7),
 (0, 7, 8),
 (0, 7, 9),
 (0, 8, 0),
 (0, 8, 1),
 (0, 8, 2),
 (0,

In [None]:
L = sum([len(all_reflections(n)) for n in range(N+1)])
L

In [51]:
N = 5
L = sum([len(all_reflections(n)) for n in range(N+1)])
ot = torch.zeros(size=(nCells[0], nCells[1], nCells[2], L, nCells[0], nCells[2], 2))
edep_idx = list(product(range(nCells[0]), range(nCells[1]), range(nCells[2])))
for i, edep in enumerate(edep_idx):
    for x_i in range(nCells[0]):
        for z_i in range(nCells[2]):
            omegas = compute_angles(edep, (x_i, z_i), N)
            omegas = torch.tensor(list(chain.from_iterable(omegas)))
            ot[edep[0], edep[1], edep[2], :, x_i, z_i, 0] = omegas[:, 0]
            ot[edep[0], edep[1], edep[2], :, x_i, z_i, 1] = omegas[:, 1]


In [52]:
ot.numpy().astype(np.float32).tofile("emission_matrix.bin")
with open("shape.txt", "w") as f:
    f.write(" ".join(map(str, ot.shape)))

In [53]:
ot.dtype

torch.float32

In [54]:
shape = ot.shape
edep = (0,0,0)
sensor = (0,0)
n = 0
i = 0
o = compute_angles(edep, sensor, N)
ot_flat =np.fromfile("emission_matrix.bin", dtype=np.float32)
l = [len(o[i]) for i in range(n)]
print(sum(l)+i)
print(o[n][i])
print(ot[edep[0], edep[1], edep[2], sum(l)+i, sensor[0], sensor[1]])
idx = edep[0]   *shape[-6]*shape[-5]*shape[-4]*shape[-3]*shape[-2]*shape[-1] + \
      edep[1]   *shape[-5]*shape[-4]*shape[-3]*shape[-2]*shape[-1] + \
      edep[2]   *shape[-4]*shape[-3]*shape[-2]*shape[-1] + \
      (sum(l)+i)*shape[-3]*shape[-2]*shape[-1]+\
      sensor[0] *shape[-2]*shape[-1] + \
      sensor[1] *shape[-1]
print(ot_flat[idx], ot_flat[idx+1])

0
(np.float64(0.04331152830433155), np.float64(0.46011831291633104))
tensor([0.0433, 0.4601])
0.04331153 0.46011832


In [23]:
print(o)
a = torch.tensor(list(chain.from_iterable(o)))
a.shape

[[(np.float64(0.0028299358542824005), np.float64(1.0196949901292196))], [(np.float64(0.002689118286626213), np.float64(1.0368051607468494)), (np.float64(0.0017069397802056407), np.float64(1.205254942375085)), (np.float64(0.0001453931993487395), np.float64(2.7309154434715524)), (np.float64(0.001416890811996474), np.float64(1.2798273464253047)), (np.float64(0.003864828421238666), np.float64(1.1696967606279667))], [(np.float64(0.001146538307050582), np.float64(1.372686240863051)), (np.float64(0.001645754735865701), np.float64(1.2197649546434102)), (np.float64(0.00014436664910044694), np.float64(2.7373502201669586)), (np.float64(0.00011848542762614434), np.float64(2.923414523632331)), (np.float64(2.5852388478381227e-05), np.float64(4.853813107770435)), (np.float64(0.0010111754610275803), np.float64(1.4320683083858037)), (np.float64(0.00012945467825464796), np.float64(2.8383233295653323)), (np.float64(0.0009935270791392448), np.float64(1.4394206754652996)), (np.float64(0.0037186586980251235

torch.Size([102, 2])

In [29]:
o[1]

[(np.float64(0.002689118286626213), np.float64(1.0368051607468494)),
 (np.float64(0.0017069397802056407), np.float64(1.205254942375085)),
 (np.float64(0.0001453931993487395), np.float64(2.7309154434715524)),
 (np.float64(0.001416890811996474), np.float64(1.2798273464253047)),
 (np.float64(0.003864828421238666), np.float64(1.1696967606279667))]

In [32]:
print(a[:,0])

tensor([2.8299e-03, 2.6891e-03, 1.7069e-03, 1.4539e-04, 1.4169e-03, 3.8648e-03,
        1.1465e-03, 1.6458e-03, 1.4437e-04, 1.1849e-04, 2.5852e-05, 1.0112e-03,
        1.2945e-04, 9.9353e-04, 3.7187e-03, 2.5998e-03, 2.8300e-04, 2.2325e-03,
        8.0188e-04, 8.5192e-04, 1.2372e-04, 1.1776e-04, 2.5795e-05, 2.2993e-05,
        1.0304e-05, 1.0702e-04, 2.4897e-05, 7.5673e-04, 1.1963e-04, 3.5889e-04,
        1.8687e-03, 2.5238e-03, 2.8109e-04, 2.3251e-04, 5.2589e-05, 1.6779e-03,
        2.5318e-04, 1.6531e-03, 2.9895e-04, 6.3157e-04, 1.1322e-04, 1.0283e-04,
        2.4521e-05, 2.2945e-05, 1.0292e-05, 9.4477e-06, 4.4503e-06, 2.2205e-05,
        1.0095e-05, 9.9812e-05, 2.4240e-05, 3.1065e-04, 8.6973e-05, 2.6578e-04,
        1.3718e-03, 1.4463e-03, 2.4239e-04, 2.3113e-04, 5.2473e-05, 4.6844e-05,
        2.1161e-05, 2.1078e-04, 5.0671e-05, 1.3034e-03, 2.3468e-04, 6.6354e-04,
        2.2418e-04, 2.6290e-04, 8.0792e-05, 9.5048e-05, 2.3779e-05, 2.1894e-05,
        1.0011e-05, 9.4369e-06, 4.4472e-

In [None]:
N = 10
fig, ax = plt.subplots(ncols=2, figsize=(16,6))
edep_idx=[(0,y,0) for y in range(nCells[1])]
col=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
     '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
for i, edep in enumerate(edep_idx):
    tot_sum = []
    tot_cum = []
    for x_i in range(nCells[0]):
        for z_i in range(nCells[2]):
            omegas = compute_angles(edep, (x_i, z_i), N)
            omegas = [[pair[0] for pair in sublist] for sublist in omegas]
            omega_sum = [np.sum(x) for x in omegas]
            omega_cum = list(accumulate(omega_sum))

            tot_sum.append(omega_sum)
            tot_cum.append(omega_cum)

    full_sum = [sum(values) for values in zip(*tot_sum)]
    full_cum = [sum(values) for values in zip(*tot_cum)]

    ax[0].plot(np.arange(N+1), full_sum, 'o',  color=col[i], label=f'y={edep[1]}')
    ax[0].plot(np.arange(N+1), full_sum, '--', color=col[i])

    ax[1].plot(np.arange(N+1), full_cum, 'o',  color=col[i], label=f'y={edep[1]}')
    ax[1].plot(np.arange(N+1), full_cum, '--', color=col[i])

ax[0].set_title("New Solid Angle coverage added by each reflection")
ax[1].set_title("Total Solid Angle coverage")
ax[0].set_ylabel("Solid Angle [srad]")
ax[1].axhline(4*np.pi)
ax[1].axhline(0.9*4*np.pi)
fig.text(0.5, 0.04, '# Reflections', ha='center')
plt.tight_layout(rect=[0.05, 0.05, 1, 1])
ax[0].legend()

In [None]:
N = 5
fig, ax = plt.subplots(ncols=2, figsize=(16,6))
edep_idx=[(5,5,z) for z in [0,1,2,3,4]]
col=['lightblue', 'orange', 'red', 'black', 'green']
for i, edep in enumerate(edep_idx):
    tot_sum = []
    tot_cum = []
    for x_i in range(nCells[0]):
        for z_i in range(nCells[2]):
            omegas = compute_angles(edep, (x_i, z_i), N)
            omegas = [[pair[0] for pair in sublist] for sublist in omegas]
            omega_sum = [np.sum(x) for x in omegas]
            omega_cum = list(accumulate(omega_sum))

            tot_sum.append(omega_sum)
            tot_cum.append(omega_cum)

    full_sum = [sum(values) for values in zip(*tot_sum)]
    full_cum = [sum(values) for values in zip(*tot_cum)]

    ax[0].plot(np.arange(N+1), full_sum, 'o',  color=col[i], label=f'z={edep[2]}')
    ax[0].plot(np.arange(N+1), full_sum, '--', color=col[i])

    ax[1].plot(np.arange(N+1), full_cum, 'o',  color=col[i], label=f'z={edep[2]}')
    ax[1].plot(np.arange(N+1), full_cum, '--', color=col[i])

ax[0].set_title("New Solid Angle coverage added by each reflection")
ax[1].set_title("Total Solid Angle coverage")
ax[0].set_ylabel("Solid Angle [srad]")
fig.text(0.5, 0.04, '# Reflections', ha='center')
plt.tight_layout(rect=[0.05, 0.05, 1, 1])
ax[1].legend()

In [None]:
N = 10
fig, ax = plt.subplots(ncols=2, figsize=(16,6))
edep_idx=[(x,5,5) for x in [0,1,2,3,4]]
col=['lightblue', 'orange', 'red', 'black', 'green']
for i, edep in enumerate(edep_idx):
    tot_sum = []
    tot_cum = []
    for x_i in range(nCells[0]):
        for z_i in range(nCells[2]):
            omegas = compute_angles(edep, (x_i, z_i), N)
            omegas = [[pair[0] for pair in sublist] for sublist in omegas]
            omega_sum = [np.sum(x) for x in omegas]
            omega_cum = list(accumulate(omega_sum))

            tot_sum.append(omega_sum)
            tot_cum.append(omega_cum)

    full_sum = [sum(values) for values in zip(*tot_sum)]
    full_cum = [sum(values) for values in zip(*tot_cum)]

    ax[0].plot(np.arange(N+1), full_sum, 'o',  color=col[i], label=f'x={edep[0]}')
    ax[0].plot(np.arange(N+1), full_sum, '--', color=col[i])

    ax[1].plot(np.arange(N+1), full_cum, 'o',  color=col[i], label=f'x={edep[0]}')
    ax[1].plot(np.arange(N+1), full_cum, '--', color=col[i])

ax[0].set_title("New Solid Angle coverage added by each reflection")
ax[1].set_title("Total Solid Angle coverage")
ax[0].set_ylabel("Solid Angle [srad]")
fig.text(0.5, 0.04, '# Reflections', ha='center')
plt.tight_layout(rect=[0.05, 0.05, 1, 1])
ax[1].legend()

In [None]:
N = 5
fig, ax = plt.subplots(ncols=2, figsize=(16,6))
edep_idx=[(0,0,0), (5,9,5)]
col=['lightblue', 'orange', 'red', 'black', 'green']
for i, edep in enumerate(edep_idx):
    tot_sum = []
    tot_cum = []
    for x_i in range(nCells[0]):
        for z_i in range(nCells[2]):
            omegas = compute_angles(edep, (x_i, z_i), N)
            omega_sum = [np.sum(x) for x in omegas]
            omega_cum = list(accumulate(omega_sum))

            tot_sum.append(omega_sum)
            tot_cum.append(omega_cum)

    full_sum = [sum(values) for values in zip(*tot_sum)]
    full_cum = [sum(values) for values in zip(*tot_cum)]

    ax[0].plot(np.arange(N+1), full_sum, 'o',  color=col[i], label=f'pos={edep}')
    ax[0].plot(np.arange(N+1), full_sum, '--', color=col[i])

    ax[1].plot(np.arange(N+1), full_cum, 'o',  color=col[i], label=f'pos={edep}')
    ax[1].plot(np.arange(N+1), full_cum, '--', color=col[i])

ax[0].set_title("New Solid Angle coverage added by each reflection")
ax[1].set_title("Total Solid Angle coverage")
ax[0].set_ylabel("Solid Angle [srad]")
fig.text(0.5, 0.04, '# Reflections', ha='center')
plt.tight_layout(rect=[0.05, 0.05, 1, 1])
ax[1].legend()

In [None]:
full_sum = [sum(values) for values in zip(*tot_sum)]
x = 6
plt.title("New Solid Angle coverage added by each reflection")
plt.ylabel("Solid Angle [srad]")
plt.xlabel("# Reflections")
plt.plot(np.arange(x), full_sum[:x], 'o')
plt.plot(np.arange(x), full_sum[:x], '--', color="lightblue")

In [None]:
a = [1,2,3]
b = [84,4,3]
c = list(zip(a,b))
c

In [None]:
full_cum = [sum(values) for values in zip(*tot_cum)]
x = 6
plt.axhline(4*np.pi, label="$4\pi$", color='orange')
plt.title("Total Solid Angle coverage")
plt.ylabel("Solid Angle [srad]")
plt.xlabel("# Reflections")
plt.plot(np.arange(x), full_cum[:x], 'o')
plt.plot(np.arange(x), full_cum[:x], '--', color="lightblue")
plt.legend()

## Simulate Fake Event

In [None]:
import torch
a = torch.rand((100, 10, 10, 10, 10, 10))

In [None]:
a.nelement() * a.element_size()