In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import networkx as nx
import numpy as np
import cvxpy as cp
from tqdm import tqdm
from pathlib import Path

from src.load_data import (
    read_metadata_networks_tntp,
    read_graph_transport_networks_tntp,
    read_traffic_mat_transport_networks_tntp,
)

from src.models import SDModel, BeckmannModel, TwostageModel
from src.algs import subgd, ustm, frank_wolfe, cyclic, N_conjugate_frank_wolfe
from src.cvxpy_solvers import get_max_traffic_mat_mul
from src.commons import Correspondences
from src.saddle_ta import salim_ta, chambolle_pock_ta

import matplotlib.pyplot as plt

plt.rcParams.update({"font.size": 14})
%config InlineBackend.figure_format = 'retina'

%matplotlib inline

In [None]:
networks_path = Path("./TransportationNetworks")

folder = "SiouxFalls"
net_name = "SiouxFalls_net"
traffic_mat_name = "SiouxFalls_trips"

# folder = "Anaheim"
# net_name = "Anaheim_net"
# traffic_mat_name = "Anaheim_trips"

# folder = "Barcelona"
# net_name = "Barcelona_net"
# traffic_mat_name = "Barcelona_trips"


net_file = networks_path / folder / f"{net_name}.tntp"
traffic_mat_file = networks_path / folder / f"{traffic_mat_name}.tntp"
graph, metadata = read_graph_transport_networks_tntp(net_file)
correspondences = read_traffic_mat_transport_networks_tntp(traffic_mat_file, metadata)
n = graph.number_of_nodes()

print(f"{graph.number_of_edges()=}, {graph.number_of_nodes()=}")

In [None]:
beckmann_model = BeckmannModel(graph, correspondences)

eps = 1e-4
mean_bw = beckmann_model.graph.ep.capacities.a.mean()
mean_cost = beckmann_model.graph.ep.free_flow_times.a.mean()

# cost suboptimality <= eps * (average link cost * avg bandwidth * |E| \approx total cost when beta=1)
eps_abs = eps * mean_cost * mean_bw * graph.number_of_edges()

eps_cons_abs = eps * mean_bw
# sum of capacity violation <= eps * average link capacity
print(eps_abs, eps_cons_abs)

# Beckmann

In [None]:
times_e_ustm, flows_e_ustm, logs, optimal = ustm(beckmann_model, eps_abs, max_iter=1000, stop_by_crit=False)
dgap_ustm, cons_log_ustm, time_log_ustm = logs
print(len(dgap_ustm), "shortest paths calls")

In [None]:
times_e_fw, flows_e_fw, logs, optimal = frank_wolfe(beckmann_model, eps_abs, max_iter=7700, stop_by_crit=False)
dgap_fw, time_log_fw, primal_r_gap_fw = logs

In [None]:
times_e_nfw, flows_e_nfw, logs, optimal = N_conjugate_frank_wolfe(
    beckmann_model, eps_abs, max_iter=4000, stop_by_crit=False, cnt_conjugates=3, linesearch=True
)
dgap_nfw, time_log_nfw, primal_r_gap_nfw = logs

In [None]:
plt.figure(figsize=(10, 4))
plt.subplot(121)
dgap_ustm = np.abs(dgap_ustm)
dgap_fw = np.abs(dgap_fw)
dgap_nfw = np.abs(dgap_nfw)
plt.plot(time_log_ustm, dgap_ustm, c="C4", label="USTM")
plt.plot(time_log_fw, dgap_fw, c="C5", label="FW")
plt.plot(time_log_nfw, dgap_nfw, c="C3", label="NFW")
plt.axhline(y=np.ones(max(dgap_ustm.size, dgap_fw.size))[0] * eps_abs, linestyle="--", label="eps_abs")
plt.yscale("log")
plt.title("abs(dgap)")
plt.legend()

In [None]:
from src.salim import SaddleOracle

In [None]:
A = incidence_mat = nx.incidence_matrix(beckmann_model.nx_graph, oriented=True).todense()

saddle_oracle = SaddleOracle(beckmann_model, None, None, None)

Ld = saddle_oracle.Bmul(beckmann_model.correspondences.traffic_mat).T
b = -Ld

In [None]:
svals = np.linalg.svd(A)[1]
lam1 = svals[0] ** 2
lam2 = svals[svals > 1e-8][-1] ** 2

L_sq = lam1 + Ld.shape[0]

lam1 *= 2
lam2 /= 2
lam1, lam2

In [None]:
mu = 1e-2
L = 100
iters = 100000

y_salim, f_salim, cons_log, opt_log = salim_ta(beckmann_model, iters=iters, mu=mu, L=L, lam1=lam1, lam2=lam2)

In [None]:
plt.plot(cons_log, label="cons")
plt.plot(opt_log, label="func")
plt.legend()
plt.yscale("log")

In [None]:
# plt.plot(flows_e_nfw)
# plt.plot(x.sum(axis=1))

In [None]:
# chambolle-pock
iters = 10000

# gamma = nu = 1 / L_sq ** 0.5

# nu = 0.01
# gamma = 1 / L_sq / nu

gamma = 0.001
nu = 1 / L_sq / gamma

print(f"{gamma, nu, L_sq =}")

y_cp, f_cp, cons_log, opt_log = chambolle_pock_ta(beckmann_model, iters=iters, gamma=gamma, nu=nu)

In [None]:
plt.plot(cons_log, label="cons")
plt.plot(opt_log, label="func")
plt.legend()
plt.grid()
plt.yscale("log")

In [None]:
# plt.plot(flows_e_nfw)
# plt.plot(f.sum(axis=1))