## Cross checking with original implementation vs. optimized implementation

In [1]:
# Torch
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch import optim 
from torchsummary import summary


# Train + Data 
import sys 
sys.path.append('../Layers')
from Conv1d_NN import *
from Conv2d_NN import *
from Conv1d_NN_spatial import * 
from Conv2d_NN_spatial import * 
from ConvNN_CNN_Branching import *

### 1. Distance Matrix Calculation

In [2]:
def _calculate_distance_matrix(matrix, sqrt=False):
    norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
    print("norm_squared", norm_squared.shape)
    print("matrix t", matrix.transpose(2, 1).shape)
    print("matrix", matrix.shape)
    print("matrix t * matrix", torch.bmm(matrix.transpose(2, 1), matrix).shape)
    
    dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
    dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
    
    dist_matrix = torch.clamp(dist_matrix, min=0)  
    
    if sqrt:
        dist_matrix = torch.sqrt(dist_matrix)
    return dist_matrix

ex = torch.randn(32, 3, 28)
o1 = _calculate_distance_matrix(ex)
print(o1.shape)


norm_squared torch.Size([32, 1, 28])
matrix t torch.Size([32, 28, 3])
matrix torch.Size([32, 3, 28])
matrix t * matrix torch.Size([32, 28, 28])
torch.Size([32, 28, 28])


In [3]:
# N Samples
def original_calculate_distance_matrix_N(matrix, matrix_sample, sqrt=False):
    """Calculates distance matrix between two input matrices""" 
    norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
    
    print("norm squared", norm_squared.shape)
    print("norm sqaured no permute", torch.sum(matrix ** 2, dim=1, keepdim=True).shape)
    
    norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
    print("norm squared sample", norm_squared_sample.shape)
    print("norm squared sample no permute", torch.sum(matrix_sample ** 2, dim=1, keepdim=True).shape)
    
    dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
    print("dot product", dot_product.shape)
    
    dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
    print("dist matrix", dist_matrix.shape)
    
    dist_matrix = torch.clamp(dist_matrix, min=0.0)  
    
    if sqrt:
        dist_matrix = torch.sqrt(dist_matrix)
    return dist_matrix

ex = torch.randn(32, 3, 28)
ex1 = torch.randn(32, 3, 10)
o1 = original_calculate_distance_matrix_N(ex, ex1)
print(o1.shape)
    

norm squared torch.Size([32, 28, 1])
norm sqaured no permute torch.Size([32, 1, 28])
norm squared sample torch.Size([32, 1, 10])
norm squared sample no permute torch.Size([32, 1, 10])
dot product torch.Size([32, 28, 10])
dist matrix torch.Size([32, 28, 10])
torch.Size([32, 28, 10])


### 2. Similarity Matrix Calculation

In [4]:
# All Samples
def original_calculate_similarity_matrix(matrix): 
    """Calculates similarity matrix of the input matrix"""
    normalized_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    dot_product = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix)
    similarity_matrix = dot_product 
    return similarity_matrix

def optimized_calculate_similarity_matrix(matrix): 
    normalized_matrix = F.normalized(matrix, p=2, dim=1)
    similarity_matrix = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix)
    return similarity_matrix

In [5]:
# N Samples
def original_calculate_similarity_matrix_N(matrix, matrix_sample): 
    """Calculates similarity matrix between two input matrices"""
    norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    norm_sample = F.normalize(matrix_sample, p=2, dim=1)
    similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
    return similarity_matrix


### 3. Topk Selection

In [6]:
# All Samples

@staticmethod 
def prime_vmap_2d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=True, maximum=maximum)
    return prime 

@staticmethod 
def prime_vmap_3d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=False, maximum=maximum)
    return prime

