In [4]:
%%bash
nvidia-smi

Wed Mar 17 00:19:05 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.118.02   Driver Version: 440.118.02   CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   37C    P0    33W / 250W |   9979MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000000:AF:00.0 Off |                    0 |
| N/A   55C    P0   153W / 250W |  15501MiB / 16280MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------

In [None]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cpu"
floattype = torch.float


class environment:    
    def reset(self, npoints, batchsize, nsamples=1):
        self.batchsize = (
            batchsize * nsamples
        )  # so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = (
            torch.rand([batchsize, npoints, 2], dtype = floattype, device=device)
            .unsqueeze(1)
            .expand(-1, nsamples, -1, -1)
            .reshape(self.batchsize, npoints, 2)
        )
        
        self.distance_matrix = (self.points.unsqueeze(1) - self.points.unsqueeze(2)).square().sum(-1).sqrt() # [batchsize * nsamples, npoints, npoints]
        
        self.previous_point = None
        
        self.points_mask = torch.zeros(
                    [self.batchsize, npoints], dtype=torch.bool, device=device
                )
        self.points_sequence = torch.empty(
            [self.batchsize, 0], dtype=torch.long, device=device
        )
        
        self.cost = torch.zeros([self.batchsize], dtype = floattype, device=device)

        self.logprob = torch.zeros([self.batchsize], dtype = floattype, device=device, requires_grad=True)

    def update(self, point_index):  # point_index is [batchsize]
        
        assert list(point_index.size()) == [self.batchsize]
        assert str(point_index.device) == device
        assert self.points_mask.gather(1, point_index.unsqueeze(1)).sum() == 0
        
        if self.previous_point != None:
            self.cost += self.distance_matrix.gather(2, self.previous_point.unsqueeze(1).unsqueeze(2).expand(-1, self.npoints, 1)).squeeze(2).gather(1, point_index.unsqueeze(1)).squeeze(1)
        
        self.previous_point = point_index
        self.points_mask.scatter_(1, point_index.unsqueeze(1), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
        
        return
    
    def laststep(self):
        
        assert self.points_sequence.size(1) == self.npoints
        
        self.cost += self.distance_matrix.gather(2, self.points_sequence[:, 0].unsqueeze(1).unsqueeze(2).expand(-1, self.npoints, 1)).squeeze(2).gather(1, self.points_sequence[:, -1].unsqueeze(1)).squeeze(1)
    


def farthest_insertion(npoints, batchsize):
    points = torch.rand([batchsize, npoints, 2], device = device)
    points_mask = torch.zeros([self.batchsize, npoints], dtype=torch.bool, device=device)
    distance_matrix = (points.unsqueeze(1) - points.unsqueeze(2)).square().sum(-1).sqrt() # [batchsize * nsamples, npoints, npoints]
    points_sequence = torch.empty([batchsize, 0], dtype = torch.long, device = device_)
    for i in range(npoints - 1):
        distance_from_tour = distance_matrix.gather(2, self.points_mask.unsqueeze(1).expand(-1, npoints, -1)).min(2)[0] #[batchsize, npoints]
        next_point = distance_from_tour.max(1)[1] #[batchsize], is index of point
        points_mask.scatter_(1, next_point.unsqueeze(1), True)
        insertion_segment = torch.cat([points_sequence, next_point, points_sequence.roll(1, 1)], dim = 2) #[batchsize, sequence, 3]
        insertion_cost = distance_matrix.gather(2, ) + distance_matrix.gather(2, ) - distance_matrix.gather(2, ) #[batchsize, sequence]
        insertion_location = insertion_cost.min(insertion_cost)[1] #[batchsize], insertion location
        #turn location into indices i, i+1, modulo back to zero, so can insert next point into last position or something