In [8]:
import torch as t
from torch.utils.data import DataLoader, Dataset, Subset
import torch.nn as nn
from timeit import timeit
import pickle
from os import path

In [2]:
DEVICE = t.device("cuda")
FLOAT_TYPE = t.float32
INT_TYPE = t.int32

In [3]:
def generate_pickle(in_file: str, out_file: str, n_points: int, float_dtype: t.dtype, int_dtype: t.dtype, cap: int | None = None):
    with open(in_file, 'r') as f:
        raw_data = f.readlines()
        data = {
            "x": t.zeros(len(raw_data), n_points, 2, dtype=float_dtype, device="cpu"),
            "y": t.zeros(len(raw_data), n_points + 1, dtype=int_dtype, device="cpu") # loops back around
        }
        for i, line in enumerate(raw_data):
            if cap is not None and i >= cap:
                data["x"] = data["x"][:i]
                data["y"] = data["y"][:i]

                break

            if (i + 1) % 10000 == 0:
                print(f"Loaded data: {i + 1}/{len(raw_data)}; cap: {cap}")
                pickle.dump(data, open(out_file, "wb"))

            index = 0
            x_mode = True
            for token in line.split(" "):
                if token == "\n":
                    continue

                if token == "output":
                    x_mode = False
                    index = 0
                    continue
                
                if x_mode:
                    data["x"][i, index // 2, index % 2] = float(token)
                    index += 1
                else:
                    data["y"][i, index] = float(token)
                    index += 1

        pickle.dump(data, open(out_file, "wb"))

In [11]:
if not path.isfile("../data/ptr_nets/tsp10.pkl"):
    generate_pickle("../data/ptr_nets/tsp10.txt", "../data/ptr_nets/tsp10.pkl", 10, FLOAT_TYPE, INT_TYPE)

In [13]:
class ConvexHullDataset(Dataset):
    def __init__(self, in_file: str, n_points: int, device: t.device, float_dtype: t.dtype, int_dtype: t.dtype):
        self.n_points = n_points
        self.data = pickle.load(open(in_file, "rb"))
        
        assert self.data["x"].shape[1] == self.n_points, "Number of points in each sample does not match n_points"
        assert self.data["x"].shape[2] == 2, "Each point must have 2 coordinates"
        assert self.data["x"].dtype == float_dtype, "Data type of x does not match float_dtype"
        assert self.data["y"].dtype == int_dtype, "Data type of y does not match int_dtype"
            
        print("Copying x to device")
        self.data["x"] = self.data["x"].to(device)
        print("Copying y to device")
        self.data["y"] = self.data["y"].to(device)

    def __len__(self):
        return len(self.data["x"])

    def __getitem__(self, idx):
        return self.data["x"][idx], self.data["y"][idx].long() # NO LONGER - 1 # -1 to make it 0-indexed

In [14]:
convex_hull_dataset = ConvexHullDataset("../data/ptr_nets/tsp10.pkl", 10, DEVICE, FLOAT_TYPE, INT_TYPE)

Copying x to device
Copying y to device


In [32]:
class PtrNet2(nn.Module):
    def __init__(self, in_dim: int, m_out: int, dtype: t.dtype, hidden_dim: int = 512) -> None:
        super().__init__()

        self.in_dim = in_dim # input embedding size
        self.n_out = m_out   # number of output indices to return
        self.dtype = dtype

        self.hidden_dim = hidden_dim

        self.encoder = nn.LSTM(in_dim, hidden_dim, 1, dtype=dtype)
        self.decoder = nn.LSTM(in_dim, hidden_dim, 1, dtype=dtype)

        self.starting_token = t.ones(in_dim, dtype=dtype) * -1.0
        self.ending_token = t.ones(hidden_dim, dtype=dtype)

        for i in range(hidden_dim):
            self.ending_token[i] = - ((i + 1) % 2)

        self.W1 = nn.Parameter(t.randn(hidden_dim, hidden_dim, dtype=dtype, requires_grad=True))
        self.W2 = nn.Parameter(t.randn(hidden_dim, hidden_dim, dtype=dtype, requires_grad=True))
        self.v  = nn.Parameter(t.randn(1, hidden_dim, dtype=dtype, requires_grad=True))


    def forward(self, x: t.Tensor, y: t.Tensor | None = None):
        """
        x is a batch_size x in_len x in_dim tensor
        y is a batch_size x n_out tensor of indices. The indices are one-indexed:
          used for during training as teacher forcing. ignored during validation
          if y[n, i] = k, this means for the nth batch, the ith output vector is to be equal to the (k)th input vector
          the 0th index is reserved for the ending token

        returns: a tuple (A, (encoder_all_hiddens, decoder_all_hiddens)) where:
            A is a batch_size x (n_out + 1) x (in_len + 1) matrix; for each batch, we get (n_out + 1) probability pointers to the input sequence. The pointers are effectively 1-indexed: the 0th index refers to a pointer to the ending token, which is a signal to stop. The 1st index refers to a pointer to the 0th position in the input sequence, and so on.
            encoder_all_hiddens is a batch_size x in_len + 1 x hidden_dim matrix; the first position is the ending token
            decoder_all_hiddens is a batch_size x n_out + 1 x hidden_dim matrix; the first position is first output token, viz. the hidden state AFTER the starting token; the (i + 1)th position is the hidden state AFTER the ith output token; the (n_out + 1)th position is the hidden state AFTER the final output token, which should be used to generate a pointer to the start token
        """

        # force batched
        if x.dim() == 2:
            x = x.unsqueeze(0)
        assert x.dim() == 3
        batch_size = len(x)
        batch_size_arange = t.arange(batch_size)
        in_len = x.shape[1]

        if self.training:
            assert y is not None
            assert x.shape[0] == y.shape[0]
            assert y.shape[1] == self.n_out

        
        assert x.shape[2] == self.in_dim

        x = x.permute(1, 0, 2) # TODO: change it so that input is already permuted

        # define L
        # define encoder_all_hiddens, an N * L + 1 * embedding_dim matrix
        encoder_all_hiddens = t.zeros(in_len + 1, batch_size, self.hidden_dim, dtype=self.dtype, device=self.device)
        encoder_first_hidden = self.ending_token.repeat(1, batch_size, 1) # NOTE: FIRST position in L dimension is the ending token. Thus, a pointer to the FIRST (0)th position is a signal to stop
        encoder_all_hiddens[0:1, :, :] = encoder_first_hidden

        # encode
        encoder_all_hiddens[1:, :, :], (encoder_last_hidden, encoder_last_cell_state) = self.encoder(x, (encoder_first_hidden, t.zeros(1, batch_size, self.hidden_dim, dtype=self.dtype, device=self.device)))

        # initialize decode variables
        decoder_all_hiddens = t.zeros(self.n_out + 1, batch_size, self.hidden_dim, dtype=self.dtype, device=self.device)
        decoder_current_cell_state = encoder_last_cell_state

        # decode
        if self.training:
            decoder_input = t.empty(self.n_out + 1, batch_size, self.in_dim, dtype=self.dtype, device=self.device)
            decoder_input[0:1, :, :] = self.starting_token.repeat(1, batch_size, 1)
            decoder_input[1:, :, :] = x[y.t() - 1, batch_size_arange.view(1, batch_size), :]

            decoder_all_hiddens[:, :, :], _ = self.decoder(decoder_input, (encoder_last_hidden, decoder_current_cell_state))

            lin_enc = self.W1 @ encoder_all_hiddens.permute(1, 2, 0)
            lin_dec = self.W2 @ decoder_all_hiddens.permute(1, 2, 0)

            TANH = (lin_enc.unsqueeze(3) + lin_dec.unsqueeze(2)).tanh()
            U = self.v @ TANH.permute(0, 3, 1, 2) # batch_size x (n_out + 1) x 1 x in_len + 1

            A = U.squeeze(dim=2)
            # A = U.softmax(dim=3).squeeze(dim=2) # batch_size x (n_out + 1) x in_len + 1

            return A, (encoder_all_hiddens.permute(1, 0, 2), decoder_all_hiddens.permute(1, 0, 2))
                      # returns a batch_size x in_len + 1 x hidden_dim matrix,
                      # a batch_size x n_out + 1 x hidden_dim matrix,
                      # and a batch_size x (n_out + 1) x in_len + 1 matrix
        else:
            decoder_current_input  = self.starting_token.repeat(1, batch_size, 1)
            decoder_current_hidden = encoder_last_hidden

            
            A = t.empty(self.n_out + 1, batch_size, in_len + 1, dtype=self.dtype, device=self.device)

            for i in range(self.n_out + 1):
                _, (decoder_current_hidden, decoder_current_cell_state) = self.decoder(decoder_current_input, (decoder_current_hidden, decoder_current_cell_state))
                TANH_i = (self.W1 @ encoder_all_hiddens.permute(1, 2, 0) + self.W2 @ decoder_current_hidden.permute(1, 2, 0)).tanh()
                U_i    = self.v @ TANH_i

                A_i = U_i.squeeze(dim=1) # no softmaxing here

                # A_i    = U_i.softmax(dim=2).squeeze(dim=1) # gives an batch_size x in_len matrix where m[n] is a probability distribution over in_length input indices
                
                next_index = A_i.argmax(dim=1) # TODO: handle when next_index[n][0] is 0 for some n, meaning that for the nth batch, the decoder should stop. TODO: add beam search functionality
                A[i, :, :] = A_i
                decoder_current_input = x[next_index - 1, batch_size_arange].unsqueeze(0)
                decoder_all_hiddens[i, :, :] = decoder_current_hidden.squeeze(dim=0)
            
            return A.permute(1, 0, 2), (encoder_all_hiddens.permute(1, 0, 2), decoder_all_hiddens.permute(1, 0, 2))
        
    
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.device = args[0]
        self.starting_token = self.starting_token.to(self.device)
        self.ending_token = self.ending_token.to(self.device)

        return self

In [48]:
model_one = PtrNet2(2, 11, FLOAT_TYPE).to(DEVICE)
optim_one = t.optim.Adam(model_one.parameters(), lr=3e-3)

In [49]:
train_dataset_size = len(convex_hull_dataset)

In [50]:
train_dataloader = DataLoader(Subset(convex_hull_dataset, range(9 * train_dataset_size // 10)), batch_size=512, shuffle=True)
eval_dataloader  = DataLoader(Subset(convex_hull_dataset, range(9 * train_dataset_size // 10, train_dataset_size)), batch_size=512, shuffle=True)

In [51]:
def train(model: PtrNet2, train_dataloader: DataLoader, eval_dataloader: DataLoader, optim: t.optim.Optimizer, model_path: str, n_epochs: int = 100):
    # NOTE: assuming all inputs are of same sequence length model.n_out
    n_out = model.n_out
    n_batches_train = len(train_dataloader)
    n_batches_eval = len(eval_dataloader)

    ce = t.nn.CrossEntropyLoss()
    for epoch_number in range(n_epochs):
        for training in [True, False]:
            if not training:
                pickle.dump(model, open(f"{model_path}_{epoch_number}.pkl", "wb"))
            for batch_number, (x_train, y_train) in enumerate(train_dataloader if training else eval_dataloader):
                if training:
                    model.train()
                    optim.zero_grad()
                else:
                    model.eval()

                probabilities, _ = model(x_train, y_train)

                targets = t.zeros((y_train.shape[0], n_out + 1), dtype=t.long, device=x_train.device)
                targets[:,:-1] = y_train
                targets[:,-1] = 0

                loss = ce(probabilities.permute(0, 2, 1), targets)

                if batch_number % 100 == 0:
                    print(f"Epoch {epoch_number + 1}/{n_epochs} {'Training' if training else 'Evaluation'} Batch {batch_number + 1}/{(n_batches_train if training else n_batches_eval)} loss {loss.item()} accuracy: {(probabilities.argmax(dim=2)[:,1:-1] == targets[:,1:-1]).sum() / targets[:,1:-1].numel()}")
                
                if training:
                    loss.backward()
                    optim.step()


In [53]:
train(model_one, train_dataloader, train_dataloader, optim_one, "../models/ptr_net")

Epoch 1/100 Training Batch 1/1758 loss 0.6837473511695862 accuracy: 0.7007812857627869
Epoch 1/100 Training Batch 101/1758 loss 0.6291791796684265 accuracy: 0.7220703363418579
Epoch 1/100 Training Batch 201/1758 loss 0.9140381813049316 accuracy: 0.6216797232627869
Epoch 1/100 Training Batch 301/1758 loss 0.6196596026420593 accuracy: 0.7275390625
Epoch 1/100 Training Batch 401/1758 loss 0.6746962070465088 accuracy: 0.708203136920929
Epoch 1/100 Training Batch 501/1758 loss 0.5961706638336182 accuracy: 0.738476574420929
Epoch 1/100 Training Batch 601/1758 loss 0.6433029174804688 accuracy: 0.715624988079071
Epoch 1/100 Training Batch 701/1758 loss 0.5862244963645935 accuracy: 0.7437500357627869
Epoch 1/100 Training Batch 801/1758 loss 0.5921444892883301 accuracy: 0.742382824420929
Epoch 1/100 Training Batch 901/1758 loss 0.5942407250404358 accuracy: 0.7496094107627869
Epoch 1/100 Training Batch 1001/1758 loss 0.6093252897262573 accuracy: 0.740429699420929
Epoch 1/100 Training Batch 1101/1

KeyboardInterrupt: 

In [10]:
x_test, y_test = convex_hull_dataset[0:3]

In [11]:
x_test.shape, y_test.shape

(torch.Size([3, 10, 2]), torch.Size([3, 11]))

In [12]:
model_one.train()
p, _ = model_one(x_test, y_test)

In [14]:
p.argmax(dim=2)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')