@staticmethod 
def process_batch(matrix, magnitude_matrix, num_nearest_neighbors, flatten, maximum): 
    """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
    ind = torch.topk(magnitude_matrix, num_nearest_neighbors, largest=maximum).indices 
    neigh = matrix[:, ind]
    if flatten: 
        reshape = torch.flatten(neigh, start_dim=1)
        return reshape
    return neigh

mat = torch.randn(32, 3, 28)
mag = original_calculate_similarity_matrix(mat)
prime = prime_vmap_2d(mat, mag, 5, True)

print('original matrix shape', mat.shape)
print('original similarity matrix shape', mag.shape)
print('prime shape', prime.shape)

# 28 * 5 = 140 



original matrix shape torch.Size([32, 3, 28])
original similarity matrix shape torch.Size([32, 28, 28])
prime shape torch.Size([32, 3, 140])


In [7]:
torch.set_printoptions(precision=4, sci_mode=False)


def prime_brute(matrix, dist_matrix, K): 
    stack_list = [] 
    count = 0 
    for i in range(matrix.shape[0]): 
        # print("i", i)
        concat_list = [] 
        for j in range(matrix.shape[2]): 
            # Get the indices of the nearest neighbors
            indices = torch.topk(dist_matrix[i, j, :], K, largest=True).indices
            # print("indices", indices)
            # Get the nearest neighbors
            nearest_neighbors = matrix[i, :, indices]
            
            # Concatenate the nearest neighbors
            concat_list.append(nearest_neighbors)
            count += 1
        # Concatenate the tensor list to create the convolution matrix 
        # print()
        concat = torch.cat(concat_list, dim=1)
        stack_list.append(concat)
    prime = torch.stack(stack_list, dim= 0)
    # print("count", count)
    return prime
        

mat = torch.randn(32, 3, 28)
sim = original_calculate_similarity_matrix(mat)
prime_b = prime_brute(mat, sim, 5)
print('original matrix shape', mat.shape)
print('original similarity matrix shape', sim.shape)
print('prime shape', prime.shape)
print(prime_b)

original matrix shape torch.Size([32, 3, 28])
original similarity matrix shape torch.Size([32, 28, 28])
prime shape torch.Size([32, 3, 140])
tensor([[[-0.5865, -0.5183, -0.5359,  ...,  0.2010, -0.4622,  0.7154],
         [ 1.2395,  0.5388,  1.0418,  ..., -0.5152, -0.3412, -0.3728],
         [ 0.4944,  0.5314,  1.0704,  ...,  0.1287,  1.3529,  0.7918]],

        [[-1.7115, -1.7791, -1.9969,  ..., -0.2711,  0.3555, -0.6732],
         [-1.3086, -0.7488, -0.6359,  ...,  1.1209,  1.3102,  1.9382],
         [-0.4066, -0.1557, -0.1028,  ...,  0.4914, -0.5082, -1.0255]],

        [[-2.8810, -0.4386, -1.6857,  ..., -0.9492, -0.2095, -0.2460],
         [ 0.8553,  0.2906,  0.4509,  ..., -0.6014, -1.1020, -0.2151],
         [-0.4248,  0.0626, -1.8269,  ...,  1.0447,  0.1566,  0.4234]],

        ...,

        [[-1.3938, -0.7453, -1.3284,  ..., -1.3938, -0.8943, -0.7827],
         [-0.0819, -0.2186, -0.2709,  ..., -0.0819, -1.2010, -1.0662],
         [ 0.5803,  0.4700, -0.3583,  ...,  0.5803, -0.583

In [8]:






_, topk_indices = torch.topk(sim, k=5, dim=2, largest=True)  
print("topk indices: ", topk_indices.shape)  


batch_size, channels, tokens = mat.shape
K = topk_indices.shape[-1]  
print("batch_size: ", batch_size)
print("channels: ", channels)
print("tokens: ", tokens)
print("K: ", K)

# Expand topk_indices: add a channel dimension at dim=1
indices_expanded = topk_indices.unsqueeze(1).expand(batch_size, channels, tokens, K)
print(indices_expanded)

print("indices_expanded shape:", indices_expanded.shape)  # torch.Size([32, 3, 28, 5])



print()

# Create index tensors for batch and channel dimensions:
batch_indices = torch.arange(batch_size).view(batch_size, 1, 1, 1).expand(batch_size, channels, tokens, K)

channel_indices = torch.arange(channels).view(1, channels, 1, 1).expand(batch_size, channels, tokens, K)

# Now use advanced indexing:
# For each [b, c, i, j], we select mat[b, c, indices_expanded[b, c, i, j]]
prime = mat[batch_indices, channel_indices, indices_expanded]
# prime will have shape [32, 3, 28, 5]

# Finally, reshape (flatten the token and neighbor dimensions) to get [32, 3, 28*5] = [32, 3, 140]
prime_new = prime.view(batch_size, channels, -1)

print("prime_new shape:", prime_new.shape)


topk indices:  torch.Size([32, 28, 5])
batch_size:  32
channels:  3
tokens:  28
K:  5
tensor([[[[ 0, 20, 11,  6,  9],
          [ 1, 26, 18, 14,  8],
          [ 2, 22, 21,  0, 13],
          ...,
          [25,  4,  5,  7, 19],
          [26,  8,  1, 18, 14],
          [27, 16,  3, 12, 19]],

         [[ 0, 20, 11,  6,  9],
          [ 1, 26, 18, 14,  8],
          [ 2, 22, 21,  0, 13],
          ...,
          [25,  4,  5,  7, 19],
          [26,  8,  1, 18, 14],
          [27, 16,  3, 12, 19]],

         [[ 0, 20, 11,  6,  9],
          [ 1, 26, 18, 14,  8],
          [ 2, 22, 21,  0, 13],
          ...,
          [25,  4,  5,  7, 19],
          [26,  8,  1, 18, 14],
          [27, 16,  3, 12, 19]]],


        [[[ 0, 15, 22,  2,  5],
          [ 1, 19, 24, 21, 23],
          [ 2,  6, 25, 20,  9],
          ...,
          [25,  9, 12,  2, 18],
          [26, 10,  8, 27, 16],
          [27, 16, 23,  8, 26]],

         [[ 0, 15, 22,  2,  5],
          [ 1, 19, 24, 21, 23],
          [ 

In [9]:
print(torch.allclose(prime_b, prime_new))



True


### Optimized check


In [10]:
torch.set_printoptions(precision=4, sci_mode=False)


In [11]:
def calculate_similarity_matrix(matrix): 
    """Calculates similarity matrix of the input matrix"""
    normalized_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    dot_product = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix)
    similarity_matrix = dot_product 
    return similarity_matrix

In [12]:
def prime_brute(matrix, dist_matrix, K): 
    stack_list = [] 
    count = 0 
    for i in range(matrix.shape[0]): 
        concat_list = [] 
        for j in range(matrix.shape[2]): 
            indices = torch.topk(dist_matrix[i, j, :], K, largest=True).indices

            nearest_neighbors = matrix[i, :, indices]
            
            concat_list.append(nearest_neighbors)
            count += 1
        concat = torch.cat(concat_list, dim=1)
        stack_list.append(concat)
    prime = torch.stack(stack_list, dim= 0)
    return prime
        

In [13]:
# All Samples

@staticmethod 
def prime_vmap_2d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=True, maximum=maximum)
    return prime 

@staticmethod 
def prime_vmap_3d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=False, maximum=maximum)
    return prime

@staticmethod 
def process_batch(matrix, magnitude_matrix, num_nearest_neighbors, flatten, maximum): 
    """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
    ind = torch.topk(magnitude_matrix, num_nearest_neighbors, largest=maximum).indices 
    neigh = matrix[:, ind]
    if flatten: 
        reshape = torch.flatten(neigh, start_dim=1)
        return reshape
    return neigh


