In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cuda"


class Graph_Transformer(nn.Module):
    def __init__(self, emsize = 128, nhead = 1, nhid = 512, nlayers = 2, ndecoderlayers = 0, dropout = 0):
        super().__init__()
        self.emsize = emsize
        from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
        encoder_layers = TransformerEncoderLayer(emsize, nhead, nhid, dropout = dropout)
        decoder_layers = TransformerDecoderLayer(emsize, nhead, nhid, dropout = dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.transformer_decoder = TransformerDecoder(decoder_layers, ndecoderlayers)
        self.encoder = nn.Linear(2, emsize)
        self.outputattention_query = nn.Linear(emsize, emsize, bias = False)
        self.outputattention_key = nn.Linear(emsize, emsize, bias = False)
        self.start_token = nn.Parameter(torch.randn([emsize], device = device))
    
    def generate_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones([sz, sz], dtype = torch.bool, device = device), diagonal = 1) #first index < second index for True. True means target (first index) does not get any value from the source (second index)
    
    def encode(self, src): #src must be [batchsize * nsamples, npoints, 2]
        src = self.encoder(src).transpose(0, 1)
        output = self.transformer_encoder(src)
        return output #[npoints, batchsize * nsamples, emsize]
    
    def decode_next(self, memory, tgt, route_mask): #route mask is [batchsize * nsamples, npoints], both memory and tgt must have batchsize and nsamples in same dimension (the 1th one)
        npoints = memory.size(0)
        batchsize = tgt.size(1)
        """if I really wanted this to be efficient I'd only recompute the decoder for the last tgt, and just remebering what the others looked like from before (won't change due to mask)"""
        """have the option to freeze the autograd on all but the last part of tgt, although at the moment this is a very natural way to say: initial choices matter more"""
        tgt_mask = self.generate_subsequent_mask(tgt.size(0))
        output = self.transformer_decoder(tgt, memory, tgt_mask) #[tgt, batchsize * nsamples, emsize]
        output_query = self.outputattention_query(memory).transpose(0, 1) #[batchsize * nsamples, npoints, emsize]
        output_key = self.outputattention_key(output[-1]) #[batchsize * nsamples, emsize]
        output_attention = torch.matmul(output_query * self.emsize ** -0.5, output_key.unsqueeze(-1)).squeeze(-1) #[batchsize * nsamples, npoints], technically don't need to scale attention as we divide by variance next anyway
        output_attention = output_attention.tanh() #[batchsize * nsamples, npoints]
        output_attention = output_attention.masked_fill(route_mask, float('-inf')) #[batchsize * nsamples, npoints]
        return output_attention #[batchsize * nsamples, npoints]
    
    def calculate_logprob(self, memory, routes): #memory is [npoints, batchsize * nsamples, emsize], routes is [batchsize * nsamples, npoints - 3], rather than backproping the entire loop, this saves vram (and computation)
        npoints = memory.size(0)
        ninternalpoints = routes.size(1)
        bigbatchsize = memory.size(1)
        memory_ = memory.gather(0, routes.transpose(0, 1).unsqueeze(2).expand(-1, -1, self.emsize)) #[npoints - 3, batchsize * nsamples, emsize] reorder memory into order of routes
        tgt = torch.cat([self.start_token.unsqueeze(0).unsqueeze(1).expand(1, bigbatchsize, -1), memory_[:-1]]) #[npoints - 3, batchsize * nroutes, emsize], want to go from memory to tgt
        tgt_mask = self.generate_subsequent_mask(ninternalpoints)
        output = self.transformer_decoder(tgt, memory, tgt_mask) #[npoints - 3, batchsize * nsamples, emsize]
        """want probability of going from key to query, but first need to normalise (softmax with mask)"""
        output_query = self.outputattention_query(memory_).transpose(0, 1) #[batchsize * nsamples, npoints - 3, emsize]
        output_key = self.outputattention_key(output).transpose(0, 1) #[batchsize * nsamples, npoints - 3, emsize]
        attention_mask = torch.full([ninternalpoints, ninternalpoints], True, device = device).triu(1) #[npoints - 3, npoints - 3], True for i < j
        output_attention = torch.matmul(output_query * self.emsize ** -0.5, output_key.transpose(-1, -2))
        
        """quick fix to stop divergence"""
        output_attention = output_attention.tanh()
        
        output_attention = output_attention.masked_fill(attention_mask, float('-inf'))
        output_attention = output_attention.softmax(-2) #[batchsize * nsamples, npoints - 3, npoints - 3]
        
        """infact I'm almost tempted to not mask choosing a previous point, so it's forced to learn it and somehow incorporate it into its computation, but without much impact on reinforcing good examples"""
        logprob = output_attention.masked_select(torch.eye(ninternalpoints, dtype = torch.bool, device = device)).view(-1, ninternalpoints).log().sum(-1) #[batchsize * nsamples]
        return logprob #[batchsize * nsamples]

NN = Graph_Transformer().to(device)
optimizer = optim.Adam(NN.parameters(), lr = 0.0001)


