In [None]:
import torch
import pickle
from problems.tsp.problem_tsp import gen_mst_graph
from tqdm import tqdm
from torch_geometric.data import Data

In [None]:
def rand_gen_mst(idx):
    torch.random.seed()
    g_list = [gen_mst_graph(graph_size) for _ in range(n_samples)]
    torch.save(g_list, save_prefix+f"_{idx}.pkl")
    print(f"dataset-{idx} created")

In [None]:
import torch.multiprocessing as mp
from torch.multiprocessing import Pool
torch.multiprocessing.set_sharing_strategy('file_system')


In [None]:
n_process = 3
graph_size = 20
n_samples = 128000 // 2
n_repeats = 20
save_prefix = f"/home/pxh/attention-learn-to-route/data/tsp/tsp{graph_size}_train"

In [None]:
out = []

with Pool(n_process) as p:
    p.map(rand_gen_mst, range(n_repeats))


In [None]:
dataset = []

In [None]:
for idx in range(n_repeats):
    g_list = torch.load(save_prefix+f"_{idx}.pkl")
    dataset += g_list
len(dataset)

In [None]:
len(dataset)

In [None]:
with open(save_prefix+f"_pre_generated.pkl", 'wb') as f:
    pickle.dump(dataset, f)

In [None]:
with open(save_prefix+f"_pre_generated.pkl", 'rb') as f:
    dataset = pickle.load(f)

## Generate Graph with networkx

In [None]:
import networkx as nx
import numpy as np
import multiprocessing as mp
from scipy.spatial.distance import cdist
from itertools import combinations

In [None]:
pos = np.random.uniform(low=0.0, high=1.0, size=(50,2))

In [None]:
dist = cdist(pos, pos, metric='euclidean')

In [None]:
graph = nx.complete_graph(50)

In [None]:
for n in graph.nodes:
    graph.nodes[0]["pos"] = pos[n]

In [None]:
for u, v in graph.edges:
    graph[u][v]["weight"] = dist[u,v]

In [None]:
mst = nx.algorithms.tree.mst.minimum_spanning_tree(graph, algorithm="prim")

In [None]:
def gen_mst_graph(graph_size):
    pos = np.random.uniform(low=0.0, high=1.0, size=(graph_size,2))
    dist = cdist(pos, pos, metric='euclidean')
    graph = nx.complete_graph(graph_size)
    for n in graph.nodes:
        graph.nodes[n]["pos"] = pos[n]
    for u, v in graph.edges:
        graph[u][v]["weight"] = dist[u,v]
        
    mst = nx.algorithms.tree.mst.minimum_spanning_tree(graph, algorithm="prim")
    
    return mst

def gen_complete_graph(graph_size):
    pos = np.random.uniform(low=0.0, high=1.0, size=(graph_size,2))
    dist = cdist(pos, pos, metric='euclidean')
    graph = nx.complete_graph(graph_size)
    for n in graph.nodes:
        graph.nodes[n]["pos"] = pos[n]
    for u, v in graph.edges:
        graph[u][v]["weight"] = dist[u,v]
        
    return graph

In [None]:
%load_ext line_profiler

In [None]:
%timeit gen_mst_graph(50)

In [None]:
batch_g = nx.union_all([graph]*10, rename=[str(i) for i in range(10)])

In [None]:
batch_mst = nx.algorithms.tree.mst.minimum_spanning_tree(batch_g, algorithm="prim")

In [None]:
for subg in nx.connected_components(batch_test):
    print(from_networkx(batch_test.subgraph(subg)))

In [None]:
def test_batch_mst(graph_size, batch_size):
    g_list = [gen_complete_graph(graph_size) for _ in range(batch_size)]
    batch_g = nx.union_all(g_list, rename=[f"{str(i)}_" for i in range(batch_size)])
    batch_mst = nx.algorithms.tree.mst.minimum_spanning_tree(batch_g, algorithm="prim")
    
    return [batch_mst.subgraph(subg) for subg in nx.connected_components(batch_mst)]

def test_sep_mst(graph_size, batch_size):
    return [gen_mst_graph(graph_size) for _ in range(batch_size)]


In [None]:
%lprun -f test_sep_mst -f gen_mst_graph test_sep_mst(50, 128)

In [None]:
%lprun -f test_batch_mst test_batch_mst(50, 128)

In [None]:
cg = nx.caveman_graph(2,10)

In [None]:
pos = np.random.uniform(low=0.0, high=1.0, size=(128, 50, 2))

In [None]:
dist = np.stack([cdist(p, p, metric='euclidean') for p in pos])

In [None]:
def test_cal_dist(pos):
    bs, gs, _ = pos.shape
    out = []
    dist = np.stack([cdist(p, p, metric='euclidean') for p in pos])
    for b in range(bs):
        for u, v in combinations(range(gs), 2):
            d = dist[b, u, v]
            out.append(d)
    return np.array(out)

In [None]:
def test_sep_dist(pos):
    bs, gs, _ = pos.shape
    out = []
    for b in range(bs):
        for u, v in combinations(range(gs), 2):
            d = np.linalg.norm(pos[b,u] - pos[b,v])
            out.append(d)
    return np.array(out)

In [None]:
%timeit test_sep_dist(pos)

In [None]:
%timeit test_cal_dist(pos)