### TSP Policy-Based RL Notebook

In [2]:
# # %%
# # Imports & Setup
# import os
# import glob
# import time
# import math
# import random
# import numpy as np

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import matplotlib.pyplot as plt

# from TSProblemDef import augment_xy_data_by_8_fold, get_random_problems
# from utils.utils import TimeEstimator

# # Device
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# torch.set_default_tensor_type('torch.cuda.FloatTensor' if device.type=='cuda' else 'torch.FloatTensor')

In [None]:

# %% [markdown]
# ## Environment Definition (TSPEnv)

class ResetState:
    def __init__(self, problems):
        self.problems = problems  # (batch, problem_size, 2)

class StepState:
    def __init__(self, BATCH_IDX, POMO_IDX, ninf_mask):
        self.BATCH_IDX = BATCH_IDX
        self.POMO_IDX = POMO_IDX
        self.current_node = None
        self.ninf_mask = ninf_mask

class TSPEnv:
    def __init__(self, problem_size, pomo_size):
        self.problem_size = problem_size
        self.pomo_size = pomo_size

    def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size
        self.problems = get_random_problems(batch_size, self.problem_size)
        if aug_factor == 8:
            self.batch_size *= 8
            self.problems = augment_xy_data_by_8_fold(self.problems)
        self.BATCH_IDX = torch.arange(self.batch_size)[:,None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX  = torch.arange(self.pomo_size)[None,:].expand(self.batch_size, self.pomo_size)

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        mask = torch.zeros((self.batch_size, self.pomo_size, self.problem_size))
        self.step_state = StepState(self.BATCH_IDX, self.POMO_IDX, mask)
        return ResetState(self.problems), None, False

    def pre_step(self):
        return self.step_state, None, False

    def step(self, selected):
        self.selected_count += 1
        self.current_node = selected
        self.selected_node_list = torch.cat((self.selected_node_list, selected[:,:,None]), dim=2)
        self.step_state.current_node = selected
        m = self.step_state.ninf_mask
        m[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
        done = (self.selected_count == self.problem_size)
        reward = -self._get_travel_distance() if done else None
        return self.step_state, reward, done

    def _get_travel_distance(self):
        coords = self.problems
        order = self.selected_node_list.unsqueeze(3).expand(self.batch_size, -1, self.problem_size, 2)
        seq = coords[:,None,:,:].expand(self.batch_size, self.pomo_size, self.problem_size, 2)
        ordered = seq.gather(2, order)
        rolled = ordered.roll(dims=2, shifts=-1)
        seg = ((ordered - rolled)**2).sum(3).sqrt()
        return seg.sum(2)

# %% [markdown]
# ## Transformer Encoder Model

def reshape_by_heads(x, head_num):
    B,n,_ = x.size()
    return x.view(B,n,head_num,-1).transpose(1,2)

class AddNorm(nn.Module):
    def __init__(self, emb):
        super().__init__()
        self.norm = nn.InstanceNorm1d(emb, affine=True)
    def forward(self, x, sub):
        y = x + sub
        return self.norm(y.transpose(1,2)).transpose(1,2)

class FeedForward(nn.Module):
    def __init__(self, emb, hidden):
        super().__init__()
        self.w1 = nn.Linear(emb, hidden)
        self.w2 = nn.Linear(hidden, emb)
    def forward(self, x):
        return self.w2(F.relu(self.w1(x)))

class TSP_Encoder(nn.Module):
    def __init__(self, embedding_dim, encoder_layer_num, head_num, qkv_dim, ff_hidden_dim, **_):
        super().__init__()
        self.embed = nn.Linear(2, embedding_dim)
        self.layers = nn.ModuleList()
        for _ in range(encoder_layer_num):
            self.layers.append(nn.ModuleDict({
                'Wq': nn.Linear(embedding_dim, head_num*qkv_dim, bias=False),
                'Wk': nn.Linear(embedding_dim, head_num*qkv_dim, bias=False),
                'Wv': nn.Linear(embedding_dim, head_num*qkv_dim, bias=False),
                'combine': nn.Linear(head_num*qkv_dim, embedding_dim),
                'addnorm1': AddNorm(embedding_dim),
                'ff': FeedForward(embedding_dim, ff_hidden_dim),
                'addnorm2': AddNorm(embedding_dim)
            }))
        self.head_num = head_num

    def forward(self, coords):
        x = self.embed(coords)
        for l in self.layers:
            q = reshape_by_heads(l['Wq'](x), self.head_num)
            k = reshape_by_heads(l['Wk'](x), self.head_num)
            v = reshape_by_heads(l['Wv'](x), self.head_num)
            dim = q.size(-1)
            scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(dim)
            att = torch.matmul(F.softmax(scores, dim=-1), v)
            att = att.transpose(1,2).contiguous().view(x.size(0), x.size(1), -1)
            att = l['combine'](att)
            x = l['addnorm1'](x, att)
            ff = l['ff'](x)
            x = l['addnorm2'](x, ff)
        return x

# %% [markdown]
# ## Instance Loading Helpers

def load_npz_instances(N, count):
    files = glob.glob(f'data/TSP{N}/instances/*.npz')
    chosen = random.sample(files, count)
    data = []
    for f in chosen:
        arr = np.load(f)['coords']
        data.append(torch.from_numpy(arr).float())
    return torch.stack(data,0)

# %% [markdown]
# ## Training with Curriculum

# Hyperparams
model_params = {
    'embedding_dim':128, 'encoder_layer_num':3,
    'head_num':8, 'qkv_dim':64, 'ff_hidden_dim':512,
    'sqrt_embedding_dim':math.sqrt(128), 'logit_clipping':10.0
}
env_params = {'problem_size':20, 'pomo_size':10}
opt_params = {'lr':1e-4, 'betas':(0.9,0.999), 'weight_decay':1e-4}
sched_params = {'milestones':[20000,50000], 'gamma':0.5}
train_params = {'batch_size':512, 'log_interval':1000}
curriculum = [
    {'sizes':[10,20], 'steps':10000},
    {'sizes':[20,30,40], 'steps':30000},
    {'sizes':[30,40,50], 'steps':60000}
]

env = TSPEnv(**env_params)
encoder = TSP_Encoder(**model_params).to(device)
opt = torch.optim.Adam(encoder.parameters(), **opt_params)
sched = torch.optim.lr_scheduler.MultiStepLR(opt, **sched_params)

time_est = TimeEstimator()
steps, loss_log, dist_log = [], [], []
global_step, start = 0, time.time()

def policy_head(q, emb, mask):
    logits = torch.matmul(q, emb.transpose(1,2))
    logits = logits / model_params['sqrt_embedding_dim']
    logits = model_params['logit_clipping']*torch.tanh(logits)
    logits = logits + mask
    return F.softmax(logits, dim=-1)

for stage in curriculum:
    for _ in range(stage['steps']):
        global_step +=1
        sz = random.choice(stage['sizes'])
        env.problem_size = sz
        env.load_problems(train_params['batch_size'])
        state,_,_ = env.reset()
        emb = encoder(state.problems.to(device))
        B,P = train_params['batch_size'], env.pomo_size
        first = torch.arange(P,device=device).unsqueeze(0).repeat(B,1)
        q = emb.gather(1, first[:,:,None].expand(-1,-1,model_params['embedding_dim']))
        mask = torch.zeros((B,P,sz),device=device)
        logs, done = [], False
        while not done:
            probs = policy_head(q, emb, mask)
            idx = probs.reshape(B*P,-1).multinomial(1).view(B,P)
            logs.append(torch.log(probs[torch.arange(B)[:,None], torch.arange(P)[None,:], idx]))
            state, reward, done = env.step(idx)
            mask = state.ninf_mask.to(device)
            q = emb.gather(1, state.current_node[:,:,None].expand(-1,-1,model_params['embedding_dim']))
        r = reward.to(device)
        adv = r - r.mean(1,keepdim=True)
        lp = torch.stack(logs,2).sum(2)
        loss = -(adv*lp).mean()
        opt.zero_grad(); loss.backward(); opt.step(); sched.step()
        if global_step%train_params['log_interval']==0:
            avgd = -r.max(1)[0].mean().item()
            print(f"Step {global_step} | Sz {sz} | Loss {loss.item():.4f} | Dist {avgd:.4f} | Elap {time.time()-start:.1f}s")
            steps.append(global_step); loss_log.append(loss.item()); dist_log.append(avgd)
    # stage eval on 20
    test = load_npz_instances(20,20).to(device)
    d=[]
    for inst in test:
        env.problems = inst.unsqueeze(0)
        env.batch_size=1
        env.BATCH_IDX=torch.zeros((1,env.pomo_size),dtype=torch.long)
        env.POMO_IDX=torch.arange(env.pomo_size)[None,:]
        state,_,_=env.reset()
        emb=encoder(state.problems.to(device))
        B,P=1,env.pomo_size
        first=torch.arange(P,device=device).unsqueeze(0)
        q=emb.gather(1, first[:,:,None].expand(-1,-1,model_params['embedding_dim']))
        mask=torch.zeros((B,P,20),device=device)
        done=False
        while not done:
            probs=policy_head(q,emb,mask)
            idx=probs.argmax(dim=2)
            state,reward,done=env.step(idx)
            mask=state.ninf_mask.to(device)
            q=emb.gather(1,state.current_node[:,:,None].expand(-1,-1,model_params['embedding_dim']))
        d.append((-reward).item())
    print(f"Stage eval avg dist: {sum(d)/len(d):.4f}")

# %% [markdown]
# ## Plot Learning Curve

plt.figure(figsize=(8,4))
plt.plot(steps, dist_log)
plt.title('Avg Tour Length')
plt.xlabel('Steps'); plt.ylabel('Distance')
plt.show()

# %% [markdown]
# ## Final Evaluation

coords20 = load_npz_instances(20,100).to(device)
coords50 = load_npz_instances(50,100).to(device)
for N, data in [(20,coords20),(50,coords50)]:
    d=[]
    for inst in data:
        env.problems=inst.unsqueeze(0)
        env.batch_size=1
        env.BATCH_IDX=torch.zeros((1,env.pomo_size),dtype=torch.long)
        env.POMO_IDX=torch.arange(env.pomo_size)[None,:]
        state,_,_=env.reset()
        emb=encoder(state.problems.to(device))
        B,P=1,env.pomo_size
        first=torch.arange(P,device=device).unsqueeze(0)
        q=emb.gather(1, first[:,:,None].expand(-1,-1,model_params['embedding_dim']))
        mask=torch.zeros((B,P,N),device=device)
        done=False
        while not done:
            probs=policy_head(q,emb,mask)
            idx=probs.argmax(dim=2)
            state,reward,done=env.step(idx)
            mask=state.ninf_mask.to(device)
            q=emb.gather(1,state.current_node[:,:,None].expand(-1,-1,model_params['embedding_dim']))
        d.append((-reward).item())
    print(f"Final {N}-city avg dist: {sum(d)/len(d):.4f}")
