### 网络模型

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Attention(nn.Module):
    def __init__(self, cfg, use_tanh=False, clip_logits=10):
        super().__init__()
        self.use_tanh = use_tanh
        self.W_q = nn.Linear(cfg.hidden, cfg.hidden, bias = True)
        self.W_ref = nn.Conv1d(cfg.hidden, cfg.hidden, 1, 1)
        self.clip_logits = clip_logits


    def forward(self, query, ref, mask, inf = 1e8): 
        u1 = self.W_q(query).unsqueeze(-1).repeat(1,1,ref.size(1))# u1: (batch, 128, city_t)
        u2 = self.W_ref(ref.permute(0,2,1))# u2: (batch, 128, city_t)
        V = self.Vec.unsqueeze(0).unsqueeze(0).repeat(ref.size(0), 1, 1)
        if self.use_tanh:
            u = torch.bmm(V, torch.tanh(u1 + u2)).squeeze(1)
            # V: (batch, 1, 128) * u1+u2: (batch, 128, city_t) => u: (batch, 1, city_t) => (batch, city_t)
            u = u - inf * mask
            a = F.softmax(u / self.softmax_T, dim = 1)
            d = torch.bmm(u2, a.unsqueeze(2)).squeeze(2)
            # u2: (batch, 128, city_t) * a: (batch, city_t, 1) => d: (batch, 128)
        else:
            u = torch.bmm(V, self.clip_logits * torch.tanh(u1 + u2)).squeeze(1)
            # V: (batch, 1, 128) * u1+u2: (batch, 128, city_t) => u: (batch, 1, city_t) => (batch, city_t)
            u = u - inf * mask
            d = u
        return d


class Greedy(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, log_p):
        return torch.argmax(log_p, dim = 1).long()

class Categorical(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, log_p):
        return torch.multinomial(log_p.exp(), 1).long().squeeze(1)

class Ptr_Actor(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.Embedding = nn.Linear(2, cfg.embed, bias = False)

        # 输入 input, (h_0, c_0) input:(batch, seq, feature) 
        # 输出 output, (h_n, c_n) output:(batch, seq, D*H_out)
        self.Encoder = nn.LSTM(input_size = cfg.embed, hidden_size = cfg.hidden, batch_first = True)
        self.Decoder = nn.LSTM(input_size = cfg.embed, hidden_size = cfg.hidden, batch_first = True)
        if torch.cuda.is_available():
            self.Vec = nn.Parameter(torch.cuda.FloatTensor(cfg.embed))
            self.Vec2 = nn.Parameter(torch.cuda.FloatTensor(cfg.embed))
        else:
            self.Vec = nn.Parameter(torch.FloatTensor(cfg.embed))
            self.Vec2 = nn.Parameter(torch.FloatTensor(cfg.embed))
        self.W_q = nn.Linear(cfg.hidden, cfg.hidden, bias = True)
        self.W_ref = nn.Conv1d(cfg.hidden, cfg.hidden, 1, 1)
        self.W_q2 = nn.Linear(cfg.hidden, cfg.hidden, bias = True)
        self.W_ref2 = nn.Conv1d(cfg.hidden, cfg.hidden, 1, 1)
        self.dec_input = nn.Parameter(torch.FloatTensor(cfg.embed))
        self._initialize_weights(cfg.init_min, cfg.init_max)
        self.clip_logits = cfg.clip_logits
        self.softmax_T = cfg.softmax_T
        self.n_glimpse = cfg.n_glimpse
        self.city_selecter = {'greedy': Greedy(), 'sampling': Categorical()}.get(cfg.decode_type, None)
        self.pointer = Attention(cfg, use_tanh=True, clip_logits=10)
        self.glimpse = Attention(cfg, use_tanh=False, clip_logits=10)
    
    def _initialize_weights(self, init_min = -0.08, init_max = 0.08):
        for param in self.parameters():
            nn.init.uniform_(param.data, init_min, init_max)
        
    def forward(self, x, device):
        '''	x: (batch, city_t, 2)
            enc_h: (batch, city_t, embed)
            dec_input: (batch, 1, embed)
            h: (1, batch, embed)
            return: pi: (batch, city_t), ll: (batch)
        '''
        x = x.to(device)
        batch, city_t, _ = x.size()
        embed_enc_inputs = self.Embedding(x)
        embed = embed_enc_inputs.size(2)
        mask = torch.zeros((batch, city_t), device = device)
        enc_h, (h, c) = self.Encoder(embed_enc_inputs, None)
        ref = enc_h
        pi_list, log_ps = [], []
        dec_input = self.dec_input.unsqueeze(0).repeat(batch,1).unsqueeze(1).to(device)
        for i in range(city_t):
            _, (h, c) = self.Decoder(dec_input, (h, c))
            query = h.squeeze(0)
            for i in range(self.n_glimpse):
                query = self.glimpse(query, ref, mask)
            logits = self.pointer(query, ref, mask)	
            log_p = torch.log_softmax(logits, dim = -1)
            next_node = self.city_selecter(log_p)
            dec_input = torch.gather(input = embed_enc_inputs, dim = 1, index = next_node.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, embed))
            
            pi_list.append(next_node)
            log_ps.append(log_p)
            mask += torch.zeros((batch,city_t), device = device).scatter_(dim = 1, index = next_node.unsqueeze(1), value = 1)
            
        pi = torch.stack(pi_list, dim = 1)
        ll = self.get_log_likelihood(torch.stack(log_ps, 1), pi)
        return pi, ll 
    
    # def glimpse(self, query, ref, mask, inf = 1e8):
    # 	"""	-ref about torch.bmm, torch.matmul and so on
    # 		https://qiita.com/tand826/items/9e1b6a4de785097fe6a5
    # 		https://qiita.com/shinochin/items/aa420e50d847453cc296
            
    # 			Args: 
    # 		query: the hidden state of the decoder at the current
    # 		(batch, 128)
    # 		ref: the set of hidden states from the encoder. 
    # 		(batch, city_t, 128)
    # 		mask: model only points at cities that have yet to be visited, so prevent them from being reselected
    # 		(batch, city_t)
    # 	"""
    # 	u1 = self.W_q(query).unsqueeze(-1).repeat(1,1,ref.size(1))# u1: (batch, 128, city_t)
    # 	u2 = self.W_ref(ref.permute(0,2,1))# u2: (batch, 128, city_t)
    # 	V = self.Vec.unsqueeze(0).unsqueeze(0).repeat(ref.size(0), 1, 1)
    # 	u = torch.bmm(V, torch.tanh(u1 + u2)).squeeze(1)
    # 	# V: (batch, 1, 128) * u1+u2: (batch, 128, city_t) => u: (batch, 1, city_t) => (batch, city_t)
    # 	u = u - inf * mask
    # 	a = F.softmax(u / self.softmax_T, dim = 1)
    # 	d = torch.bmm(u2, a.unsqueeze(2)).squeeze(2)
    # 	# u2: (batch, 128, city_t) * a: (batch, city_t, 1) => d: (batch, 128)
    # 	return d

    # def pointer(self, query, ref, mask, inf = 1e8):
    # 	"""	Args: 
    # 		query: the hidden state of the decoder at the current
    # 		(batch, 128)
    # 		ref: the set of hidden states from the encoder. 
    # 		(batch, city_t, 128)
    # 		mask: model only points at cities that have yet to be visited, so prevent them from being reselected
    # 		(batch, city_t)
    # 	"""
    # 	u1 = self.W_q2(query).unsqueeze(-1).repeat(1,1,ref.size(1))# u1: (batch, 128, city_t)
    # 	u2 = self.W_ref2(ref.permute(0,2,1))# u2: (batch, 128, city_t)
    # 	V = self.Vec2.unsqueeze(0).unsqueeze(0).repeat(ref.size(0), 1, 1)
    # 	u = torch.bmm(V, self.clip_logits * torch.tanh(u1 + u2)).squeeze(1)
    # 	# V: (batch, 1, 128) * u1+u2: (batch, 128, city_t) => u: (batch, 1, city_t) => (batch, city_t)
    # 	u = u - inf * mask
    # 	return u
    