class environment():
    def reset(self, npoints, batchsize, nsamples):
        if npoints <= 3:
            print("Error: not enough points for valid problem instance")
            return
        self.batchsize = batchsize * nsamples #so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = torch.rand([batchsize, npoints - 3, 2], device = device).unsqueeze(1).expand(-1, nsamples, -1, -1).reshape(self.batchsize, npoints - 3, 2)
        self.corner_points = torch.tensor([[0, 0], [2, 0], [0, 2]], dtype = torch.float, device = device)
        self.points = torch.cat([self.corner_points.unsqueeze(0).expand(self.batchsize, -1, -1), self.points], dim = -2) #[batchsize * nsamples, npoints, 2]
        self.points_mask = torch.cat([torch.ones([self.batchsize, 3], dtype = torch.bool, device = device), torch.zeros([self.batchsize, npoints - 3], dtype = torch.bool, device = device)], dim = 1)
        self.points_sequence = torch.empty([self.batchsize, 0], dtype = torch.long, device = device)
        
        """use a trick, for the purpose of an 'external' triangle that is always left untouched, which means we don't have to deal with boundary edges as being different. external triangle is [0, 1, 2] traversed clockwise..."""
        self.partial_delaunay_triangles = torch.tensor([[0, 2, 1], [0, 1, 2]], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1, -1).contiguous() #[batchsize, ntriangles, 3] contains index of points, always anticlockwise
        self.partial_delaunay_edges = torch.tensor([5, 4, 3, 2, 1, 0], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1).contiguous() #[batchsize, ntriangles * 3] contains location of corresponding edge (edges go in order 01, 12, 20). Edges will always flip since triangles are stored anticlockwise.
        
        self.ntriangles = 2 #can store as scalar, since will always be the same
        self.cost = torch.zeros([self.batchsize], device = device)
    
    def update(self, point_index): #point_index is [batchsize]
        if point_index.size(0) != self.batchsize:
            print("Error: point_index.size() doesn't match expected size, should be [batchsize * nsamples]")
            return
        if self.points_mask.gather(1, point_index.unsqueeze(1)).sum():
            print("Error: some points already added")
            return
        triangles_coordinates = self.points.gather(1, self.partial_delaunay_triangles.view(self.batchsize, self.ntriangles * 3).unsqueeze(2).expand(-1, -1, 2)).view(self.batchsize, self.ntriangles, 3, 2) #[batchsize, ntriangles, 3, 2]
        newpoint = self.points.gather(1, point_index.unsqueeze(1).unsqueeze(2).expand(self.batchsize, 1, 2)).squeeze(1) #[batchsize, 2]
        
        incircle_matrix = torch.cat([triangles_coordinates, newpoint.unsqueeze(1).unsqueeze(2).expand(-1, self.ntriangles, 1, -1)], dim = -2) #[batchsize, ntriangles, 4, 2]
        incircle_matrix = torch.cat([incircle_matrix, (incircle_matrix * incircle_matrix).sum(-1, keepdim = True), torch.ones([self.batchsize, self.ntriangles, 4, 1], device = device)], dim = -1) #[batchsize, ntriangles, 4, 4]
        incircle_test = incircle_matrix.det() > 0 #[batchsize, ntriangles], is True if inside incircle
        removed_edge_mask = incircle_test.unsqueeze(2).expand(-1, -1, 3).reshape(-1) #[batchsize * ntriangles * 3]
        
        edges = (self.partial_delaunay_edges + self.ntriangles * 3 * torch.arange(self.batchsize, device = device).unsqueeze(1)).view(-1) #[batchsize * ntriangles * 3]
        neighbouring_edge = edges.masked_select(removed_edge_mask)
        neighbouring_edge_mask = torch.zeros([self.batchsize * self.ntriangles * 3], device = device, dtype = torch.bool)
        neighbouring_edge_mask[neighbouring_edge] = True
        neighbouring_edge_mask = (neighbouring_edge_mask * removed_edge_mask.logical_not()) #[batchsize * ntriangles * 3]
        
        n_new_triangles = neighbouring_edge_mask.view(self.batchsize, -1).sum(-1) #[batchsize]
        
        new_point = point_index.unsqueeze(1).expand(-1, self.ntriangles * 3).masked_select(neighbouring_edge_mask.view(self.batchsize, -1))
        
        second_point_mask = neighbouring_edge_mask.view(self.batchsize, -1, 3) #[batchsize, ntriangles 3]
        (first_point_indices0, first_point_indices1, first_point_indices2) = second_point_mask.nonzero(as_tuple = True)
        first_point_indices2 = (first_point_indices2 != 2) * (first_point_indices2 + 1)
        
        first_point = self.partial_delaunay_triangles[first_point_indices0, first_point_indices1, first_point_indices2] #[?]
        second_point = self.partial_delaunay_triangles.masked_select(second_point_mask) #[?]
        
        new_triangles_mask = torch.cat([incircle_test, torch.ones([self.batchsize, 2], dtype = torch.bool, device = device)], dim = 1) #[batchsize, ntriangles + 2]
        
        new_neighbouring_edges = 3 * new_triangles_mask.nonzero(as_tuple = True)[1] #[?], 3* since is the 01 edge of new triangles (see later)
        self.partial_delaunay_edges.masked_scatter_(neighbouring_edge_mask.view(self.batchsize, -1), new_neighbouring_edges) #still [batchsize, ntriangles * 3] for now
        
        self.partial_delaunay_triangles = torch.cat([self.partial_delaunay_triangles, torch.empty([self.batchsize, 2, 3], dtype = torch.long, device = device)], dim = 1)
        self.partial_delaunay_edges = torch.cat([self.partial_delaunay_edges, torch.empty([self.batchsize, 6], dtype = torch.long, device = device)], dim = 1)
        new_triangles = torch.stack([first_point, second_point, new_point], dim = 1) #[?, 3], edge here is flipped compared to edge in neighbouring triangle (so first_point is the second point in neighbouring edge)
        self.partial_delaunay_triangles.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3), new_triangles) #[batchsize, ntriangles + 2, 3]
        
        new_edge01 = neighbouring_edge_mask.view(self.batchsize, -1).nonzero(as_tuple = True)[1] #[?]
        
        """we are currently storing which triangles have to be inserted, via the edges along the perimeter of the delaunay cavity, we need to compute which edge is to the 'left'/'right' of each edge"""
        """don't have the memory to do a batchsize * n * n boolean search, don't have the speed to do a batchsize^2 search (as would be the case for sparse matrix or similar)"""
        """best alternative: rotate the edge around right point, repeat until hit edge in mask (will never go to an edge of a removed triangle before we hit edge in mask) should basically be order 1!!!!!"""
        
        neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?]
        next_neighbouring_edge_index = torch.empty_like(neighbouring_edge_index) #[?]
        
        rotating_flipped_neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?], initialise
        todo_mask = torch.ones_like(next_neighbouring_edge_index, dtype = torch.bool) #[?]
        while todo_mask.sum():
            rotating_neighbouring_edge_index = rotating_flipped_neighbouring_edge_index + 1 - 3 * (rotating_flipped_neighbouring_edge_index % 3 == 2) #[todo_mask.sum()], gets smaller until nothing left EFFICIENCY (this may be seriously stupid, as it requires making a bunch of copies when I could be doing stuff inplace)
            
            update_mask = neighbouring_edge_mask[rotating_neighbouring_edge_index] #[todo_mask.sum()]
            update_mask_unravel = torch.zeros_like(todo_mask).masked_scatter(todo_mask, update_mask) #[?]
            
            next_neighbouring_edge_index.masked_scatter_(update_mask_unravel, rotating_neighbouring_edge_index.masked_select(update_mask)) #[?]
            
            todo_mask.masked_fill_(update_mask_unravel, False) #[?]
            rotating_flipped_neighbouring_edge_index = edges[rotating_neighbouring_edge_index.masked_select(update_mask.logical_not())] #[todo_mask.sum()]
        triangle_index = new_triangles_mask.view(-1).nonzero(as_tuple = True)[0] #[?], index goes up to batchsize * (ntriangles + 2), this is needed for when we invert the permutation by scattering (won't scatter same number of triangles per batch)
        
        next_triangle_index = torch.empty_like(edges).masked_scatter_(neighbouring_edge_mask, triangle_index)[next_neighbouring_edge_index] #[?], index goes up to batchsize * (ntriangles + 2)
        next_edge = 3 * next_triangle_index + 1 #[?]
        
        invert_permutation = torch.empty_like(new_triangles_mask.view(-1), dtype=torch.long) #[batchsize * (ntriangles + 2)]
        invert_permutation[next_triangle_index] = triangle_index #[batchsize * (ntriangles + 2)]
        previous_triangle_index = invert_permutation.masked_select(new_triangles_mask.view(-1)) #[?]
        previous_edge = 3 * previous_triangle_index + 2 #[?]
        
        """in the above we rotated around 'first_point' in our new triangles"""
        new_edge20 = next_edge % ((self.ntriangles + 2) * 3) #[?]
        new_edge12 = previous_edge % ((self.ntriangles + 2) * 3) #[?]
        
        new_edges = torch.stack([new_edge01, new_edge12, new_edge20], dim = 1) #[?, 3]
        self.partial_delaunay_edges.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3).reshape(self.batchsize, -1), new_edges) #[batchsize, (ntriangles + 2) * 3]
        
        self.ntriangles += 2
        """currently only count the extra triangles you replace (not the on you have to remove because you're located there, and not the ones you make because you have to create two more"""
        self.cost += (n_new_triangles - 3)
        self.points_mask.scatter_(1, point_index.unsqueeze(1).expand(-1, self.npoints), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
    
    def sample_point(self, logits): #logits must be [batchsize * nsamples, points]
        probs = torch.distributions.categorical.Categorical(logits = logits)
        next_point = probs.sample() #size is [batchsize * nsamples]
        self.update(next_point)
        return next_point #[batchsize * nsamples]
    
    def sampleandgreedy_point(self, logits): #logits must be [batchsize * nsamples, npoints], last sample will be the greedy choice (but we still need to keep track of its logits)
        probs = torch.distributions.categorical.Categorical(logits = logits.view(-1, self.nsamples, self.npoints)[:, :-1, :])
        sample_point = probs.sample() #[batchsize, (nsamples - 1)]
        greedy_point = logits.view(-1, self.nsamples, self.npoints)[:, -1, :].max(-1, keepdim = True)[1] #[batchsize, 1]
        next_point = torch.cat([sample_point, greedy_point], dim = 1).view(-1)
        self.update(next_point)
        return next_point
    

env = environment()


def train(epochs = 30000, npoints = 5, batchsize = 100, nsamples = 100):
    NN.train()
    for i in range(epochs):
        env.reset(npoints, batchsize, nsamples)
        """include the boundary points, kinda makes sense that they should contribute (atm only in the encoder, difficult to see how in the decoder)"""
        memory = NN.encode(env.points)
#        print('encode done')
        tgt = NN.start_token.unsqueeze(0).unsqueeze(1).expand(1, batchsize * nsamples, -1).detach()
        with torch.no_grad(): #to speed up computation, selecting routes is done without gradient
            for j in range(3, npoints):
                logits = NN.decode_next(memory.detach(), tgt, env.points_mask)
#                print('decode step done')
                next_point = env.sampleandgreedy_point(logits)
                tgt = torch.cat([tgt, memory.gather(0, next_point.unsqueeze(0).unsqueeze(2).expand(1, -1, memory.size(2)))])
        
        logprob = NN.calculate_logprob(memory, env.points_sequence) #[batchsize * nsamples]
#        print('logprob done')
        baseline = env.cost.view(batchsize, nsamples)[:, -1] #[batchsize], greedy sample
        loss = (logprob.view(batchsize, nsamples)[:, :-1] * (env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1))).mean() #don't include greedy choice (tbh probably doesn't matter)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#        print('backprop done')
        print(baseline.mean().item())

In [2]:
train() #6 total

1.3633333444595337
1.0533334016799927
1.0333333015441895
1.0466667413711548
1.0566667318344116
1.0166666507720947
0.996666669845581
1.0333333015441895
1.0266666412353516
0.9800000190734863
0.9633333683013916
1.1166666746139526
1.0833333730697632
1.0
1.1100000143051147
1.1200000047683716
1.003333330154419
0.9933333396911621
0.9333333373069763
1.0166666507720947
1.1233333349227905
1.1066666841506958
1.0800000429153442
1.1333333253860474
1.090000033378601
1.0299999713897705
0.9900000095367432
0.9700000286102295
1.0133333206176758
0.9433333277702332
0.9766666889190674
0.9133333563804626
0.9399999976158142
0.9533333778381348
0.9466667175292969
0.9133333563804626
0.9233333468437195
0.9433333277702332
0.9066666960716248
0.9266666769981384
0.9866666793823242
0.9866666793823242
0.9900000095367432
1.0466667413711548
1.0733333826065063
1.0633333921432495
1.1666667461395264
1.096666693687439
1.15666663646698
1.2066667079925537
1.2233333587646484
1.3133333921432495
1.443333387374878
1.3333333730697

0.9700000286102295
0.9766666889190674
1.0766667127609253
1.0333333015441895
0.9500000476837158
0.8533333539962769
0.9700000286102295
0.9666666984558105
0.9600000381469727
0.9000000357627869
0.9233333468437195
0.9266666769981384
0.9066666960716248
0.89000004529953
0.8966667056083679
0.9200000166893005
0.8799999952316284
0.9700000286102295
0.9266666769981384
0.9433333277702332
0.89000004529953
0.9399999976158142
0.9900000095367432
0.9200000166893005
0.9366666674613953
0.9200000166893005
0.9566667079925537
0.9399999976158142
0.9933333396911621
0.9866666793823242
0.9866666793823242
1.0199999809265137
0.9433333277702332
0.9666666984558105
0.9766666889190674
1.0199999809265137
0.9700000286102295
1.0633333921432495
1.036666750907898
1.0466667413711548
1.006666660308838
0.9433333277702332
0.9466667175292969
1.003333330154419
0.9533333778381348
1.0333333015441895
0.9033333659172058
0.9466667175292969
0.9399999976158142
0.8366667032241821
0.89000004529953
0.9000000357627869
0.9233333468437195
0.

