In [5]:
%%bash
nvidia-smi

Wed Mar 17 21:13:15 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.118.02   Driver Version: 440.118.02   CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   35C    P0   158W / 250W |  16227MiB / 16280MiB |     76%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000000:AF:00.0 Off |                    0 |
| N/A   51C    P0   155W / 250W |  15783MiB / 16280MiB |     93%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------

In [125]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cpu"
floattype = torch.float


class environment:    
    def reset(self, npoints, batchsize, nsamples=1):
        self.batchsize = (
            batchsize * nsamples
        )  # so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = (
            torch.rand([batchsize, npoints, 2], dtype = floattype, device=device)
            .unsqueeze(1)
            .expand(-1, nsamples, -1, -1)
            .reshape(self.batchsize, npoints, 2)
        )
        
        self.distance_matrix = (self.points.unsqueeze(1) - self.points.unsqueeze(2)).square().sum(-1).sqrt() # [batchsize * nsamples, npoints, npoints]
        
        self.previous_point = None
        
        self.points_mask = torch.zeros(
                    [self.batchsize, npoints], dtype=torch.bool, device=device
                )
        self.points_sequence = torch.empty(
            [self.batchsize, 0], dtype=torch.long, device=device
        )
        
        self.cost = torch.zeros([self.batchsize], dtype = floattype, device=device)

        self.logprob = torch.zeros([self.batchsize], dtype = floattype, device=device, requires_grad=True)

    def update(self, point_index):  # point_index is [batchsize]
        
        assert list(point_index.size()) == [self.batchsize]
        assert str(point_index.device) == device
        assert self.points_mask.gather(1, point_index.unsqueeze(1)).sum() == 0
        
        if self.previous_point != None:
            self.cost += self.distance_matrix.gather(2, self.previous_point.unsqueeze(1).unsqueeze(2).expand(-1, self.npoints, 1)).squeeze(2).gather(1, point_index.unsqueeze(1)).squeeze(1)
        
        self.previous_point = point_index
        self.points_mask.scatter_(1, point_index.unsqueeze(1), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
        
        return
    
    def laststep(self):
        
        assert self.points_sequence.size(1) == self.npoints
        
        self.cost += self.distance_matrix.gather(2, self.points_sequence[:, 0].unsqueeze(1).unsqueeze(2).expand(-1, self.npoints, 1)).squeeze(2).gather(1, self.points_sequence[:, -1].unsqueeze(1)).squeeze(1)
    


def farthest_insertion(npoints, batchsize):
    
    points = torch.rand([batchsize, npoints, 2], device = device)
    points_mask = torch.zeros([batchsize, npoints], dtype=torch.bool, device=device)
    distance_matrix = (points.unsqueeze(1) - points.unsqueeze(2)).square().sum(-1).sqrt() # [batchsize * nsamples, npoints, npoints]
    
    max_distances, second_point = distance_matrix.max(2) #[batchsize, npoints], [batchsize, npoints]
    running_cost, first_point = max_distances.max(1) #[batchsize], [batchsize]
    
    running_cost *= 2
    
    points_sequence = torch.stack([first_point, second_point.gather(1, first_point.unsqueeze(1)).squeeze(1)], dim = 1) #[batchsize, sequence]
    points_mask = points_mask.scatter(1, points_sequence, True)
    
    for i in range(2, npoints):
        
        assert points_sequence.size(1) == i
        assert points_mask.sum() == batchsize * i
        npoints_inserted = i
        
        distance_from_tour = (
            distance_matrix
            .masked_select(points_mask
                           .unsqueeze(1)
                           .expand(-1, npoints, -1)
                          )
            .view(batchsize, npoints, -1)
            .min(2)[0]
        ) #[batchsize, npoints]
        
        next_point = distance_from_tour.max(1)[1] #[batchsize], is index of point
        
        points_mask.scatter_(1, next_point.unsqueeze(1), True)
        
        insertion_cost = (
            
            distance_matrix
            .gather(
                1,
                points_sequence
                .roll(1, 1)
                .unsqueeze(2)
                .expand(-1, -1, npoints)
            )
            .gather(
                2,
                next_point
                .unsqueeze(1)
                .unsqueeze(2)
                .expand(-1, i ,1)
            )
            .squeeze(2) #[batchsize, sequence]
            
            + distance_matrix
            .gather(
                1,
                next_point
                .unsqueeze(1)
                .unsqueeze(2)
                .expand(-1, i ,npoints)
            )
            .gather(
                2,
                points_sequence
                .unsqueeze(2)
            )
            .squeeze(2) #[batchsize, sequence] 
            
            - distance_matrix
            .gather(
                1,
                points_sequence
                .roll(1, 1)
                .unsqueeze(2)
                .expand(-1, -1, npoints)
            )
            .gather(
                2, 
                points_sequence
                .unsqueeze(2)
            )
            .squeeze(2) #[batchsize, sequence]
            
        ) #[batchsize, sequence]
        
        min_insertion_cost, insertion_location = insertion_cost.min(1) #[batchsize], [batchsize]
        
        running_cost += min_insertion_cost
        
        insertion_indices = ((
            torch.arange(npoints_inserted, device = device)
            .unsqueeze(0)
            .expand(batchsize, -1)
            + insertion_location.unsqueeze(1))
            % npoints_inserted
        ) #[batchsize, sequence]
        
        new_points_sequence = torch.cat([points_sequence.gather(1, insertion_indices), next_point.unsqueeze(1)], dim = 1)
        
        points_sequence = new_points_sequence #[batchsize, sequence + 1]
    
    pair_distances = distance_matrix.gather(1, points_sequence.unsqueeze(2).expand(-1, -1, npoints)).gather(2, points_sequence.roll(1, 1).unsqueeze(2)).squeeze(2) #[batchsize, sequence]
    total_cost = pair_distances.sum(1)
    
    return total_cost.mean()
        
        
        

In [132]:
for i in range(10):
    print(farthest_insertion(1000, 1))

tensor(25.3708)
tensor(25.7802)
tensor(26.5136)
tensor(25.6348)
tensor(25.8166)
tensor(25.9177)
tensor(26.0191)
tensor(25.9999)
tensor(25.5729)
tensor(25.1513)


In [None]:
for i in range(10):
    print(farthest_insertion(1000, 100))

In [2]:
%%bash
nvidia-smi
free -m

Thu Mar 18 00:04:34 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.118.02   Driver Version: 440.118.02   CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   58C    P0   169W / 250W |  16229MiB / 16280MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000000:AF:00.0 Off |                    0 |
| N/A   30C    P0    32W / 250W |    813MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------