In [14]:
def prime_optimized(matrix, dist_matrix, K=5, maximum=False):

    _, topk_indices = torch.topk(dist_matrix, k=K, dim=2, largest=True)  


    batch_size, channels, tokens = matrix.shape
    K = topk_indices.shape[-1]  

    indices_expanded = topk_indices.unsqueeze(1).expand(batch_size, channels, tokens, K)

    batch_indices = torch.arange(batch_size).view(batch_size, 1, 1, 1).expand(batch_size, channels, tokens, K)

    channel_indices = torch.arange(channels).view(1, channels, 1, 1).expand(batch_size, channels, tokens, K)

    prime = matrix[batch_indices, channel_indices, indices_expanded]

    prime_new = prime.view(batch_size, channels, -1)

    return prime_new

In [15]:
def _prime(matrix, magnitude_matrix, K, maximum):
    b, c, t = matrix.shape 

    _, topk_indices = torch.topk(magnitude_matrix, k = K, dim=2, largest=maximum)
    
    tk = topk_indices.shape[-1]
    
    assert K == tk, "Error: K must be same as tk. K == tk."
    
    indices_expanded = topk_indices.unsqueeze(1).expand(b, c, t, tk)
    batch_indices = torch.arange(b).view(b, 1, 1, 1).expand(b, c, t, tk)
    channel_indices = torch.arange(c).view(1, c, 1, 1).expand(b, c, t, tk)
    
    prime = matrix[batch_indices, channel_indices, indices_expanded]
    prime = prime.view(b, c, -1)
    return prime