0.8799999952316284
0.8633333444595337
1.003333330154419
0.9033333659172058
0.9166666865348816
1.0333333015441895
1.0233333110809326
1.1266666650772095
1.0666667222976685
1.0533334016799927
0.9800000190734863
0.9700000286102295
1.0433334112167358
0.9500000476837158
0.9399999976158142
0.9133333563804626
0.9233333468437195
0.8733333349227905
0.8666666746139526
0.9033333659172058
0.9500000476837158
0.9100000262260437
0.9066666960716248
0.8966667056083679
0.9933333396911621
0.9533333778381348
0.9566667079925537
0.9566667079925537
0.9233333468437195
1.0166666507720947
1.006666660308838
1.0433334112167358
0.9399999976158142
1.0299999713897705
0.9000000357627869
0.89000004529953
0.9266666769981384
0.9200000166893005
0.9100000262260437
0.9266666769981384
0.8633333444595337
0.9433333277702332
0.8833333253860474
0.9066666960716248
0.8966667056083679
0.8666666746139526
0.9666666984558105
0.9033333659172058
0.8833333253860474
0.89000004529953
0.8566666841506958
0.8733333349227905
0.8733333349227905

KeyboardInterrupt: 

In [2]:
train() #5 total

0.5099999904632568
0.3700000047683716
0.3199999928474426
0.35999998450279236
0.3100000023841858
0.25999999046325684
0.23999999463558197
0.2199999988079071
0.25999999046325684
0.2800000011920929
0.2199999988079071
0.14999999105930328
0.23999999463558197
0.3199999928474426
0.22999998927116394
0.3999999761581421
0.23999999463558197
0.3400000035762787
0.2800000011920929
0.2199999988079071
0.29999998211860657
0.17999999225139618
0.20999999344348907
0.3100000023841858
0.23999999463558197
0.23999999463558197
0.22999998927116394
0.19999998807907104
0.17000000178813934
0.19999998807907104
0.20999999344348907
0.2199999988079071
0.25999999046325684
0.20999999344348907
0.22999998927116394
0.23999999463558197
0.23999999463558197
0.1899999976158142
0.22999998927116394
0.1599999964237213
0.14999999105930328
0.22999998927116394
0.1599999964237213
0.20999999344348907
0.14999999105930328
0.20999999344348907
0.19999998807907104
0.28999999165534973
0.2800000011920929
0.28999999165534973
0.2599999904632568

0.22999998927116394
0.26999998092651367
0.2800000011920929
0.20999999344348907
0.2199999988079071
0.19999998807907104
0.14000000059604645
0.17999999225139618
0.1599999964237213
0.1899999976158142
0.2800000011920929
0.19999998807907104
0.2199999988079071
0.17999999225139618
0.20999999344348907
0.25
0.1899999976158142
0.1599999964237213
0.17999999225139618
0.28999999165534973
0.17999999225139618
0.25
0.2199999988079071
0.17000000178813934
0.22999998927116394
0.20999999344348907
0.2199999988079071
0.17999999225139618
0.25
0.1899999976158142
0.17999999225139618
0.17999999225139618
0.17000000178813934
0.1899999976158142
0.22999998927116394
0.25999999046325684
0.20999999344348907
0.25
0.19999998807907104
0.19999998807907104
0.2199999988079071
0.1899999976158142
0.2199999988079071
0.1899999976158142
0.12999999523162842
0.25999999046325684
0.20999999344348907
0.19999998807907104
0.1899999976158142
0.25
0.1899999976158142
0.25999999046325684


KeyboardInterrupt: 

In [2]:
train() #5 total, 0 decoder layers, dropout = 0, adam, lr = 0.0001, 2 encoder layers, 1 head, emsize = 128, nhid = 512

0.8100000023841858
0.6699999570846558
0.2199999988079071
0.25
0.26999998092651367
0.28999999165534973
0.2199999988079071
0.2199999988079071
0.28999999165534973
0.14999999105930328
0.19999998807907104
0.17999999225139618
0.2199999988079071
0.17000000178813934
0.1899999976158142
0.23999999463558197
0.1899999976158142
0.2199999988079071
0.17000000178813934
0.1899999976158142
0.22999998927116394
0.25999999046325684
0.2199999988079071
0.1599999964237213
0.2800000011920929
0.25
0.17000000178813934
0.10999999940395355
0.1899999976158142
0.2199999988079071
0.23999999463558197
0.1599999964237213
0.2199999988079071
0.2199999988079071
0.19999998807907104
0.22999998927116394
0.14000000059604645
0.17000000178813934
0.19999998807907104
0.23999999463558197
0.14000000059604645
0.25
0.2199999988079071
0.17000000178813934
0.1599999964237213
0.17000000178813934
0.17000000178813934
0.1899999976158142
0.2199999988079071
0.19999998807907104
0.29999998211860657
0.23999999463558197
0.17999999225139618
0.23999

0.1899999976158142
0.17000000178813934
0.20999999344348907
0.14000000059604645
0.2199999988079071
0.17999999225139618
0.17999999225139618
0.11999999731779099
0.25
0.14999999105930328
0.2199999988079071
0.20999999344348907
0.20999999344348907
0.17999999225139618
0.17000000178813934
0.1899999976158142
0.2199999988079071
0.23999999463558197
0.17999999225139618
0.23999999463558197
0.1899999976158142
0.20999999344348907
0.14999999105930328
0.2199999988079071
0.1899999976158142
0.17000000178813934
0.17999999225139618
0.14999999105930328
0.1599999964237213
0.20999999344348907
0.17000000178813934
0.19999998807907104
0.2199999988079071
0.17999999225139618
0.20999999344348907
0.25
0.09999999403953552
0.14999999105930328
0.23999999463558197
0.11999999731779099
0.22999998927116394
0.19999998807907104
0.22999998927116394
0.17999999225139618
0.19999998807907104
0.19999998807907104
0.23999999463558197
0.20999999344348907
0.17999999225139618
0.2199999988079071
0.14999999105930328
0.17999999225139618
0

0.2199999988079071
0.22999998927116394
0.19999998807907104
0.1899999976158142
0.17000000178813934
0.25999999046325684
0.2199999988079071
0.2199999988079071
0.26999998092651367
0.2199999988079071
0.25
0.17000000178813934
0.23999999463558197
0.22999998927116394
0.25999999046325684
0.3100000023841858
0.2800000011920929
0.20999999344348907
0.14999999105930328
0.2199999988079071
0.23999999463558197
0.25
0.17999999225139618
0.17000000178813934
0.2199999988079071
0.1899999976158142
0.2199999988079071
0.2199999988079071
0.19999998807907104
0.17000000178813934
0.1599999964237213
0.2199999988079071
0.17000000178813934
0.17999999225139618
0.2199999988079071
0.25
0.12999999523162842
0.12999999523162842
0.14000000059604645
0.17999999225139618
0.23999999463558197
0.25999999046325684
0.17999999225139618
0.19999998807907104
0.14999999105930328
0.1899999976158142
0.25
0.17999999225139618
0.1599999964237213
0.17000000178813934
0.1899999976158142
0.19999998807907104
0.25999999046325684
0.1099999994039535

0.17999999225139618
0.23999999463558197
0.23999999463558197
0.25
0.19999998807907104
0.25999999046325684
0.19999998807907104
0.17999999225139618
0.22999998927116394
0.23999999463558197
0.22999998927116394
0.14000000059604645
0.1599999964237213
0.17999999225139618
0.19999998807907104
0.14999999105930328
0.2199999988079071
0.17000000178813934
0.2199999988079071
0.26999998092651367
0.1599999964237213
0.20999999344348907
0.19999998807907104
0.20999999344348907
0.14999999105930328
0.23999999463558197
0.1899999976158142
0.14999999105930328
0.2199999988079071
0.19999998807907104
0.1599999964237213
0.26999998092651367
0.20999999344348907
0.20999999344348907
0.20999999344348907
0.1599999964237213
0.25999999046325684
0.20999999344348907
0.19999998807907104
0.2199999988079071
0.2199999988079071
0.14000000059604645
0.25
0.19999998807907104
0.14999999105930328
0.22999998927116394
0.17999999225139618
0.23999999463558197
0.1899999976158142
0.1899999976158142
0.14999999105930328
0.1599999964237213
0.1

0.25999999046325684
0.2800000011920929
0.1899999976158142
0.1599999964237213
0.1899999976158142
0.22999998927116394
0.12999999523162842
0.17999999225139618
0.17000000178813934
0.19999998807907104
0.17000000178813934
0.23999999463558197
0.1899999976158142
0.14999999105930328
0.26999998092651367
0.22999998927116394
0.17999999225139618
0.1599999964237213
0.19999998807907104
0.07999999821186066
0.17000000178813934
0.1599999964237213
0.2199999988079071
0.2199999988079071
0.26999998092651367
0.17000000178813934
0.19999998807907104
0.2199999988079071
0.2199999988079071
0.19999998807907104
0.2199999988079071
0.17999999225139618
0.2199999988079071
0.2199999988079071
0.22999998927116394
0.1899999976158142
0.23999999463558197
0.14000000059604645
0.17999999225139618
0.2199999988079071
0.14000000059604645
0.2800000011920929
0.22999998927116394
0.23999999463558197
0.14000000059604645
0.2199999988079071
0.14999999105930328
0.1599999964237213
0.1899999976158142
0.2800000011920929
0.20999999344348907
0

0.17999999225139618
0.23999999463558197
0.25999999046325684
0.1599999964237213
0.2199999988079071
0.17999999225139618
0.11999999731779099
0.1599999964237213
0.2199999988079071
0.1599999964237213
0.1899999976158142
0.25
0.14999999105930328
0.14999999105930328
0.2199999988079071
0.20999999344348907
0.23999999463558197
0.17999999225139618
0.19999998807907104
0.17999999225139618
0.20999999344348907
0.14000000059604645
0.28999999165534973
0.1599999964237213
0.19999998807907104
0.19999998807907104
0.12999999523162842
0.20999999344348907
0.17999999225139618
0.14999999105930328
0.19999998807907104
0.17000000178813934
0.1599999964237213
0.1599999964237213
0.14000000059604645
0.25999999046325684
0.12999999523162842
0.22999998927116394
0.1599999964237213
0.23999999463558197
0.2199999988079071
0.1599999964237213
0.28999999165534973
0.22999998927116394
0.17999999225139618
0.17000000178813934
0.14000000059604645
0.26999998092651367
0.17999999225139618
0.20999999344348907
0.2199999988079071
0.1499999

