In [3]:

# XXX 
# TODO THIS CODE SNIPPET STARTS THE WEIGHTED AVERAGE FIX
# XXX

import torch
from typing import Iterable, SupportsIndex

def _index_out_of_bounds(arr: Iterable | SupportsIndex, row: int, col: int) -> bool:
    """
    Checks if a row and column are out of bounds of a 2D array.
    """
    return row < 0 or col < 0 or row >= len(arr) or col >= len(arr[row])

def weighted_pool(arr: torch.Tensor | SupportsIndex, row: int, col: int, kernel: torch.Tensor) -> list:
    """
    Finds the weighted average of a specified cell's neighbors, based on `kernel`.
    Averages are calculated by (neighborhood of cell * kernel) / sum(kernel)
    + If kernel is partially out of bounds, a partial average is calculated; no padding will
    be added (neighborhood of cell that are in bounds * kernel in bounds) / sum(kernel in bounds)
    """
    
    weighted_sum = 0.0
    kernel_weights_used = 0.0
    
    _kernel_lrow = kernel.shape[0] // 2 # radius along rows (horizontal), also the center row of the kernel
    _kernel_lcol = kernel.shape[1] // 2 # radius along cols (vertical), also the center col of the kernel
    
    for mov_row in range(-_kernel_lrow, _kernel_lrow + 1):
        for mov_col in range(-_kernel_lcol, _kernel_lcol + 1):
            if _index_out_of_bounds(arr, row + mov_row, col + mov_col):
                continue
            
            weight = kernel[mov_row + _kernel_lrow][mov_col + _kernel_lcol]
            _dot = weight * arr[row + mov_row][col + mov_col]
            weighted_sum += _dot
            kernel_weights_used += weight
            
    return weighted_sum / kernel_weights_used

# create a tensor
tensor = torch.ones((5, 5))
tensor[0][0] = 0.0
print(tensor)
# create weights. should be symmetrical
weights = torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
weights = weights / sum(weights.view(-1))

# apply the weighted average to each element based on its neighbors
weighted_tensor = torch.nn.functional.conv2d(tensor.view(1, 1, 5, 5), weights.view(1, 1, 3, 3), padding=1)

# loop through the edge of the original tensor and find the correct values (without padding
weighted_tensor = weighted_tensor.view(5, 5)

aux_arr = weighted_tensor.clone()
for row in range(len(tensor)):
    for col in range(len(tensor[row])):
        if row == 0 or row == len(tensor) - 1 or col == 0 or col == len(tensor[row]) - 1:
            # on the edge of the tensor
            value = weighted_pool(tensor, row, col, weights)
            print(f"value at {row}, {col} is {value}")
            aux_arr[row][col] = value

print(aux_arr)


tensor([[0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
value at 0, 0 is 0.75
value at 0, 1 is 0.8333333730697632
value at 0, 2 is 1.0
value at 0, 3 is 1.0
value at 0, 4 is 1.0
value at 1, 0 is 0.8333333730697632
value at 1, 4 is 1.0
value at 2, 0 is 1.0
value at 2, 4 is 1.0
value at 3, 0 is 1.0
value at 3, 4 is 1.0
value at 4, 0 is 1.0
value at 4, 1 is 1.0
value at 4, 2 is 1.0
value at 4, 3 is 1.0
value at 4, 4 is 1.0
tensor([[0.7500, 0.8333, 1.0000, 1.0000, 1.0000],
        [0.8333, 0.8889, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]])


In [70]:
def _index_out_of_bounds(arr, row: int, col: int) -> bool:
    """
    Checks if a row and column are out of bounds of a 2D array.
    """
    return row < 0 or col < 0 or row >= len(arr) or col >= len(arr[row])

def get_numbers_around_location(arr, row: int, col: int, radius: int = 1) -> list:
    """
    Gets all numbers around a specified location in a 2D array as a list.
    If on edge, the list will contain less numbers.
    """
    nums = []
    
    for mov_row in range(-radius, radius + 1):
        for mov_col in range(-radius, radius + 1):
            if mov_row == 0 and mov_col == 0:
                continue
            
            if _index_out_of_bounds(arr, row + mov_row, col + mov_col):
                continue
            
            nums.append(arr[max(0,row + mov_row)][max(0,col + mov_col)])
    return nums
            
testarr = [
    [1, 2, 3, 4, 5],
    [6, 7, 8, 9, 10],
    [11,12,13,14,15],
    [16,17,18,19,20],
    [21,22,23,24,25]    
]

print(get_numbers_around_location(testarr, 1, 0, 2))

[1, 2, 3, 7, 8, 11, 12, 13, 16, 17, 18]


tensor([[0.1778, 0.3000, 0.3667, 0.4333, 0.3111],
        [0.4333, 0.7000, 0.8000, 0.9000, 0.6333],
        [0.7667, 1.2000, 1.3000, 1.4000, 0.9667],
        [1.1000, 1.7000, 1.8000, 1.9000, 1.3000],
        [0.8444, 1.3000, 1.3667, 1.4333, 0.9778]])