class Ptr_Critic(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.Embedding = nn.Linear(2, cfg.embed, bias = False)
        self.Encoder = nn.LSTM(input_size = cfg.embed, hidden_size = cfg.hidden, batch_first = True)
        self.Decoder = nn.LSTM(input_size = cfg.embed, hidden_size = cfg.hidden, batch_first = True)
        if torch.cuda.is_available():
            self.Vec = nn.Parameter(torch.cuda.FloatTensor(cfg.embed))
        else:
            self.Vec = nn.Parameter(torch.FloatTensor(cfg.embed))
        self.W_q = nn.Linear(cfg.hidden, cfg.hidden, bias = True)
        self.W_ref = nn.Conv1d(cfg.hidden, cfg.hidden, 1, 1)
        # self.dec_input = nn.Parameter(torch.FloatTensor(cfg.embed))
        self.final2FC = nn.Sequential(
                    nn.Linear(cfg.hidden, cfg.hidden, bias = False),
                    nn.ReLU(inplace = False),
                    nn.Linear(cfg.hidden, 1, bias = False))
        self._initialize_weights(cfg.init_min, cfg.init_max)
        self.n_glimpse = cfg.n_glimpse
        self.n_process = cfg.n_process
        self.glimpse = Attention(cfg, use_tanh=False, clip_logits=10)
    
    def _initialize_weights(self, init_min = -0.08, init_max = 0.08):
        for param in self.parameters():
            nn.init.uniform_(param.data, init_min, init_max)
            
    def forward(self, x, device):
        '''	x: (batch, city_t, 2)
            enc_h: (batch, city_t, embed)
            query(Decoder input): (batch, 1, embed)
            h: (1, batch, embed)
            return: pred_l: (batch)
        '''
        x = x.to(device)
        batch, city_t, xy = x.size()
        embed_enc_inputs = self.Embedding(x)
        embed = embed_enc_inputs.size(2)
        enc_h, (h, c) = self.Encoder(embed_enc_inputs, None)
        ref = enc_h
        # ~ query = h.permute(1,0,2).to(device)# query = self.dec_input.unsqueeze(0).repeat(batch,1).unsqueeze(1).to(device)
        query = h[-1]
        # ~ process_h, process_c = [torch.zeros((1, batch, embed), device = device) for _ in range(2)]
        for i in range(self.n_process):
            # ~ _, (process_h, process_c) = self.Decoder(query, (process_h, process_c))
            # ~ _, (h, c) = self.Decoder(query, (h, c))
            # ~ query = query.squeeze(1)
            for i in range(self.n_glimpse):
                query = self.glimpse(query, ref)
                # ~ query = query.unsqueeze(1)
        '''	
        - page 5/15 in paper
        critic model architecture detail is out there, "Critic’s architecture for TSP"
        - page 14/15 in paper
        glimpsing more than once with the same parameters 
        made the model less likely to learn and barely improved the results 
        
        query(batch,hidden)*FC(hidden,hidden)*FC(hidden,1) -> pred_l(batch,1) ->pred_l(batch)
        '''
        pred_l = self.final2FC(query).squeeze(-1).squeeze(-1)
        return pred_l 
    
    # def glimpse(self, query, ref, infinity = 1e8):
    # 	"""	Args: 
    # 		query: the hidden state of the decoder at the current
    # 		(batch, 128)
    # 		ref: the set of hidden states from the encoder. 
    # 		(batch, city_t, 128)
    # 	"""
    # 	u1 = self.W_q(query).unsqueeze(-1).repeat(1,1,ref.size(1))# u1: (batch, 128, city_t)
    # 	u2 = self.W_ref(ref.permute(0,2,1))# u2: (batch, 128, city_t)
    # 	V = self.Vec.unsqueeze(0).unsqueeze(0).repeat(ref.size(0), 1, 1)
    # 	u = torch.bmm(V, torch.tanh(u1 + u2)).squeeze(1)
    # 	# V: (batch, 1, 128) * u1+u2: (batch, 128, city_t) => u: (batch, 1, city_t) => (batch, city_t)
    # 	a = F.softmax(u, dim = 1)
    # 	d = torch.bmm(u2, a.unsqueeze(2)).squeeze(2)
    # 	# u2: (batch, 128, city_t) * a: (batch, city_t, 1) => d: (batch, 128)
    # 	return d


class DRL(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.actor = Ptr_Actor()
        self.critic = Ptr_Critic()

    def forward(self, inputs):
        pass

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import argparse

class MPSEnv():
     pass

class A2C:
    def __init__(self) -> None:
        self.actor = Ptr_Actor()
        self.critic = Ptr_Critic()

    def get_action(self, state):
        probs = self.actor(state)
        m = Categorical(probs)
        action = m.sample()
        logp_action = m.log_prob(action)
        return action, logp_action

    def compute_value_loss(self, bs, blogp_a, br, bd, bns):
        with torch.no_grad():
            target_value = br
        value_loss = F.mse_loss(self.critic(bs).squezze(), target_value)
        return value_loss


    def compute_policy_loss(self, bs, blogp_a, br, bd, bns):
        with torch.no_grad():
            target_value = br
            adv = target_value - self.critic(bs).squeeze()
        policy_loss = 0
        for i, logp_a in enumerate(blogp_a):
            policy_loss += -logp_a * adv[i]
        policy_loss = policy_loss.mean()
        return policy_loss

    def update(self):
        pass

class INFO():
    pass

class Rollout():
    pass
    
def train(cfg, env, agent):
    V_optimizer = torch.optim.Adam(agent.V.parameters(), lr=3e-3)
    pi_optimizer = torch.optim.Adam(agent.pi.parameters(), lr=3e-3)
    info = INFO()

    rollout = Rollout()
    state, _ = env.reset()
    for step in range(cfg.max_steps):
        action, logp_action = agent.get_action(torch.tensor(state).float())
        next_state, reward, terminited, truncated, _ = env.step(action.item())
        done = terminited or truncated


        state = next_state


        if done is True:
            bs, ba, blogp_a, br, bd, bns = rollout.tensor()
            value_loss = agent.compute_value_loss(bs, blogp_a, br, bd, bns)
            V_optimizer.zero_grad()
            value_loss.backward(retain_graph=True)
            V_optimizer.step()

            policy_loss = agent.compute_policy_loss(bs, blogp_a, br, bd, bns)
            pi_optimizer.zero_grad()
            policy_loss.backward()
            pi_optimizer.step()

            agent.update()
            state, _ = env.reset()
        
        if step % 1000 == 0:
            print()




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")
    parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")
    parser.add_argument("--num_action", default=2, type=int, help="Number of action.")
    parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")

    parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.")
    parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")
    parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")

    parser.add_argument("--do_train", action="store_true", help="Train policy.")
    parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")
    cfg = parser.parse_args()
    env = MPSEnv()
    agent = A2C(cfg)
    train(cfg, env, agent)