In [16]:
mat = torch.randn(32, 3, 28)
sim = calculate_similarity_matrix(mat)
prime_b = prime_brute(mat, sim, 5)
prime_o = prime_optimized(mat, sim, 5)
prime_f = prime_vmap_2d(mat, sim, 5, True)
prime_ = _prime(mat, sim, 5, True)
print('original matrix shape', mat.shape)
print('original similarity matrix shape', sim.shape)
print()
print('prime brute shape', prime_b.shape)
print('prime optimized shape', prime_o.shape)
print('prime vmap shape', prime_f.shape)

print(torch.allclose(prime_b, prime_o))
print(torch.allclose(prime_b, prime_f))
print(torch.allclose(prime_o, prime_f))
print(torch.allclose(prime_b, prime_))


original matrix shape torch.Size([32, 3, 28])
original similarity matrix shape torch.Size([32, 28, 28])

prime brute shape torch.Size([32, 3, 140])
prime optimized shape torch.Size([32, 3, 140])
prime vmap shape torch.Size([32, 3, 140])
True
True
True
True


#### N Samples

In [17]:
torch.set_printoptions(precision=4, sci_mode=False)
import numpy as np


In [18]:
# N Samples
def calculate_similarity_matrix_N(matrix, matrix_sample): 
    """Calculates similarity matrix between two input matrices"""
    norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    norm_sample = F.normalize(matrix_sample, p=2, dim=1)
    similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
    return similarity_matrix

In [19]:

def prime_vmap_2d_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
    batched_process = torch.vmap(process_batch_N, in_dims=(0, 0, None, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten=True, maximum=maximum)
    return prime 

def prime_vmap_3d_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
    batched_process = torch.vmap(process_batch_N, in_dims=(0, 0, None, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten=False, maximum=maximum)
    return prime

def process_batch_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten, maximum): 
    """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
    topk_ind = torch.topk(magnitude_matrix, num_nearest_neighbors - 1, largest=maximum).indices
    device = topk_ind.device
    rand_idx = rand_idx.to(device) # same device as topk_ind
    mapped_tensor = rand_idx[topk_ind] 
    index_tensor = torch.arange(0, matrix.shape[1], device=device).unsqueeze(1) # shape [40, 1]
    final_tensor = torch.cat([index_tensor, mapped_tensor], dim=1)
    neigh = matrix[:, final_tensor] 
    if flatten: 
        reshape = torch.flatten(neigh, start_dim=1)
        return reshape
    return neigh

In [20]:
def _prime_N(matrix, magnitude_matrix, K, rand_idx, maximum):
    """
    Create prime tensor for the N-samples case.
    
    Args:
        matrix (torch.Tensor): Input tensor of shape [b, c, t].
        magnitude_matrix (torch.Tensor): Magnitude matrix of shape [b, t, s] 
                                           (where s is the number of sampled tokens).
        K (int): Total number of neighbors to consider (including the self index).
        rand_idx (torch.Tensor): 1D tensor of length s containing the sampled indices.
        maximum (bool): If True, select the largest values (e.g. for similarity); if False, select the smallest (e.g. for distance).
        
    Returns:
        torch.Tensor: A prime tensor of shape [b, c, t*K].
    """
    b, c, t = matrix.shape
    # print("batch ", b)
    # print("channel ", c)
    # print("token ", t)
        

    _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
    tk = topk_indices.shape[-1]
    assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."
    # print("topk_indices shape: ", topk_indices.shape)  # torch.Size([32, 28, 4])
    
    mapped_tensor = rand_idx[topk_indices]
    # print("mapped_tensor shape: ", mapped_tensor.shape)  # torch.Size([32, 28, 4])
    # print(mapped_tensor)

    token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
    # print("token_indices shape: ", token_indices.shape)  # torch.Size([32, 28, 1])
    # 
    
    final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
    # print("final_indices shape: ", final_indices.shape)  # torch.Size([32, 28, 5])

    indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
    # print("indices_expanded shape: ", indices_expanded.shape)  # torch.Size([32, 3, 28, 5])
    
    batch_indices = torch.arange(b, device=matrix.device).view(b, 1, 1, 1).expand(b, c, t, K)
    channel_indices = torch.arange(c, device=matrix.device).view(1, c, 1, 1).expand(b, c, t, K)
    

    prime = matrix[batch_indices, channel_indices, indices_expanded]  
    
    prime = prime.view(b, c, -1)
    return prime



In [21]:


samples = 10
x1 = torch.randn(32, 3, 28)


rand_idx = torch.tensor(np.random.choice(x1.shape[2], samples, replace=False)) # list 
x1_sample = x1[:, :, rand_idx]


sim = calculate_similarity_matrix_N(x1, x1_sample)
p1 = prime_vmap_2d_N(x1, sim, 5, rand_idx, True)
o1 = _prime_N(x1, sim, 5, rand_idx, True)
print('p1 shape', p1.shape)
print('o1 shape', o1.shape)
print(torch.allclose(p1, o1))

p1 shape torch.Size([32, 3, 140])
o1 shape torch.Size([32, 3, 140])
True


In [22]:
rand_idx = torch.tensor(np.random.choice(x1.shape[2], samples, replace=False)) # list 
print('rand_idx', rand_idx)

rand_idx_1 = torch.randperm(x1.shape[2], device=x1.device)[:samples]

print('rand_idx_1', rand_idx_1)

rand_idx tensor([17,  5, 26, 15, 11, 27,  8, 14, 13, 25])
rand_idx_1 tensor([24, 12,  3,  5, 25, 13, 16, 20,  6,  2])


In [23]:
print(p1)

tensor([[[-0.2694, -0.2975,  0.0531,  ..., -0.8539, -0.5894, -0.4942],
         [ 0.4386,  1.2617,  1.9679,  ..., -0.1308, -0.1011, -0.2730],
         [ 0.2822,  0.1671,  0.3515,  ..., -1.7980,  0.0064,  0.0574]],

        [[ 1.7518,  1.2711,  2.3639,  ...,  1.1061,  2.3639,  1.2711],
         [-0.7599, -1.1533, -2.1803,  ...,  2.1738, -2.1803, -1.1533],
         [-0.5267, -0.1448, -0.0572,  ...,  0.9943, -0.0572, -0.1448]],

        [[-1.2693, -0.8747, -1.3881,  ..., -0.8747, -0.0742, -1.3881],
         [ 0.7128,  0.2238, -0.1755,  ...,  0.2238, -0.0850, -0.1755],
         [-0.4798, -0.4872, -0.6020,  ..., -0.4872, -3.1228, -0.6020]],

        ...,

        [[ 2.6302,  2.9570,  1.4378,  ...,  0.4437,  2.9570,  0.2356],
         [ 1.6285,  0.6963,  0.4922,  ..., -0.2838,  0.6963, -0.6965],
         [ 0.0815,  0.7962,  0.6793,  ..., -0.4098,  0.7962,  0.1178]],

        [[ 0.0246,  0.2150, -0.0120,  ..., -0.9588,  1.1807,  1.5767],
         [ 0.6311,  0.7079, -0.0340,  ...,  1.1877,  0.

In [24]:
print(o1)

tensor([[[-0.2694, -0.2975,  0.0531,  ..., -0.8539, -0.5894, -0.4942],
         [ 0.4386,  1.2617,  1.9679,  ..., -0.1308, -0.1011, -0.2730],
         [ 0.2822,  0.1671,  0.3515,  ..., -1.7980,  0.0064,  0.0574]],

        [[ 1.7518,  1.2711,  2.3639,  ...,  1.1061,  2.3639,  1.2711],
         [-0.7599, -1.1533, -2.1803,  ...,  2.1738, -2.1803, -1.1533],
         [-0.5267, -0.1448, -0.0572,  ...,  0.9943, -0.0572, -0.1448]],

        [[-1.2693, -0.8747, -1.3881,  ..., -0.8747, -0.0742, -1.3881],
         [ 0.7128,  0.2238, -0.1755,  ...,  0.2238, -0.0850, -0.1755],
         [-0.4798, -0.4872, -0.6020,  ..., -0.4872, -3.1228, -0.6020]],

        ...,

        [[ 2.6302,  2.9570,  1.4378,  ...,  0.4437,  2.9570,  0.2356],
         [ 1.6285,  0.6963,  0.4922,  ..., -0.2838,  0.6963, -0.6965],
         [ 0.0815,  0.7962,  0.6793,  ..., -0.4098,  0.7962,  0.1178]],

        [[ 0.0246,  0.2150, -0.0120,  ..., -0.9588,  1.1807,  1.5767],
         [ 0.6311,  0.7079, -0.0340,  ...,  1.1877,  0.