0.12999999523162842
0.1899999976158142
0.1899999976158142
0.20999999344348907
0.19999998807907104
0.1599999964237213
0.1899999976158142
0.2199999988079071
0.1899999976158142
0.14999999105930328
0.23999999463558197
0.19999998807907104
0.17999999225139618
0.09999999403953552
0.20999999344348907
0.26999998092651367
0.25
0.17999999225139618
0.19999998807907104
0.14000000059604645
0.12999999523162842
0.1899999976158142
0.2199999988079071
0.22999998927116394
0.20999999344348907
0.12999999523162842
0.17000000178813934
0.20999999344348907
0.25
0.22999998927116394
0.17000000178813934
0.19999998807907104
0.23999999463558197
0.1899999976158142
0.22999998927116394
0.28999999165534973
0.19999998807907104
0.1599999964237213
0.2199999988079071
0.17000000178813934
0.22999998927116394
0.25
0.1899999976158142
0.14999999105930328
0.20999999344348907
0.14000000059604645
0.2199999988079071
0.2199999988079071
0.20999999344348907
0.20999999344348907
0.26999998092651367
0.1899999976158142
0.14999999105930328


0.17999999225139618
0.17999999225139618
0.19999998807907104
0.22999998927116394
0.17000000178813934
0.19999998807907104
0.1599999964237213
0.17000000178813934
0.19999998807907104
0.17000000178813934
0.17000000178813934
0.25
0.22999998927116394
0.1599999964237213
0.1899999976158142
0.2800000011920929
0.17999999225139618
0.1899999976158142
0.25999999046325684
0.17999999225139618
0.14999999105930328
0.17000000178813934
0.1899999976158142
0.1899999976158142
0.20999999344348907
0.1899999976158142
0.23999999463558197
0.17999999225139618
0.1599999964237213
0.20999999344348907
0.1899999976158142
0.19999998807907104
0.17000000178813934
0.17999999225139618
0.12999999523162842
0.2199999988079071
0.1599999964237213
0.17999999225139618
0.17000000178813934
0.19999998807907104
0.17000000178813934
0.1599999964237213
0.12999999523162842
0.1899999976158142
0.20999999344348907
0.2199999988079071
0.1899999976158142
0.2199999988079071
0.1599999964237213
0.1899999976158142
0.19999998807907104
0.280000001192

0.19999998807907104
0.20999999344348907
0.14999999105930328
0.19999998807907104
0.20999999344348907
0.32999998331069946
0.22999998927116394
0.1599999964237213
0.20999999344348907
0.11999999731779099
0.19999998807907104
0.19999998807907104
0.14999999105930328
0.25999999046325684
0.19999998807907104
0.1599999964237213
0.19999998807907104
0.19999998807907104
0.22999998927116394
0.09999999403953552
0.1599999964237213
0.19999998807907104
0.3100000023841858
0.17000000178813934
0.25
0.22999998927116394
0.23999999463558197
0.1599999964237213
0.20999999344348907
0.25
0.19999998807907104
0.1599999964237213
0.19999998807907104
0.1899999976158142
0.1899999976158142
0.26999998092651367
0.2199999988079071
0.14999999105930328
0.17000000178813934
0.14000000059604645
0.22999998927116394
0.22999998927116394
0.2199999988079071
0.1599999964237213
0.20999999344348907
0.1599999964237213
0.20999999344348907
0.1899999976158142
0.1899999976158142
0.20999999344348907
0.20999999344348907
0.19999998807907104
0.15

0.28999999165534973
0.17000000178813934
0.17000000178813934
0.17000000178813934
0.10999999940395355
0.14000000059604645
0.17000000178813934
0.20999999344348907
0.1599999964237213
0.09999999403953552
0.12999999523162842
0.14000000059604645
0.12999999523162842
0.20999999344348907
0.1599999964237213
0.20999999344348907
0.17999999225139618
0.29999998211860657
0.2199999988079071
0.23999999463558197
0.22999998927116394
0.19999998807907104
0.20999999344348907
0.25
0.23999999463558197
0.17000000178813934
0.19999998807907104
0.23999999463558197
0.22999998927116394
0.12999999523162842
0.17000000178813934
0.2199999988079071
0.1599999964237213
0.17000000178813934
0.1899999976158142
0.11999999731779099
0.25
0.17999999225139618
0.22999998927116394
0.14000000059604645
0.23999999463558197
0.19999998807907104
0.12999999523162842
0.14999999105930328
0.25
0.19999998807907104
0.17000000178813934
0.17000000178813934
0.17000000178813934
0.17000000178813934
0.19999998807907104
0.1599999964237213
0.1799999922

0.10999999940395355
0.22999998927116394
0.28999999165534973
0.17999999225139618
0.19999998807907104
0.23999999463558197
0.25
0.17999999225139618
0.10999999940395355
0.17999999225139618
0.19999998807907104
0.2199999988079071
0.25
0.17999999225139618
0.14999999105930328
0.14000000059604645
0.14999999105930328
0.17000000178813934
0.19999998807907104
0.2199999988079071
0.12999999523162842
0.14999999105930328
0.2199999988079071
0.20999999344348907
0.22999998927116394
0.1899999976158142
0.17000000178813934
0.2199999988079071
0.14000000059604645
0.2199999988079071
0.23999999463558197
0.14000000059604645
0.14000000059604645
0.14000000059604645
0.1899999976158142
0.17000000178813934
0.2199999988079071
0.17000000178813934
0.17999999225139618
0.14999999105930328
0.2199999988079071
0.25
0.20999999344348907
0.2199999988079071
0.19999998807907104
0.1899999976158142
0.10999999940395355
0.19999998807907104
0.17000000178813934
0.1899999976158142
0.14999999105930328
0.22999998927116394
0.179999992251396

0.14000000059604645
0.20999999344348907
0.23999999463558197
0.23999999463558197
0.14999999105930328
0.2199999988079071
0.2199999988079071
0.22999998927116394
0.09999999403953552
0.1599999964237213
0.14000000059604645
0.17999999225139618
0.17000000178813934
0.14999999105930328
0.1599999964237213
0.17999999225139618
0.2199999988079071
0.17999999225139618
0.25
0.17000000178813934
0.20999999344348907
0.1899999976158142
0.1599999964237213
0.1899999976158142
0.2199999988079071
0.2199999988079071
0.1599999964237213
0.17999999225139618
0.17000000178813934
0.2199999988079071
0.3100000023841858
0.23999999463558197
0.23999999463558197
0.2199999988079071
0.1599999964237213
0.25999999046325684
0.17999999225139618
0.17999999225139618
0.19999998807907104
0.25999999046325684
0.2199999988079071
0.17999999225139618
0.23999999463558197
0.1899999976158142
0.14000000059604645
0.11999999731779099
0.17000000178813934
0.19999998807907104
0.1899999976158142
0.22999998927116394
0.17999999225139618
0.25
0.259999

KeyboardInterrupt: 

In [3]:
a = torch.rand([4])
a

tensor([0.4737, 0.5515, 0.9339, 0.9452])

In [4]:
a.tanh()
a

tensor([0.4737, 0.5515, 0.9339, 0.9452])

In [3]:
c = torch.rand([5, 5, 5])
c

