In [28]:
from torch import Tensor as T
import torch
import random

### symmetrize_edge_weights function

In [29]:
def symmetrize_edge_weights(edge_indices: T, edge_weights: T) -> T:
    """Method to symmetrize edge weights by averaging the weights of every
    symmetric edge

    Args:
        edge_indices: Tensor containing edge indices of a graph
        edge_weights: Weight for each edge
    """
    edge_dictionary = {}
    symmetric_pairings = []

    # Find indices of symmetric pairs by storing edges in a dictionary
    for i in range(len(edge_indices)):
        a1, a2 = edge_indices[i].tolist()
        a1 = int(a1)
        a2 = int(a2)

        edge_symmetric = False
        if a2 not in edge_dictionary and a1 not in edge_dictionary:
            edge_dictionary[a1] = [(a2, i)]
            continue
        if a2 not in edge_dictionary:
            edge_dictionary[a2] = []
        if a1 not in edge_dictionary:
            edge_dictionary[a1] = []        
        for node in edge_dictionary[a2]:
            if node[0] == a1:
                symmetric_pairings.append((i, node[1]))
                edge_symmetric = True
                break
        if not edge_symmetric:
            edge_dictionary[a2].append((a1, i))
    
    # Update weights for symmetric pairs
    updated_values = edge_weights.clone()
    for pair in symmetric_pairings:
        average_weight = (updated_values[pair[0]] + updated_values[pair[1]])/2
        updated_values[pair[0]] = average_weight
        updated_values[pair[1]] = average_weight
    return updated_values


In [30]:
symmetrize_edge_weights(T([[1, 2], [2, 1]]), T([1, 3]))

tensor([2., 2.])

### Unit Test Code:

MockEdgeData creates a randomly ordered set of edges and weights where some edges are symmetric and others are non-symmetric. Size parameter determine a scaling factor for the random variables which determine list length.

In [31]:
class MockEdgeData:
    def __init__(self, size):
        overlap = torch.randint(0, size, (1,)).item()
        rand = torch.randint(0, size, (1,)).item()

        # Add some number of symmetric edges
        self.edges = [[i, overlap - i] for i in range(int(overlap/2))] + \
            [[overlap - i, i] for i in range(int(overlap/2))]
        self.weights = [int(size * random.random()) for i in range(2 * int(overlap/2))]
        self.targets = [(self.weights[i] + self.weights[i + int(overlap/2)])/2 for i in range(int(overlap/2))] + \
            [(self.weights[i] + self.weights[i + int(overlap/2)])/2 for i in range(int(overlap/2))]

        # Add some number of non-symmetric edges
        for i in range(rand):
            self.edges.append([i + overlap + 1, i + overlap + 1 + int(50 * random.random())])
            self.weights.append(int(size * random.random()))
            self.targets.append(self.weights[-1])

        # Randomly order edges
        self.edges = T(self.edges)
        self.weights = T(self.weights)
        self.targets = T(self.targets)
        n = self.edges.shape[0]
        perm = torch.randperm(n) 
        self.edges = self.edges[perm]
        self.weights = T(self.weights[perm])
        self.targets = T(self.targets[perm])

In [32]:
def test_symmetrize_edge_weights():
    mock = MockEdgeData(1000)
    assert torch.equal(symmetrize_edge_weights(mock.edges, mock.weights), mock.targets)
    print("Test Passed")

In [33]:
test_symmetrize_edge_weights()

Test Passed
