## Double Checking matrix prime calculations for random sampling


In [3]:
# Torch
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch import optim 


# Train + Data 
import sys 
sys.path.append('../Layers')
from Conv1d_NN import * 
from Conv2d_NN import * 
from Conv1d_NN_spatial import *
from Conv2d_NN_spatial import * 

sys.path.append('../Data')
from CIFAR10 import * 


sys.path.append('../Train')
from train2d import * 


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
ex1 = torch.rand(1, 5, 5)
print(ex1.shape)

torch.Size([1, 5, 5])


In [5]:
ds_Reg = Conv1d_NN.calculate_distance_matrix(ex1)
print(ds_Reg.shape)

torch.Size([1, 5, 5])


In [6]:
ds_Spatial = Conv1d_NN_spatial.calculate_distance_matrix_N(ex1, ex1)
print(ds_Spatial.shape)

torch.Size([1, 5, 5])


In [7]:
result = torch.equal(ds_Reg, ds_Spatial)


In [8]:
print(result)

False


In [9]:
sub = ds_Reg - ds_Spatial
print(sub)

tensor([[[nan, 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., nan]]])


In [10]:
sm_reg = Conv1d_NN.calculate_similarity_matrix(ex1)

sm_spatial = Conv1d_NN_spatial.calculate_similarity_matrix_N(ex1, ex1)

result = torch.equal(sm_reg, sm_spatial)
print(result)

sub = sm_reg - sm_spatial
print(sub)

True
tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])


## I. Spatial sampling indices for batch processing


In [25]:
x1 = torch.randn(32, 12, 14, 14)
sample_padding = 0
samples = 14

# Generate equally spaced indices for rows (x) and columns (y)
x_ind = torch.round(torch.linspace(sample_padding,
                                   x1.shape[2] - sample_padding - 1,
                                   samples)).to(torch.int)
y_ind = torch.round(torch.linspace(sample_padding,
                                   x1.shape[3] - sample_padding - 1,
                                   samples)).to(torch.int)

# Create a meshgrid of indices
x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')

x_sample = x1[:, :, x_grid, y_grid]
print("x_sample shape: ", x_sample.shape)

x_sample_flatten = x_sample.flatten(start_dim=2)
print("x_sample_flatten shape: ", x_sample_flatten.shape)


flatten = nn.Flatten(start_dim=2)
x2 = flatten(x1)
print("x2 shape: ", x2.shape)

result = torch.equal(x_sample_flatten, x2)
print(result)

x_sample shape:  torch.Size([32, 12, 14, 14])
x_sample_flatten shape:  torch.Size([32, 12, 196])
x2 shape:  torch.Size([32, 12, 196])
True


In [None]:
import torch

# Sample input tensor dimensions [32, 12, 14, 14]
x1 = torch.randn(32, 12, 14, 14)
sample_padding = 0
samples = 14  # We want a 3x3 sample

# Generate equally spaced indices for rows (x) and columns (y)
x_ind = torch.round(torch.linspace(sample_padding,
                                   x1.shape[2] - sample_padding - 1,
                                   samples)).to(torch.int)
y_ind = torch.round(torch.linspace(sample_padding,
                                   x1.shape[3] - sample_padding - 1,
                                   samples)).to(torch.int)

# Create a meshgrid of indices
x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')

# Flatten the grid indices (each will have 'samples*samples' = 9 values)
x_idx_flat = x_grid.flatten()  # Row indices
y_idx_flat = y_grid.flatten()  # Column indices

# To get the corresponding flattened indices for a 14x14 matrix:
width = x1.shape[3]  # This is 14
flat_indices = x_idx_flat * width + y_idx_flat

print("3x3 Sample pixel positions (flattened indices):", flat_indices)

flatten = nn.Flatten(start_dim=2)
x1 = torch.randn(32, 12, 14, 14)
x2 = flatten(x1)
print(x2.shape)



3x3 Sample pixel positions (flattened indices): tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169