In [1]:
import torch as t
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from timeit import timeit

In [2]:
DEVICE = t.device("cuda")
FLOAT_TYPE = t.float32
INT_TYPE = t.int32

In [10]:
class ConvexHullDataset(Dataset):
    def __init__(self, txt_file: str, n_points: int, device: t.device, float_dtype: t.dtype, int_dtype: t.dtype, cap: int | None = None):
        self.n_points = n_points

        with open(txt_file, 'r') as f:
            raw_data = f.readlines()
            self.data_x = t.zeros(len(raw_data), n_points, 2, dtype=float_dtype, device="cpu")
            self.data_y = t.zeros(len(raw_data), n_points + 1, dtype=int_dtype, device="cpu") # loops back around

            for i, line in enumerate(raw_data):
                if cap is not None and i >= cap:
                    self.data_x = self.data_x[:i]
                    self.data_y = self.data_y[:i]
                    break

                if (i + 1) % 10000 == 0:
                    print(f"Loaded data: {i + 1}/{len(raw_data)}; cap: {cap}")

                index = 0
                x_mode = True
                for token in line.split(" "):
                    if token == "\n":
                        continue

                    if token == "output":
                        x_mode = False
                        index = 0
                        continue
                    
                    if x_mode:
                        self.data_x[i, index // 2, index % 2] = float(token)
                        index += 1
                    else:
                        self.data_y[i, index] = float(token)
                        index += 1

            print("Copying x to device")
            self.data_x = self.data_x.to(device)
            print("Copying y to device")
            self.data_y = self.data_y.to(device)

    def __len__(self):
        return len(self.data_x)

    def __getitem__(self, idx):
        return self.data_x[idx], self.data_y[idx].long() # NO LONGER - 1 # -1 to make it 0-indexed

In [20]:
convex_hull_dataset = ConvexHullDataset("../data/ptr_nets/tsp10.txt", 10, DEVICE, FLOAT_TYPE, INT_TYPE, cap=100000)

Loaded data: 10000/1000000; cap: 100000
Loaded data: 20000/1000000; cap: 100000
Loaded data: 30000/1000000; cap: 100000
Loaded data: 40000/1000000; cap: 100000
Loaded data: 50000/1000000; cap: 100000
Loaded data: 60000/1000000; cap: 100000
Loaded data: 70000/1000000; cap: 100000
Loaded data: 80000/1000000; cap: 100000
Loaded data: 90000/1000000; cap: 100000
Loaded data: 100000/1000000; cap: 100000
Copying x to device
Copying y to device


In [43]:
class PtrNet2(nn.Module):
    def __init__(self, in_dim: int, m_out: int, dtype: t.dtype, hidden_dim: int = 512) -> None:
        super().__init__()

        self.in_dim = in_dim # input embedding size
        self.n_out = m_out   # number of output indices to return
        self.dtype = dtype

        self.hidden_dim = hidden_dim

        self.encoder = nn.LSTM(in_dim, hidden_dim, 1, dtype=dtype)
        self.decoder = nn.LSTM(in_dim, hidden_dim, 1, dtype=dtype)

        self.starting_token = t.ones(in_dim, dtype=dtype) * -1.0
        self.ending_token = t.ones(hidden_dim, dtype=dtype)

        for i in range(hidden_dim):
            self.ending_token[i] = - ((i + 1) % 2)

        self.W1 = nn.Parameter(t.randn(hidden_dim, hidden_dim, dtype=dtype))
        self.W2 = nn.Parameter(t.randn(hidden_dim, hidden_dim, dtype=dtype))
        self.v  = nn.Parameter(t.randn(1, hidden_dim, dtype=dtype))


    def forward(self, x: t.Tensor, y: t.Tensor | None = None):
        """
        x is a batch_size x in_len x in_dim tensor
        y is a batch_size x n_out tensor of indices. The indices are one-indexed:
          used for during training as teacher forcing. ignored during validation
          if y[n, i] = k, this means for the nth batch, the ith output vector is to be equal to the (k)th input vector
          the 0th index is reserved for the ending token

        returns: a tuple (A, (encoder_all_hiddens, decoder_all_hiddens)) where:
            A is a batch_size x (n_out + 1) x (in_len + 1) matrix; for each batch, we get (n_out + 1) probability pointers to the input sequence. The pointers are effectively 1-indexed: the 0th index refers to a pointer to the ending token, which is a signal to stop. The 1st index refers to a pointer to the 0th position in the input sequence, and so on.
            encoder_all_hiddens is a batch_size x in_len + 1 x hidden_dim matrix; the first position is the ending token
            decoder_all_hiddens is a batch_size x n_out + 1 x hidden_dim matrix; the first position is first output token, viz. the hidden state AFTER the starting token; the (i + 1)th position is the hidden state AFTER the ith output token; the (n_out + 1)th position is the hidden state AFTER the final output token, which should be used to generate a pointer to the start token
        """

        # force batched
        if x.dim() == 2:
            x = x.unsqueeze(0)
        assert x.dim() == 3
        batch_size = len(x)
        batch_size_arange = t.arange(batch_size)
        in_len = x.shape[1]

        if self.training:
            assert y is not None
            assert x.shape[0] == y.shape[0]
            assert y.shape[1] == self.n_out

        
        assert x.shape[2] == self.in_dim

        x = x.permute(1, 0, 2) # TODO: change it so that input is already permuted

        # define L
        # define encoder_all_hiddens, an N * L + 1 * embedding_dim matrix
        encoder_all_hiddens = t.zeros(in_len + 1, batch_size, self.hidden_dim, dtype=self.dtype, device=self.device)
        encoder_first_hidden = self.ending_token.repeat(1, batch_size, 1) # NOTE: FIRST position in L dimension is the ending token. Thus, a pointer to the FIRST (0)th position is a signal to stop
        encoder_all_hiddens[0:1, :, :] = encoder_first_hidden

        # encode
        encoder_all_hiddens[1:, :, :], (encoder_last_hidden, encoder_last_cell_state) = self.encoder(x, (encoder_first_hidden, t.zeros(1, batch_size, self.hidden_dim, dtype=self.dtype, device=self.device)))

        # initialize decode variables
        decoder_all_hiddens = t.zeros(self.n_out + 1, batch_size, self.hidden_dim, dtype=self.dtype, device=self.device)
        decoder_current_cell_state = encoder_last_cell_state

        # decode
        if self.training:
            decoder_input = t.empty(self.n_out + 1, batch_size, self.in_dim, dtype=self.dtype, device=self.device)
            decoder_input[0:1, :, :] = self.starting_token.repeat(1, batch_size, 1)
            decoder_input[1:, :, :] = x[y.t() - 1, batch_size_arange.view(1, batch_size), :]

            decoder_all_hiddens[:, :, :], _ = self.decoder(decoder_input, (encoder_last_hidden, decoder_current_cell_state))

            lin_enc = self.W1 @ encoder_all_hiddens.permute(1, 2, 0)
            lin_dec = self.W2 @ decoder_all_hiddens.permute(1, 2, 0)

            TANH = (lin_enc.unsqueeze(3) + lin_dec.unsqueeze(2)).tanh()
            U = self.v @ TANH.permute(0, 3, 1, 2) # batch_size x (n_out + 1) x 1 x in_len + 1

            # A = U.squeeze(dim=2)
            A = U.softmax(dim=3).squeeze(dim=2) # batch_size x (n_out + 1) x in_len + 1

            return A, (encoder_all_hiddens.permute(1, 0, 2), decoder_all_hiddens.permute(1, 0, 2))
                      # returns a batch_size x in_len + 1 x hidden_dim matrix,
                      # a batch_size x n_out + 1 x hidden_dim matrix,
                      # and a batch_size x (n_out + 1) x in_len + 1 matrix
        else:
            decoder_current_input  = self.starting_token.repeat(1, batch_size, 1)
            decoder_current_hidden = encoder_last_hidden

            
            A = t.empty(self.n_out + 1, batch_size, in_len + 1, dtype=self.dtype, device=self.device)

            for i in range(self.n_out + 1):
                _, (decoder_current_hidden, decoder_current_cell_state) = self.decoder(decoder_current_input, (decoder_current_hidden, decoder_current_cell_state))
                TANH_i = (self.W1 @ encoder_all_hiddens.permute(1, 2, 0) + self.W2 @ decoder_current_hidden.permute(1, 2, 0)).tanh()
                U_i    = self.v @ TANH_i

                # A_i = U_i.squeeze(dim=1) # no softmaxing here

                A_i    = U_i.softmax(dim=2).squeeze(dim=1) # gives an batch_size x in_len matrix where m[n] is a probability distribution over in_length input indices
                
                next_index = A_i.argmax(dim=1) # TODO: handle when next_index[n][0] is 0 for some n, meaning that for the nth batch, the decoder should stop. TODO: add beam search functionality
                A[i, :, :] = A_i
                decoder_current_input = x[next_index - 1, batch_size_arange].unsqueeze(0)
                decoder_all_hiddens[i, :, :] = decoder_current_hidden.squeeze(dim=0)
            
            return A.permute(1, 0, 2), (encoder_all_hiddens.permute(1, 0, 2), decoder_all_hiddens.permute(1, 0, 2))
        
    
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.device = args[0]
        self.starting_token = self.starting_token.to(self.device)
        self.ending_token = self.ending_token.to(self.device)

        return self

In [45]:
model_one = PtrNet2(2, 11, FLOAT_TYPE).to(DEVICE)
optim_one = t.optim.Adam(model_one.parameters(), lr=0.001)

In [46]:
dataloader = DataLoader(convex_hull_dataset, batch_size=512, shuffle=True)

In [47]:
def train(model: PtrNet2, d: DataLoader, optim: t.optim.Optimizer, n_epochs: int = 100):
    # NOTE: assuming all inputs are of same sequence length model.n_out
    n_out = model.n_out
    n_batches = len(d)

    model.train()

    ce = t.nn.CrossEntropyLoss()
    for epoch_number in range(n_epochs):
        for batch_number, (x_train, y_train) in enumerate(d):
            optim.zero_grad()

            probabilities, _ = model(x_train, y_train)

            targets = t.zeros((y_train.shape[0], n_out + 1), dtype=t.long, device=x_train.device)
            targets[:,:-1] = y_train
            targets[:,-1] = 0

            loss = ce(probabilities.permute(0, 2, 1), targets)

            if batch_number % 100 == 0:
                print(f"Epoch {epoch_number + 1}/{n_epochs} Batch {batch_number + 1}/{n_batches} loss {loss.item()} accuracy: {(probabilities.argmax(dim=2)[:,1:-1] == targets[:,1:-1]).sum() / targets[:,1:-1].numel()}")
            
            loss.backward()
            optim.step()


In [48]:
train(model_one, dataloader, optim_one, n_epochs=10)

Epoch 1/10 Batch 1/196 loss 2.434884786605835 accuracy: 0.013671875
Epoch 1/10 Batch 101/196 loss 2.3484256267547607 accuracy: 0.21894530951976776
Epoch 2/10 Batch 1/196 loss 2.305101156234741 accuracy: 0.2701171934604645
Epoch 2/10 Batch 101/196 loss 2.267939805984497 accuracy: 0.32109376788139343
Epoch 3/10 Batch 1/196 loss 2.2464091777801514 accuracy: 0.35039064288139343
Epoch 3/10 Batch 101/196 loss 2.2127792835235596 accuracy: 0.3919921815395355
Epoch 4/10 Batch 1/196 loss 2.1680126190185547 accuracy: 0.4476562440395355
Epoch 4/10 Batch 101/196 loss 2.153106451034546 accuracy: 0.46699219942092896
Epoch 5/10 Batch 1/196 loss 2.125591278076172 accuracy: 0.504687488079071
Epoch 5/10 Batch 101/196 loss 2.0890491008758545 accuracy: 0.553906261920929
Epoch 6/10 Batch 1/196 loss 2.076425790786743 accuracy: 0.5654296875
Epoch 6/10 Batch 101/196 loss 2.0662741661071777 accuracy: 0.576171875
Epoch 7/10 Batch 1/196 loss 2.0463850498199463 accuracy: 0.5970703363418579
Epoch 7/10 Batch 101/196

In [49]:
x_test, y_test = convex_hull_dataset[0:3]

In [50]:
x_test.shape, y_test.shape

(torch.Size([3, 10, 2]), torch.Size([3, 11]))

In [51]:
model_one.train()
p, _ = model_one(x_test, y_test)

In [52]:
p

tensor([[[0.0000e+00, 1.0000e+00, 5.5496e-15, 4.6777e-10, 1.7652e-13,
          4.1180e-16, 7.6380e-13, 2.8813e-14, 2.1219e-13, 8.0688e-12,
          1.0039e-10],
         [0.0000e+00, 5.4404e-18, 3.1486e-15, 3.0507e-08, 9.9999e-01,
          3.6071e-15, 6.7045e-06, 1.1977e-13, 1.4881e-08, 4.8464e-16,
          2.5963e-07],
         [0.0000e+00, 9.2538e-12, 3.9776e-03, 4.0498e-01, 3.1658e-06,
          1.7585e-05, 2.9896e-08, 1.9362e-08, 1.4768e-08, 5.6171e-01,
          2.9312e-02],
         [0.0000e+00, 2.5635e-17, 2.3643e-03, 4.2935e-09, 9.7099e-07,
          1.2548e-05, 6.5435e-06, 9.8097e-08, 1.6975e-06, 9.9761e-01,
          2.9851e-08],
         [0.0000e+00, 1.2416e-15, 4.0563e-01, 8.4804e-12, 2.0894e-05,
          3.0493e-01, 1.3285e-05, 8.0567e-03, 4.2363e-07, 2.8135e-01,
          1.2779e-11],
         [0.0000e+00, 1.0593e-21, 7.1946e-02, 1.1762e-16, 1.4949e-09,
          9.2245e-01, 8.6392e-08, 5.5784e-03, 2.4779e-05, 4.4656e-07,
          2.0680e-17],
         [0.0000e+00, 