In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
import numpy as np
from torch.nn.functional import one_hot
from eda.TSP import TSP_Instance, TSP_Environment, TSP_State
from eda.solveTSP_v2 import solve

def generate_data(n_cities=50, nb_sample=512):
    X = []
    Y = []
    while len(X) < nb_sample:
        city_points = np.random.rand(n_cities, 2)
        inst_info = TSP_Instance(city_points)
        solution = solve(city_points)

        X.append(torch.from_numpy(city_points))
        Y.append(torch.tensor(solution.visited))
    return torch.stack(X).float(), torch.stack(Y).float()

X, Y = generate_data()
X.shape, Y.shape

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_dim = 6  # Dimensión de la entrada
        self.num_heads = 16  # Número de cabezas en la atención multi-cabeza
        self.head_dim = 8  # Dimensión de cada cabeza
        self.node_dim = 2
        self.embd_dim = self.num_heads * self.head_dim
        self.ff_dim=256
        self.nb_layers=6
        self.batchnorm=True
        
        self.MHA_layers = nn.ModuleList( [nn.MultiheadAttention(self.embd_dim, self.num_heads) for _ in range(self.nb_layers)] )
        self.linear1_layers = nn.ModuleList( [nn.Linear(self.embd_dim, self.ff_dim) for _ in range(self.nb_layers)] )
        self.linear2_layers = nn.ModuleList( [nn.Linear(self.ff_dim, self.embd_dim) for _ in range(self.nb_layers)] )   
        if self.batchnorm:
            self.norm1_layers = nn.ModuleList( [nn.BatchNorm1d(self.embd_dim) for _ in range(self.nb_layers)] )
            self.norm2_layers = nn.ModuleList( [nn.BatchNorm1d(self.embd_dim) for _ in range(self.nb_layers)] )
        else:
            self.norm1_layers = nn.ModuleList( [nn.LayerNorm(self.embd_dim) for _ in range(self.nb_layers)] )
            self.norm2_layers = nn.ModuleList( [nn.LayerNorm(self.embd_dim) for _ in range(self.nb_layers)] )

        self.norm = nn.BatchNorm1d(self.embd_dim)
    def forward(self, h):      
        # PyTorch nn.MultiheadAttention requires input size (seq_len, bsz, dim_emb) 
        h = h.transpose(0,1) # size(h)=(nb_nodes, bsz, dim_emb)  
        # L layers
        for i in range(self.nb_layers):
            h_rc = h # residual connection, size(h_rc)=(nb_nodes, bsz, dim_emb)
            h, score = self.MHA_layers[i](h, h, h) # size(h)=(nb_nodes, bsz, dim_emb), size(score)=(bsz, nb_nodes, nb_nodes)
            # add residual connection
            h = h_rc + h # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                # Pytorch nn.BatchNorm1d requires input size (bsz, dim, seq_len)
                h = h.permute(1,2,0).contiguous() # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm1_layers[i](h)       # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2,0,1).contiguous() # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm1_layers[i](h)       # size(h)=(nb_nodes, bsz, dim_emb) 
            # feedforward
            h_rc = h # residual connection
            h = self.linear2_layers[i](torch.relu(self.linear1_layers[i](h)))
            h = h_rc + h # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                h = h.permute(1,2,0).contiguous() # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm2_layers[i](h)       # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2,0,1).contiguous() # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm2_layers[i](h) # size(h)=(nb_nodes, bsz, dim_emb)
        # Transpose h
        h = h.transpose(0,1) # size(h)=(bsz, nb_nodes, dim_emb)
        return h, score

def generate_positional_encoding(d_model, max_len):
    """
    Create standard transformer PEs.
    Inputs :  
      d_model is a scalar correspoding to the hidden dimension
      max_len is the maximum length of the sequence
    Output :  
      pe of size (max_len, d_model), where d_model=dim_emb, max_len=1000
    """
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
    pe[:,0::2] = torch.sin(position * div_term)
    pe[:,1::2] = torch.cos(position * div_term)
    return pe


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Parámetros del modelo
        self.input_dim = 6  # Dimensión de la entrada
        self.num_heads = 16  # Número de cabezas en la atención multi-cabeza
        self.head_dim = 8  # Dimensión de cada cabeza
        self.node_dim = 2
        self.embd_dim = self.num_heads * self.head_dim
        self.ff_dim=256
        self.nb_dec_layers = 2
        self.nb_nodes=50

class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Parámetros del modelo
        self.input_dim = 6  # Dimensión de la entrada
        self.num_heads = 16  # Número de cabezas en la atención multi-cabeza
        self.head_dim = 8  # Dimensión de cada cabeza
        self.node_dim = 2
        self.embd_dim = self.num_heads * self.head_dim
        self.ff_dim=256
        self.nb_dec_layers = 2
        self.nb_nodes=50
        max_len_PE = 10000

        self.input_emb = nn.Linear(self.node_dim, self.embd_dim)
        self.ff = nn.Linear(self.embd_dim, 1)
        self.start_placeholder = nn.Parameter(torch.randn(self.embd_dim))

        self.enc = Encoder()
        
        self.WKatt_dec = nn.Linear(self.embd_dim, self.nb_dec_layers * self.embd_dim)
        self.WVatt_dec = nn.Linear(self.embd_dim, self.nb_dec_layers * self.embd_dim)
        self.PE = generate_positional_encoding(self.embd_dim, max_len_PE)     

        self.mha= nn.MultiheadAttention(self.embd_dim, self.num_heads)
        
        

        
        
            
    def forward(self, x, return_probabilities=False):
        # x: (bsz, nb_nodes, dim)
        zero_to_bsz = torch.arange(x.shape[0])
        bsz = x.shape[0]

        attn_mask = None
        h = self.input_emb(x)
        
        repeated_placeholder = self.start_placeholder.repeat(bsz, 1, 1)
        h = torch.cat([h, repeated_placeholder ], dim=1)

        h_enc, _ = self.enc(h)
        tours = []
        sumLog = []
        Katt_dec = self.WKatt_dec(h_enc)
        Vatt_dec = self.WVatt_dec(h_enc)
        self.PE = self.PE.to(x.device)
        idx_start_placeholder = torch.tensor([self.nb_nodes]).long().repeat(bsz).to(x.device)
        h_start = h_enc[zero_to_bsz, idx_start_placeholder, :] + self.PE[0].repeat(bsz, 1) 

        mask_visited_nodes = torch.zeros(bsz, self.nb_nodes +1, device = x.device)
        mask_visited_nodes[zero_to_bsz, idx_start_placeholder] = True

        h_t = h_start

        for t in range(self.nb_nodes):
            prob_next_node, _ = self.mha(h_t, Katt_dec, Vatt_dec, mask_visited_nodes)
            if deterministic:
                idx = torch.argmax(prob_next_node, dim=1)
            else:
                idx = Categorical(prob_next_node).sample()
            
        return h        



        pass

In [None]:
model = CustomModel()
model(X).shape