In [1]:
import torch
import numpy as np
import pandas as pd
import os

import sys
sys.path.append("../")

from solver.env import TSP
from solver.agent import QAgent, DQAgent
from solver.runner import Runner
from solver.memory import *
from solver.net import DQN

import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
path = os.getcwd()
path_ = "\\".join(path.split("\\")[:-1])

df = pd.read_csv(path_ + f"\\data\\trimmed_data.csv")
data = df[['OUTLET_LATITUDE', 'OUTLET_LONGITUDE']].head(10).to_numpy()
print(df.info())
df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2233 entries, 0 to 2232
Data columns (total 3 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   key               2233 non-null   int64  
 1   OUTLET_LATITUDE   2233 non-null   float64
 2   OUTLET_LONGITUDE  2233 non-null   float64
dtypes: float64(2), int64(1)
memory usage: 52.5 KB
None


Unnamed: 0,key,OUTLET_LATITUDE,OUTLET_LONGITUDE
0,722,10.784402,106.699383
1,2087,10.767384,106.697778
2,2247,10.777919,106.698471
3,3178,10.7667,106.689
4,5915,10.762336,106.687443


In [3]:
env = TSP(
    data = data, 
    # distance_method = 'euclidean_distance',
    fixed_start = True, 
    cum_reward = False, 
    render_mode = None
)

policy_net = DQN(
    num_stops = data.shape[0]
    # , hidden_dim = 128
    , output_dim = data.shape[0]
)
target_net = policy_net
memory = ReplayMemory(1_000)

agent = DQAgent(
    env = env, net = [policy_net, target_net],
    states_size = data.shape[0], actions_size = data.shape[0],
    device = 'cuda'
)

runner = Runner(
    env = env, agent = agent, mem = memory
)

In [4]:
num_episodes = 100_000
env, agent, best_R, best_route, fig = runner.run(num_episodes = num_episodes)
fig.show()

100%|██████████| 100000/100000 [1:28:14<00:00, 18.89it/s] 


In [8]:
runner.load_model(best = True)
agent = runner.agent
n_infer = 1_000
R_infer_list = []

for i in range(n_infer):
    route_infer, R_infer = agent.inference()
    # route_infer[:6] + ['...'] + route_infer[-5:], R_infer
    R_infer_list.append(R_infer)

min(R_infer_list)

-11.155918942430949

In [6]:
q_table = agent.inspect_qtable(nstops = data.shape[0])
print(q_table.shape)

fig_qtable = go.Figure(
    data = [go.Surface(z = q_table)]
)

fig_qtable.update_layout(
    title = dict(text = 'State Action Values'), 
    autosize = False,
    width = 700, height = 700,
    margin = dict(l=65, r=50, b=65, t=90)
)

fig_qtable.show()

(10, 10)