tensor([[[0.8950, 0.5087, 0.9918, 0.3349, 0.9168],
         [0.3521, 0.1771, 0.9841, 0.2360, 0.0987],
         [0.2457, 0.6453, 0.0882, 0.6625, 0.8833],
         [0.3182, 0.4032, 0.7740, 0.2487, 0.7766],
         [0.4535, 0.3391, 0.6288, 0.9605, 0.6645]],

        [[0.1674, 0.4566, 0.2110, 0.6609, 0.2833],
         [0.9507, 0.1512, 0.3692, 0.4357, 0.5276],
         [0.7551, 0.3018, 0.4106, 0.7149, 0.6627],
         [0.2312, 0.0712, 0.8497, 0.8244, 0.9473],
         [0.4749, 0.4027, 0.2531, 0.7649, 0.8583]],

        [[0.9508, 0.3054, 0.8421, 0.1475, 0.0499],
         [0.9176, 0.4006, 0.2401, 0.5482, 0.9270],
         [0.8670, 0.4944, 0.1959, 0.0772, 0.1017],
         [0.7348, 0.2407, 0.1236, 0.1750, 0.9848],
         [0.6862, 0.1035, 0.6500, 0.9341, 0.3612]],

        [[0.1083, 0.9244, 0.6368, 0.9206, 0.0429],
         [0.8638, 0.9787, 0.6948, 0.5904, 0.7577],
         [0.3385, 0.1159, 0.5848, 0.3209, 0.6606],
         [0.7108, 0.0280, 0.8824, 0.4495, 0.7364],
         [0.7790, 0.8497,

In [8]:
c[:, -1, :].max(1, keepdim = True)[1].size()

torch.Size([5, 1])

In [4]:
c.scatter(-1, torch.tensor([0, 1]).unsqueeze(0).unsqueeze(0).expand(5, 5, -1), 0)

tensor([[[0.0000, 0.0000, 0.9918, 0.3349, 0.9168],
         [0.0000, 0.0000, 0.9841, 0.2360, 0.0987],
         [0.0000, 0.0000, 0.0882, 0.6625, 0.8833],
         [0.0000, 0.0000, 0.7740, 0.2487, 0.7766],
         [0.0000, 0.0000, 0.6288, 0.9605, 0.6645]],

        [[0.0000, 0.0000, 0.2110, 0.6609, 0.2833],
         [0.0000, 0.0000, 0.3692, 0.4357, 0.5276],
         [0.0000, 0.0000, 0.4106, 0.7149, 0.6627],
         [0.0000, 0.0000, 0.8497, 0.8244, 0.9473],
         [0.0000, 0.0000, 0.2531, 0.7649, 0.8583]],

        [[0.0000, 0.0000, 0.8421, 0.1475, 0.0499],
         [0.0000, 0.0000, 0.2401, 0.5482, 0.9270],
         [0.0000, 0.0000, 0.1959, 0.0772, 0.1017],
         [0.0000, 0.0000, 0.1236, 0.1750, 0.9848],
         [0.0000, 0.0000, 0.6500, 0.9341, 0.3612]],

        [[0.0000, 0.0000, 0.6368, 0.9206, 0.0429],
         [0.0000, 0.0000, 0.6948, 0.5904, 0.7577],
         [0.0000, 0.0000, 0.5848, 0.3209, 0.6606],
         [0.0000, 0.0000, 0.8824, 0.4495, 0.7364],
         [0.0000, 0.0000,

In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cuda"


class environment():
    def reset(self, npoints, batchsize, nsamples):
        if npoints <= 3:
            print("Error: not enough points for valid problem instance")
            return
        self.batchsize = batchsize * nsamples #so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = torch.rand([batchsize, npoints - 3, 2], device = device).unsqueeze(1).expand(-1, nsamples, -1, -1).reshape(self.batchsize, npoints - 3, 2)
        self.corner_points = torch.tensor([[0, 0], [2, 0], [0, 2]], dtype = torch.float, device = device)
        self.points = torch.cat([self.corner_points.unsqueeze(0).expand(self.batchsize, -1, -1), self.points], dim = -2) #[batchsize * nsamples, npoints, 2]
        self.points_mask = torch.cat([torch.ones([self.batchsize, 3], dtype = torch.bool, device = device), torch.zeros([self.batchsize, npoints - 3], dtype = torch.bool, device = device)], dim = 1)
        self.points_sequence = torch.empty([self.batchsize, 0], dtype = torch.long, device = device)
        
        """use a trick, for the purpose of an 'external' triangle that is always left untouched, which means we don't have to deal with boundary edges as being different. external triangle is [0, 1, 2] traversed clockwise..."""
        self.partial_delaunay_triangles = torch.tensor([[0, 2, 1], [0, 1, 2]], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1, -1).contiguous() #[batchsize, ntriangles, 3] contains index of points, always anticlockwise
        self.partial_delaunay_edges = torch.tensor([5, 4, 3, 2, 1, 0], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1).contiguous() #[batchsize, ntriangles * 3] contains location of corresponding edge (edges go in order 01, 12, 20). Edges will always flip since triangles are stored anticlockwise.
        
        self.ntriangles = 2 #can store as scalar, since will always be the same
        self.cost = torch.zeros([self.batchsize], device = device)
    
    def update(self, point_index): #point_index is [batchsize]
        if point_index.size(0) != self.batchsize:
            print("Error: point_index.size() doesn't match expected size, should be [batchsize]")
            return
        if self.points_mask.gather(1, point_index.unsqueeze(1)).sum():
            print("Error: some points already added")
            return
        triangles_coordinates = self.points.gather(1, self.partial_delaunay_triangles.view(self.batchsize, self.ntriangles * 3).unsqueeze(2).expand(-1, -1, 2)).view(self.batchsize, self.ntriangles, 3, 2) #[batchsize, ntriangles, 3, 2]
        newpoint = self.points.gather(1, point_index.unsqueeze(1).unsqueeze(2).expand(self.batchsize, 1, 2)).squeeze(1) #[batchsize, 2]
        
        incircle_matrix = torch.cat([triangles_coordinates, newpoint.unsqueeze(1).unsqueeze(2).expand(-1, self.ntriangles, 1, -1)], dim = -2) #[batchsize, ntriangles, 4, 2]
        incircle_matrix = torch.cat([incircle_matrix, (incircle_matrix * incircle_matrix).sum(-1, keepdim = True), torch.ones([self.batchsize, self.ntriangles, 4, 1], device = device)], dim = -1) #[batchsize, ntriangles, 4, 4]
        incircle_test = incircle_matrix.det() > 0 #[batchsize, ntriangles], is True if inside incircle
        removed_edge_mask = incircle_test.unsqueeze(2).expand(-1, -1, 3).reshape(-1) #[batchsize * ntriangles * 3]
        
        edges = (self.partial_delaunay_edges + self.ntriangles * 3 * torch.arange(self.batchsize, device = device).unsqueeze(1)).view(-1) #[batchsize * ntriangles * 3]
        neighbouring_edge = edges.masked_select(removed_edge_mask)
        neighbouring_edge_mask = torch.zeros([self.batchsize * self.ntriangles * 3], device = device, dtype = torch.bool)
        neighbouring_edge_mask[neighbouring_edge] = True
        neighbouring_edge_mask = (neighbouring_edge_mask * removed_edge_mask.logical_not()) #[batchsize * ntriangles * 3]
        
        n_new_triangles = neighbouring_edge_mask.view(self.batchsize, -1).sum(-1) #[batchsize]
        
        new_point = point_index.unsqueeze(1).expand(-1, self.ntriangles * 3).masked_select(neighbouring_edge_mask.view(self.batchsize, -1))
        
        second_point_mask = neighbouring_edge_mask.view(self.batchsize, -1, 3) #[batchsize, ntriangles 3]
        (first_point_indices0, first_point_indices1, first_point_indices2) = second_point_mask.nonzero(as_tuple = True)
        first_point_indices2 = (first_point_indices2 != 2) * (first_point_indices2 + 1)
        
        first_point = self.partial_delaunay_triangles[first_point_indices0, first_point_indices1, first_point_indices2] #[?]
        second_point = self.partial_delaunay_triangles.masked_select(second_point_mask) #[?]
        
        new_triangles_mask = torch.cat([incircle_test, torch.ones([self.batchsize, 2], dtype = torch.bool, device = device)], dim = 1) #[batchsize, ntriangles + 2]
        
        new_neighbouring_edges = 3 * new_triangles_mask.nonzero(as_tuple = True)[1] #[?], 3* since is the 01 edge of new triangles (see later)
        self.partial_delaunay_edges.masked_scatter_(neighbouring_edge_mask.view(self.batchsize, -1), new_neighbouring_edges) #still [batchsize, ntriangles * 3] for now
        
        self.partial_delaunay_triangles = torch.cat([self.partial_delaunay_triangles, torch.empty([self.batchsize, 2, 3], dtype = torch.long, device = device)], dim = 1)
        self.partial_delaunay_edges = torch.cat([self.partial_delaunay_edges, torch.empty([self.batchsize, 6], dtype = torch.long, device = device)], dim = 1)
        new_triangles = torch.stack([first_point, second_point, new_point], dim = 1) #[?, 3], edge here is flipped compared to edge in neighbouring triangle (so first_point is the second point in neighbouring edge)
        self.partial_delaunay_triangles.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3), new_triangles) #[batchsize, ntriangles + 2, 3]
        
        new_edge01 = neighbouring_edge_mask.view(self.batchsize, -1).nonzero(as_tuple = True)[1] #[?]
        
        """we are currently storing which triangles have to be inserted, via the edges along the perimeter of the delaunay cavity, we need to compute which edge is to the 'left'/'right' of each edge"""
        """don't have the memory to do a batchsize * n * n boolean search, don't have the speed to do a batchsize^2 search (as would be the case for sparse matrix or similar)"""
        """best alternative: rotate the edge around right point, repeat until hit edge in mask (will never go to an edge of a removed triangle before we hit edge in mask) should basically be order 1!!!!!"""
        
        neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?]
        next_neighbouring_edge_index = torch.empty_like(neighbouring_edge_index) #[?]
        
        rotating_flipped_neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?], initialise
        todo_mask = torch.ones_like(next_neighbouring_edge_index, dtype = torch.bool) #[?]
        while todo_mask.sum():
            rotating_neighbouring_edge_index = rotating_flipped_neighbouring_edge_index + 1 - 3 * (rotating_flipped_neighbouring_edge_index % 3 == 2) #[todo_mask.sum()], gets smaller until nothing left EFFICIENCY (this may be seriously stupid, as it requires making a bunch of copies when I could be doing stuff inplace)
            
            update_mask = neighbouring_edge_mask[rotating_neighbouring_edge_index] #[todo_mask.sum()]
            update_mask_unravel = torch.zeros_like(todo_mask).masked_scatter(todo_mask, update_mask) #[?]
            
            next_neighbouring_edge_index.masked_scatter_(update_mask_unravel, rotating_neighbouring_edge_index.masked_select(update_mask)) #[?]
            
            todo_mask.masked_fill_(update_mask_unravel, False) #[?]
            rotating_flipped_neighbouring_edge_index = edges[rotating_neighbouring_edge_index.masked_select(update_mask.logical_not())] #[todo_mask.sum()]
        triangle_index = new_triangles_mask.view(-1).nonzero(as_tuple = True)[0] #[?], index goes up to batchsize * (ntriangles + 2), this is needed for when we invert the permutation by scattering (won't scatter same number of triangles per batch)
        
        next_triangle_index = torch.empty_like(edges).masked_scatter_(neighbouring_edge_mask, triangle_index)[next_neighbouring_edge_index] #[?], index goes up to batchsize * (ntriangles + 2)
        next_edge = 3 * next_triangle_index + 1 #[?]
        
        invert_permutation = torch.empty_like(new_triangles_mask.view(-1), dtype=torch.long) #[batchsize * (ntriangles + 2)]
        invert_permutation[next_triangle_index] = triangle_index #[batchsize * (ntriangles + 2)]
        previous_triangle_index = invert_permutation.masked_select(new_triangles_mask.view(-1)) #[?]
        previous_edge = 3 * previous_triangle_index + 2 #[?]
        
        """in the above we rotated around 'first_point' in our new triangles"""
        new_edge20 = next_edge % ((self.ntriangles + 2) * 3) #[?]
        new_edge12 = previous_edge % ((self.ntriangles + 2) * 3) #[?]
        
        new_edges = torch.stack([new_edge01, new_edge12, new_edge20], dim = 1) #[?, 3]
        self.partial_delaunay_edges.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3).reshape(self.batchsize, -1), new_edges) #[batchsize, (ntriangles + 2) * 3]
        
        self.ntriangles += 2
        """currently only count the extra triangles you replace (not the on you have to remove because you're located there, and not the ones you make because you have to create two more"""
        self.cost += (n_new_triangles - 3)
        self.points_mask.scatter_(1, point_index.unsqueeze(1).expand(-1, self.npoints), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
    
    def allindices(self): #generate all orders of point insertion
        npoints = self.npoints - 3
        allroutes = torch.empty([1, 0], dtype = torch.long, device = device)
        for i in range(npoints):
            nroutes = allroutes.size(0)
            remaining_mask = torch.ones([nroutes], dtype = torch.bool, device = device).unsqueeze(1).expand(-1, npoints).clone().scatter_(-1, allroutes, False)
            remaining_indices = remaining_mask.nonzero(as_tuple = True)[1]
            allroutes = allroutes.unsqueeze(1).expand(-1, remaining_mask[0, :].sum(), -1)
            allroutes = torch.cat([allroutes, remaining_indices.view(nroutes, -1).unsqueeze(2)], dim = -1).view(-1, allroutes.size(-1) + 1)
        return allroutes #[npoints!, npoints]


env = environment()
npoints = 8
batchsize = 7
env.reset(npoints + 3, batchsize, math.factorial(npoints))
allroutes = env.allindices() + 3
allroutes = allroutes.unsqueeze(0).expand(batchsize, -1, -1).reshape(-1, npoints)
for j in range(300):
    for i in range(npoints):
        env.update(allroutes[:, i])
    print(env.cost.view(batchsize, -1).min(-1)[0].mean().item())
    env.reset(npoints + 3, batchsize, math.factorial(npoints))

5.142857551574707
4.714285850524902
3.857142925262451
4.285714626312256
4.5714287757873535
4.428571701049805
4.428571701049805
4.428571701049805
4.142857551574707
4.0
4.142857551574707
3.000000238418579
4.857142925262451
4.428571701049805
4.5714287757873535
4.5714287757873535
4.428571701049805
4.428571701049805
4.714285850524902
4.5714287757873535
4.142857551574707
3.7142858505249023
4.0
4.142857551574707
4.714285850524902
4.5714287757873535
4.428571701049805
4.285714626312256
4.0
4.285714626312256
4.428571701049805
4.714285850524902
4.285714626312256
5.0
4.5714287757873535
4.285714626312256
4.5714287757873535
4.857142925262451
4.142857551574707
4.285714626312256
4.0
4.142857551574707
4.0
4.428571701049805
4.428571701049805
4.285714626312256
4.142857551574707
4.142857551574707
4.142857551574707
4.5714287757873535
4.142857551574707
4.285714626312256
3.5714287757873535
4.285714626312256
4.0
4.285714626312256
4.428571701049805
4.5714287757873535
4.857142925262451
3.7142858505249023
4.2857

In [4]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cuda"


class environment():
    def reset(self, npoints, batchsize, nsamples):
        if npoints <= 3:
            print("Error: not enough points for valid problem instance")
            return
        self.batchsize = batchsize * nsamples #so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = torch.rand([batchsize, npoints - 3, 2], device = device).unsqueeze(1).expand(-1, nsamples, -1, -1).reshape(self.batchsize, npoints - 3, 2)
        self.corner_points = torch.tensor([[0, 0], [2, 0], [0, 2]], dtype = torch.float, device = device)
        self.points = torch.cat([self.corner_points.unsqueeze(0).expand(self.batchsize, -1, -1), self.points], dim = -2) #[batchsize * nsamples, npoints, 2]
        self.points_mask = torch.cat([torch.ones([self.batchsize, 3], dtype = torch.bool, device = device), torch.zeros([self.batchsize, npoints - 3], dtype = torch.bool, device = device)], dim = 1)
        self.points_sequence = torch.empty([self.batchsize, 0], dtype = torch.long, device = device)
        
        """use a trick, for the purpose of an 'external' triangle that is always left untouched, which means we don't have to deal with boundary edges as being different. external triangle is [0, 1, 2] traversed clockwise..."""
        self.partial_delaunay_triangles = torch.tensor([[0, 2, 1], [0, 1, 2]], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1, -1).contiguous() #[batchsize, ntriangles, 3] contains index of points, always anticlockwise
        self.partial_delaunay_edges = torch.tensor([5, 4, 3, 2, 1, 0], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1).contiguous() #[batchsize, ntriangles * 3] contains location of corresponding edge (edges go in order 01, 12, 20). Edges will always flip since triangles are stored anticlockwise.
        
        self.ntriangles = 2 #can store as scalar, since will always be the same
        self.cost = torch.zeros([self.batchsize], device = device)
    
    def update(self, point_index): #point_index is [batchsize]
        if point_index.size(0) != self.batchsize:
            print("Error: point_index.size() doesn't match expected size, should be [batchsize]")
            return
        if self.points_mask.gather(1, point_index.unsqueeze(1)).sum():
            print("Error: some points already added")
            return
        triangles_coordinates = self.points.gather(1, self.partial_delaunay_triangles.view(self.batchsize, self.ntriangles * 3).unsqueeze(2).expand(-1, -1, 2)).view(self.batchsize, self.ntriangles, 3, 2) #[batchsize, ntriangles, 3, 2]
        newpoint = self.points.gather(1, point_index.unsqueeze(1).unsqueeze(2).expand(self.batchsize, 1, 2)).squeeze(1) #[batchsize, 2]
        
        incircle_matrix = torch.cat([triangles_coordinates, newpoint.unsqueeze(1).unsqueeze(2).expand(-1, self.ntriangles, 1, -1)], dim = -2) #[batchsize, ntriangles, 4, 2]
        incircle_matrix = torch.cat([incircle_matrix, (incircle_matrix * incircle_matrix).sum(-1, keepdim = True), torch.ones([self.batchsize, self.ntriangles, 4, 1], device = device)], dim = -1) #[batchsize, ntriangles, 4, 4]
        incircle_test = incircle_matrix.det() > 0 #[batchsize, ntriangles], is True if inside incircle
        removed_edge_mask = incircle_test.unsqueeze(2).expand(-1, -1, 3).reshape(-1) #[batchsize * ntriangles * 3]
        
        edges = (self.partial_delaunay_edges + self.ntriangles * 3 * torch.arange(self.batchsize, device = device).unsqueeze(1)).view(-1) #[batchsize * ntriangles * 3]
        neighbouring_edge = edges.masked_select(removed_edge_mask)
        neighbouring_edge_mask = torch.zeros([self.batchsize * self.ntriangles * 3], device = device, dtype = torch.bool)
        neighbouring_edge_mask[neighbouring_edge] = True
        neighbouring_edge_mask = (neighbouring_edge_mask * removed_edge_mask.logical_not()) #[batchsize * ntriangles * 3]
        
        n_new_triangles = neighbouring_edge_mask.view(self.batchsize, -1).sum(-1) #[batchsize]
        
        new_point = point_index.unsqueeze(1).expand(-1, self.ntriangles * 3).masked_select(neighbouring_edge_mask.view(self.batchsize, -1))
        
        second_point_mask = neighbouring_edge_mask.view(self.batchsize, -1, 3) #[batchsize, ntriangles 3]
        (first_point_indices0, first_point_indices1, first_point_indices2) = second_point_mask.nonzero(as_tuple = True)
        first_point_indices2 = (first_point_indices2 != 2) * (first_point_indices2 + 1)
        
        first_point = self.partial_delaunay_triangles[first_point_indices0, first_point_indices1, first_point_indices2] #[?]
        second_point = self.partial_delaunay_triangles.masked_select(second_point_mask) #[?]
        
        new_triangles_mask = torch.cat([incircle_test, torch.ones([self.batchsize, 2], dtype = torch.bool, device = device)], dim = 1) #[batchsize, ntriangles + 2]
        
        new_neighbouring_edges = 3 * new_triangles_mask.nonzero(as_tuple = True)[1] #[?], 3* since is the 01 edge of new triangles (see later)
        self.partial_delaunay_edges.masked_scatter_(neighbouring_edge_mask.view(self.batchsize, -1), new_neighbouring_edges) #still [batchsize, ntriangles * 3] for now
        
        self.partial_delaunay_triangles = torch.cat([self.partial_delaunay_triangles, torch.empty([self.batchsize, 2, 3], dtype = torch.long, device = device)], dim = 1)
        self.partial_delaunay_edges = torch.cat([self.partial_delaunay_edges, torch.empty([self.batchsize, 6], dtype = torch.long, device = device)], dim = 1)
        new_triangles = torch.stack([first_point, second_point, new_point], dim = 1) #[?, 3], edge here is flipped compared to edge in neighbouring triangle (so first_point is the second point in neighbouring edge)
        self.partial_delaunay_triangles.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3), new_triangles) #[batchsize, ntriangles + 2, 3]
        
        new_edge01 = neighbouring_edge_mask.view(self.batchsize, -1).nonzero(as_tuple = True)[1] #[?]
        
        """we are currently storing which triangles have to be inserted, via the edges along the perimeter of the delaunay cavity, we need to compute which edge is to the 'left'/'right' of each edge"""
        """don't have the memory to do a batchsize * n * n boolean search, don't have the speed to do a batchsize^2 search (as would be the case for sparse matrix or similar)"""
        """best alternative: rotate the edge around right point, repeat until hit edge in mask (will never go to an edge of a removed triangle before we hit edge in mask) should basically be order 1!!!!!"""
        
        neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?]
        next_neighbouring_edge_index = torch.empty_like(neighbouring_edge_index) #[?]
        
        rotating_flipped_neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?], initialise
        todo_mask = torch.ones_like(next_neighbouring_edge_index, dtype = torch.bool) #[?]
        while todo_mask.sum():
            rotating_neighbouring_edge_index = rotating_flipped_neighbouring_edge_index + 1 - 3 * (rotating_flipped_neighbouring_edge_index % 3 == 2) #[todo_mask.sum()], gets smaller until nothing left EFFICIENCY (this may be seriously stupid, as it requires making a bunch of copies when I could be doing stuff inplace)
            
            update_mask = neighbouring_edge_mask[rotating_neighbouring_edge_index] #[todo_mask.sum()]
            update_mask_unravel = torch.zeros_like(todo_mask).masked_scatter(todo_mask, update_mask) #[?]
            
            next_neighbouring_edge_index.masked_scatter_(update_mask_unravel, rotating_neighbouring_edge_index.masked_select(update_mask)) #[?]
            
            todo_mask.masked_fill_(update_mask_unravel, False) #[?]
            rotating_flipped_neighbouring_edge_index = edges[rotating_neighbouring_edge_index.masked_select(update_mask.logical_not())] #[todo_mask.sum()]
        triangle_index = new_triangles_mask.view(-1).nonzero(as_tuple = True)[0] #[?], index goes up to batchsize * (ntriangles + 2), this is needed for when we invert the permutation by scattering (won't scatter same number of triangles per batch)
        
        next_triangle_index = torch.empty_like(edges).masked_scatter_(neighbouring_edge_mask, triangle_index)[next_neighbouring_edge_index] #[?], index goes up to batchsize * (ntriangles + 2)
        next_edge = 3 * next_triangle_index + 1 #[?]
        
        invert_permutation = torch.empty_like(new_triangles_mask.view(-1), dtype=torch.long) #[batchsize * (ntriangles + 2)]
        invert_permutation[next_triangle_index] = triangle_index #[batchsize * (ntriangles + 2)]
        previous_triangle_index = invert_permutation.masked_select(new_triangles_mask.view(-1)) #[?]
        previous_edge = 3 * previous_triangle_index + 2 #[?]
        
        """in the above we rotated around 'first_point' in our new triangles"""
        new_edge20 = next_edge % ((self.ntriangles + 2) * 3) #[?]
        new_edge12 = previous_edge % ((self.ntriangles + 2) * 3) #[?]
        
        new_edges = torch.stack([new_edge01, new_edge12, new_edge20], dim = 1) #[?, 3]
        self.partial_delaunay_edges.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3).reshape(self.batchsize, -1), new_edges) #[batchsize, (ntriangles + 2) * 3]
        
        self.ntriangles += 2
        """currently only count the extra triangles you replace (not the on you have to remove because you're located there, and not the ones you make because you have to create two more"""
        self.cost += (n_new_triangles - 3)
        self.points_mask.scatter_(1, point_index.unsqueeze(1).expand(-1, self.npoints), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
    
    def allindices(self): #generate all orders of point insertion
        npoints = self.npoints - 3
        allroutes = torch.empty([1, 0], dtype = torch.long, device = device)
        for i in range(npoints):
            nroutes = allroutes.size(0)
            remaining_mask = torch.ones([nroutes], dtype = torch.bool, device = device).unsqueeze(1).expand(-1, npoints).clone().scatter_(-1, allroutes, False)
            remaining_indices = remaining_mask.nonzero(as_tuple = True)[1]
            allroutes = allroutes.unsqueeze(1).expand(-1, remaining_mask[0, :].sum(), -1)
            allroutes = torch.cat([allroutes, remaining_indices.view(nroutes, -1).unsqueeze(2)], dim = -1).view(-1, allroutes.size(-1) + 1)
        return allroutes #[npoints!, npoints]


env = environment()
npoints = 10
minibatchsize = 1
nbatches = 1
env.reset(npoints + 3, minibatchsize, math.factorial(npoints))
allroutes = env.allindices() + 3
allroutes = allroutes.unsqueeze(0).expand(minibatchsize, -1, -1).reshape(-1, npoints)
for j in range(300):
    for i in range(npoints):
        env.update(allroutes[:, i])
    print(env.cost.view(batchsize, -1).min(-1)[0].mean().item(), env.cost.mean().item())
    env.reset(npoints + 3, minibatchsize, math.factorial(npoints))

4.857142925262451 9.414966583251953
4.0 9.334183692932129
4.5714287757873535 9.08231258392334
4.0 9.242517471313477
4.428571701049805 9.127891540527344
3.7142858505249023 9.360713958740234
4.5714287757873535 9.433503150939941
4.142857551574707 9.15799331665039
4.857142925262451 9.376020431518555
4.5714287757873535 9.37704086303711
4.714285850524902 9.245918273925781
4.285714626312256 9.0469388961792
4.0 9.293367385864258
4.428571701049805 9.130271911621094
4.5714287757873535 9.245068550109863
4.714285850524902 9.286395072937012
3.7142858505249023 9.128401756286621
4.0 9.243537902832031
4.142857551574707 9.222108840942383
4.5714287757873535 9.41972827911377
3.4285717010498047 9.317007064819336
4.5714287757873535 9.14234733581543
4.714285850524902 9.497109413146973
4.285714626312256 9.22465991973877
4.857142925262451 9.20272159576416
4.428571701049805 9.261734962463379
4.285714626312256 9.303231239318848
4.0 9.463775634765625
4.285714626312256 9.194387435913086
4.0 9.091496467590332
4.42

4.0 9.26547622680664
4.0 9.28282356262207
4.5714287757873535 9.133163452148438
4.142857551574707 9.360203742980957
4.5714287757873535 9.146939277648926
4.857142925262451 9.324830055236816
4.714285850524902 9.2268705368042
4.428571701049805 9.190646171569824
4.0 9.347959518432617
4.285714626312256 9.479251861572266
3.7142858505249023 9.297449111938477
4.5714287757873535 9.240816116333008
4.0 9.330612182617188
4.0 9.251871109008789
4.5714287757873535 9.435033798217773
4.714285850524902 9.234354019165039
4.285714626312256 9.530101776123047
4.142857551574707 9.233674049377441
4.285714626312256 9.068197250366211
4.714285850524902 9.588775634765625
4.857142925262451 9.303231239318848
4.5714287757873535 9.308333396911621
4.714285850524902 9.427891731262207
4.0 9.187755584716797
4.0 9.069047927856445
4.0 9.324999809265137
4.428571701049805 9.345748901367188
4.5714287757873535 9.167346954345703
4.0 9.3869047164917
4.142857551574707 9.063776016235352
4.5714287757873535 9.239115715026855
4.0 9.12

In [None]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cpu"


class environment():
    def reset(self, npoints, batchsize, nsamples):
        if npoints <= 3:
            print("Error: not enough points for valid problem instance")
            return
        self.batchsize = batchsize * nsamples #so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = torch.rand([batchsize, npoints - 3, 2], device = device).unsqueeze(1).expand(-1, nsamples, -1, -1).reshape(self.batchsize, npoints - 3, 2)
        self.corner_points = torch.tensor([[0, 0], [2, 0], [0, 2]], dtype = torch.float, device = device)
        self.points = torch.cat([self.corner_points.unsqueeze(0).expand(self.batchsize, -1, -1), self.points], dim = -2) #[batchsize * nsamples, npoints, 2]
        self.points_mask = torch.cat([torch.ones([self.batchsize, 3], dtype = torch.bool, device = device), torch.zeros([self.batchsize, npoints - 3], dtype = torch.bool, device = device)], dim = 1)
        self.points_sequence = torch.empty([self.batchsize, 0], dtype = torch.long, device = device)
        
        """use a trick, for the purpose of an 'external' triangle that is always left untouched, which means we don't have to deal with boundary edges as being different. external triangle is [0, 1, 2] traversed clockwise..."""
        self.partial_delaunay_triangles = torch.tensor([[0, 2, 1], [0, 1, 2]], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1, -1).contiguous() #[batchsize, ntriangles, 3] contains index of points, always anticlockwise
        self.partial_delaunay_edges = torch.tensor([5, 4, 3, 2, 1, 0], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1).contiguous() #[batchsize, ntriangles * 3] contains location of corresponding edge (edges go in order 01, 12, 20). Edges will always flip since triangles are stored anticlockwise.
        
        self.ntriangles = 2 #can store as scalar, since will always be the same
        self.cost = torch.zeros([self.batchsize], device = device)
    
    def update(self, point_index): #point_index is [batchsize]
        if point_index.size(0) != self.batchsize:
            print("Error: point_index.size() doesn't match expected size, should be [batchsize]")
            return
        if self.points_mask.gather(1, point_index.unsqueeze(1)).sum():
            print("Error: some points already added")
            return
        triangles_coordinates = self.points.gather(1, self.partial_delaunay_triangles.view(self.batchsize, self.ntriangles * 3).unsqueeze(2).expand(-1, -1, 2)).view(self.batchsize, self.ntriangles, 3, 2) #[batchsize, ntriangles, 3, 2]
        newpoint = self.points.gather(1, point_index.unsqueeze(1).unsqueeze(2).expand(self.batchsize, 1, 2)).squeeze(1) #[batchsize, 2]
        
        incircle_matrix = torch.cat([triangles_coordinates, newpoint.unsqueeze(1).unsqueeze(2).expand(-1, self.ntriangles, 1, -1)], dim = -2) #[batchsize, ntriangles, 4, 2]
        incircle_matrix = torch.cat([incircle_matrix, (incircle_matrix * incircle_matrix).sum(-1, keepdim = True), torch.ones([self.batchsize, self.ntriangles, 4, 1], device = device)], dim = -1) #[batchsize, ntriangles, 4, 4]
        incircle_test = incircle_matrix.det() > 0 #[batchsize, ntriangles], is True if inside incircle
        removed_edge_mask = incircle_test.unsqueeze(2).expand(-1, -1, 3).reshape(-1) #[batchsize * ntriangles * 3]
        
        edges = (self.partial_delaunay_edges + self.ntriangles * 3 * torch.arange(self.batchsize, device = device).unsqueeze(1)).view(-1) #[batchsize * ntriangles * 3]
        neighbouring_edge = edges.masked_select(removed_edge_mask)
        neighbouring_edge_mask = torch.zeros([self.batchsize * self.ntriangles * 3], device = device, dtype = torch.bool)
        neighbouring_edge_mask[neighbouring_edge] = True
        neighbouring_edge_mask = (neighbouring_edge_mask * removed_edge_mask.logical_not()) #[batchsize * ntriangles * 3]
        
        n_new_triangles = neighbouring_edge_mask.view(self.batchsize, -1).sum(-1) #[batchsize]
        
        new_point = point_index.unsqueeze(1).expand(-1, self.ntriangles * 3).masked_select(neighbouring_edge_mask.view(self.batchsize, -1))
        
        second_point_mask = neighbouring_edge_mask.view(self.batchsize, -1, 3) #[batchsize, ntriangles 3]
        (first_point_indices0, first_point_indices1, first_point_indices2) = second_point_mask.nonzero(as_tuple = True)
        first_point_indices2 = (first_point_indices2 != 2) * (first_point_indices2 + 1)
        
        first_point = self.partial_delaunay_triangles[first_point_indices0, first_point_indices1, first_point_indices2] #[?]
        second_point = self.partial_delaunay_triangles.masked_select(second_point_mask) #[?]
        
        new_triangles_mask = torch.cat([incircle_test, torch.ones([self.batchsize, 2], dtype = torch.bool, device = device)], dim = 1) #[batchsize, ntriangles + 2]
        
        new_neighbouring_edges = 3 * new_triangles_mask.nonzero(as_tuple = True)[1] #[?], 3* since is the 01 edge of new triangles (see later)
        self.partial_delaunay_edges.masked_scatter_(neighbouring_edge_mask.view(self.batchsize, -1), new_neighbouring_edges) #still [batchsize, ntriangles * 3] for now
        
        self.partial_delaunay_triangles = torch.cat([self.partial_delaunay_triangles, torch.empty([self.batchsize, 2, 3], dtype = torch.long, device = device)], dim = 1)
        self.partial_delaunay_edges = torch.cat([self.partial_delaunay_edges, torch.empty([self.batchsize, 6], dtype = torch.long, device = device)], dim = 1)
        new_triangles = torch.stack([first_point, second_point, new_point], dim = 1) #[?, 3], edge here is flipped compared to edge in neighbouring triangle (so first_point is the second point in neighbouring edge)
        self.partial_delaunay_triangles.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3), new_triangles) #[batchsize, ntriangles + 2, 3]
        
        new_edge01 = neighbouring_edge_mask.view(self.batchsize, -1).nonzero(as_tuple = True)[1] #[?]
        
        """we are currently storing which triangles have to be inserted, via the edges along the perimeter of the delaunay cavity, we need to compute which edge is to the 'left'/'right' of each edge"""
        """don't have the memory to do a batchsize * n * n boolean search, don't have the speed to do a batchsize^2 search (as would be the case for sparse matrix or similar)"""
        """best alternative: rotate the edge around right point, repeat until hit edge in mask (will never go to an edge of a removed triangle before we hit edge in mask) should basically be order 1!!!!!"""
        
        neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?]
        next_neighbouring_edge_index = torch.empty_like(neighbouring_edge_index) #[?]
        
        rotating_flipped_neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?], initialise
        todo_mask = torch.ones_like(next_neighbouring_edge_index, dtype = torch.bool) #[?]
        while todo_mask.sum():
            rotating_neighbouring_edge_index = rotating_flipped_neighbouring_edge_index + 1 - 3 * (rotating_flipped_neighbouring_edge_index % 3 == 2) #[todo_mask.sum()], gets smaller until nothing left EFFICIENCY (this may be seriously stupid, as it requires making a bunch of copies when I could be doing stuff inplace)
            
            update_mask = neighbouring_edge_mask[rotating_neighbouring_edge_index] #[todo_mask.sum()]
            update_mask_unravel = torch.zeros_like(todo_mask).masked_scatter(todo_mask, update_mask) #[?]
            
            next_neighbouring_edge_index.masked_scatter_(update_mask_unravel, rotating_neighbouring_edge_index.masked_select(update_mask)) #[?]
            
            todo_mask.masked_fill_(update_mask_unravel, False) #[?]
            rotating_flipped_neighbouring_edge_index = edges[rotating_neighbouring_edge_index.masked_select(update_mask.logical_not())] #[todo_mask.sum()]
        triangle_index = new_triangles_mask.view(-1).nonzero(as_tuple = True)[0] #[?], index goes up to batchsize * (ntriangles + 2), this is needed for when we invert the permutation by scattering (won't scatter same number of triangles per batch)
        
        next_triangle_index = torch.empty_like(edges).masked_scatter_(neighbouring_edge_mask, triangle_index)[next_neighbouring_edge_index] #[?], index goes up to batchsize * (ntriangles + 2)
        next_edge = 3 * next_triangle_index + 1 #[?]
        
        invert_permutation = torch.empty_like(new_triangles_mask.view(-1), dtype=torch.long) #[batchsize * (ntriangles + 2)]
        invert_permutation[next_triangle_index] = triangle_index #[batchsize * (ntriangles + 2)]
        previous_triangle_index = invert_permutation.masked_select(new_triangles_mask.view(-1)) #[?]
        previous_edge = 3 * previous_triangle_index + 2 #[?]
        
        """in the above we rotated around 'first_point' in our new triangles"""
        new_edge20 = next_edge % ((self.ntriangles + 2) * 3) #[?]
        new_edge12 = previous_edge % ((self.ntriangles + 2) * 3) #[?]
        
        new_edges = torch.stack([new_edge01, new_edge12, new_edge20], dim = 1) #[?, 3]
        self.partial_delaunay_edges.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3).reshape(self.batchsize, -1), new_edges) #[batchsize, (ntriangles + 2) * 3]
        
        self.ntriangles += 2
        """currently only count the extra triangles you replace (not the on you have to remove because you're located there, and not the ones you make because you have to create two more"""
        self.cost += (n_new_triangles - 3)
        self.points_mask.scatter_(1, point_index.unsqueeze(1).expand(-1, self.npoints), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
    
    def allindices(self): #generate all orders of point insertion
        npoints = self.npoints - 3
        allroutes = torch.empty([1, 0], dtype = torch.long, device = device)
        for i in range(npoints):
            nroutes = allroutes.size(0)
            remaining_mask = torch.ones([nroutes], dtype = torch.bool, device = device).unsqueeze(1).expand(-1, npoints).clone().scatter_(-1, allroutes, False)
            remaining_indices = remaining_mask.nonzero(as_tuple = True)[1]
            allroutes = allroutes.unsqueeze(1).expand(-1, remaining_mask[0, :].sum(), -1)
            allroutes = torch.cat([allroutes, remaining_indices.view(nroutes, -1).unsqueeze(2)], dim = -1).view(-1, allroutes.size(-1) + 1)
        return allroutes #[npoints!, npoints]


env = environment()
npoints = 10
minibatchsize = 1
nbatches = 1
env.reset(npoints + 3, minibatchsize, math.factorial(npoints))
allroutes = env.allindices() + 3
allroutes = allroutes.unsqueeze(0).expand(minibatchsize, -1, -1).reshape(-1, npoints)
for j in range(300):
    for i in range(npoints):
        env.update(allroutes[:, i])
    print(env.cost.view(minibatchsize, -1).min(-1)[0].mean().item(), env.cost.mean().item())
    env.reset(npoints + 3, minibatchsize, math.factorial(npoints))

6.0 13.032936096191406
5.0 12.896031379699707
7.0 13.512301445007324
5.0 13.460317611694336
8.0 13.252778053283691
7.0 12.421428680419922
6.0 13.231348991394043
6.0 13.295635223388672
7.0 12.44603157043457
6.0 13.15238094329834
5.0 13.85793685913086
7.0 13.214285850524902
7.0 12.856745719909668
6.0 13.197221755981445


In [32]:
allroutes

tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 4, 3],
        [0, 1, 3, 2, 4],
        [0, 1, 3, 4, 2],
        [0, 1, 4, 2, 3],
        [0, 1, 4, 3, 2],
        [0, 2, 1, 3, 4],
        [0, 2, 1, 4, 3],
        [0, 2, 3, 1, 4],
        [0, 2, 3, 4, 1],
        [0, 2, 4, 1, 3],
        [0, 2, 4, 3, 1],
        [0, 3, 1, 2, 4],
        [0, 3, 1, 4, 2],
        [0, 3, 2, 1, 4],
        [0, 3, 2, 4, 1],
        [0, 3, 4, 1, 2],
        [0, 3, 4, 2, 1],
        [0, 4, 1, 2, 3],
        [0, 4, 1, 3, 2],
        [0, 4, 2, 1, 3],
        [0, 4, 2, 3, 1],
        [0, 4, 3, 1, 2],
        [0, 4, 3, 2, 1],
        [1, 0, 2, 3, 4],
        [1, 0, 2, 4, 3],
        [1, 0, 3, 2, 4],
        [1, 0, 3, 4, 2],
        [1, 0, 4, 2, 3],
        [1, 0, 4, 3, 2],
        [1, 2, 0, 3, 4],
        [1, 2, 0, 4, 3],
        [1, 2, 3, 0, 4],
        [1, 2, 3, 4, 0],
        [1, 2, 4, 0, 3],
        [1, 2, 4, 3, 0],
        [1, 3, 0, 2, 4],
        [1, 3, 0, 4, 2],
        [1, 3, 2, 0, 4],
        [1, 3, 2, 4, 0],
