In [1]:
import torch
import torch.nn.functional as F
from multiprocessing import Pool

In [2]:
def read_dataset(dataset_path: str):
    edge_index = [[], []]
    with open(dataset_path, "r") as f:
        f.readline()  # Skip header
        for line in f:
            a, b = map(int, line.strip().split(","))
            edge_index[0].append(a)
            edge_index[0].append(b)
            edge_index[1].append(b)
            edge_index[1].append(a)
    edge_index = torch.tensor(edge_index)
    num_nodes = edge_index.max() + 1
    sparse = torch.sparse_coo_tensor(
        edge_index,
        torch.ones(edge_index.shape[1]),
        (num_nodes, num_nodes),
        dtype=torch.float32
    ).coalesce()
    return sparse.indices(), sparse

In [3]:
edge_index, sparse = read_dataset("data/lasftm_asia/lastfm_asia_edges.csv")
num_nodes = edge_index.max() + 1

In [4]:
def cost(v: int, sparse: torch.sparse.FloatTensor):
    return 1

In [5]:
def f1(S: torch.Tensor, sparse: torch.sparse.FloatTensor):
    a = (sparse * S) # Get intersections between S and the Graph
    intersection = a.indices()[0, a.values().nonzero().flatten()].bincount()
    intersection = F.pad(intersection, (0, num_nodes - intersection.shape[0]), value=0)
    return torch.minimum(
        intersection,
        edge_index[0].bincount() / 2
    ).sum()

In [6]:
num_nodes = sparse.indices().max() + 1
S_p = torch.zeros(num_nodes, dtype=torch.float32)
S_d = torch.zeros(num_nodes, dtype=torch.float32)
nodes = torch.ones(num_nodes, dtype=torch.float32)
ss_cost = 0

In [7]:
def worker(v):
    return f1(S_d + F.one_hot(torch.tensor(v), num_nodes).detach(), sparse)

with Pool(12) as pool:
    results = pool.map(worker, [v for v in range(num_nodes)])

KeyboardInterrupt: 

In [None]:
def cost_seeds_greedy(sparse: torch.sparse.FloatTensor, k: int, f: callable, c: callable):
    num_nodes = sparse.indices().max() + 1
    S_p = torch.zeros(num_nodes, dtype=torch.float32)
    S_d = torch.zeros(num_nodes, dtype=torch.float32)
    nodes = torch.ones(num_nodes, dtype=torch.float32)
    ss_cost = 0
    while True:
        nodes = nodes * (1. - S_d)
        u = torch.tensor([f1(S_d + F.one_hot(torch.tensor(v), num_nodes).detach(), sparse) for v in range(num_nodes)]).argmax()
        S_p = S_d.detach().clone()
        S_d = S_d + F.one_hot(torch.tensor(u), num_nodes)
        ss_cost_temp = ss_cost + c(u, sparse)
        if ss_cost_temp > k:
            break
        ss_cost = ss_cost_temp
    return S_p, ss_cost

In [None]:
cost_seeds_greedy(sparse, 10, f1, cost)