In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cuda:1"

batchsize = 512
nsamples = 8
npoints = 5
emsize = 512


class Graph_Transformer(nn.Module):
    def __init__(self, emsize = 32, nhead = 8, nhid = 1024, nlayers = 3, ndecoderlayers = 0, dropout = 0):
        super().__init__()
        self.emsize = emsize
        from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
        encoder_layers = TransformerEncoderLayer(emsize, nhead, nhid, dropout = dropout)
        decoder_layers = TransformerDecoderLayer(emsize, nhead, nhid, dropout = dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.transformer_decoder = TransformerDecoder(decoder_layers, ndecoderlayers)
        self.encoder = nn.Linear(2, emsize)
        self.outputattention_query = nn.Linear(emsize, emsize, bias = False)
        self.outputattention_key = nn.Linear(emsize, emsize, bias = False)
        self.start_token = nn.Parameter(torch.randn([emsize], device = device))
        
        self.lineartest0 = nn.Linear(2, emsize)
        self.lineartest1 = nn.Linear(2, emsize)
        self.lineartest2 = nn.Linear(2, emsize)
        self.lineartest3 = nn.Linear(2 * emsize, emsize)
    
    def generate_subsequent_mask(self, sz): #last dimension will be softmaxed over when adding to attention logits, if boolean the ones turn into -inf
        #mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        #mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        
        mask = torch.triu(torch.ones([sz, sz], dtype = torch.bool, device = device), diagonal = 1)
        return mask
    
    def encode(self, src): #src must be [batchsize * nsamples, npoints, 2]
        src = self.encoder(src).transpose(0, 1)
        output = self.transformer_encoder(src)
        return output #[npoints, batchsize * nsamples, emsize]
    
    def test_encode(self, src): #src must be [batchsize * nsamples, npoints = 5, 2]
        point_1 = self.lineartest1(src[:, 3, :] - 0.5) #want mean 0!!!!
        point_2 = self.lineartest1(src[:, 4, :] - 0.5)
        remaining = self.lineartest0(src[:, :3, :]) #[batchsize * nsamples, 3, emsize]
        point_1_message = self.lineartest1(src[:, 3, :] - 0.5).squeeze(1)
        point_2_message = self.lineartest2(src[:, 4, :] - 0.5).squeeze(1)
        point_1 = F.relu(torch.cat([point_1, point_2_message], dim = 1))
        point_2 = F.relu(torch.cat([point_2, point_1_message], dim = 1))
        point_1 = self.lineartest3(point_1)
        point_2 = self.lineartest3(point_2)
        src = torch.cat([torch.zeros_like(remaining.transpose(0, 1)), point_1.unsqueeze(0), point_2.unsqueeze(0)])
        return src #[npoints = 5, batchsize * nsamples, emsize]
    
    def decode_next(self, memory, tgt, route_mask): #route mask is [batchsize * nsamples, npoints], both memory and tgt must have batchsize and nsamples in same dimension (the 1th one)
        npoints = memory.size(0)
        batchsize = tgt.size(1)
        """if I really wanted this to be efficient I'd only recompute the decoder for the last tgt, and just remebering what the others looked like from before (won't change due to mask)"""
        """have the option to freeze the autograd on all but the last part of tgt, although at the moment this is a very natural way to say: initial choices matter more"""
        tgt_mask = self.generate_subsequent_mask(tgt.size(0))
        output = self.transformer_decoder(tgt, memory, tgt_mask) #[tgt, batchsize * nsamples, emsize]
        output_query = self.outputattention_query(memory).transpose(0, 1) #[batchsize * nsamples, npoints, emsize]
        output_key = self.outputattention_key(output[-1]) #[batchsize * nsamples, emsize]
        output_attention = torch.matmul(output_query * self.emsize ** -0.5, output_key.unsqueeze(-1)).squeeze(-1) #[batchsize * nsamples, npoints], technically don't need to scale attention as we divide by variance next anyway
        output_attention_tanh = output_attention.tanh() #[batchsize * nsamples, npoints]
        
        #we clone the route_mask incase we want to backprop using it (else it was modified by inplace opporations)
        output_attention = output_attention.masked_fill(route_mask.clone(), float('-inf')) #[batchsize * nsamples, npoints]
        output_attention_tanh = output_attention_tanh.masked_fill(route_mask.clone(), float('-inf')) #[batchsize * nsamples, npoints]
        
        return output_attention_tanh, output_attention #[batchsize * nsamples, npoints]
    
    def calculate_logprob(self, memory, routes): #memory is [npoints, batchsize * nsamples, emsize], routes is [batchsize * nsamples, npoints - 3], rather than backproping the entire loop, this saves vram (and computation)
        npoints = memory.size(0)
        ninternalpoints = routes.size(1)
        bigbatchsize = memory.size(1)
        memory_ = memory.gather(0, routes.transpose(0, 1).unsqueeze(2).expand(-1, -1, self.emsize)) #[npoints - 3, batchsize * nsamples, emsize] reorder memory into order of routes
        tgt = torch.cat([self.start_token.unsqueeze(0).unsqueeze(1).expand(1, bigbatchsize, -1), memory_[:-1]]) #[npoints - 3, batchsize * nroutes, emsize], want to go from memory to tgt
        tgt_mask = self.generate_subsequent_mask(ninternalpoints)
        output = self.transformer_decoder(tgt, memory, tgt_mask) #[npoints - 3, batchsize * nsamples, emsize]
        """want probability of going from key to query, but first need to normalise (softmax with mask)"""
        output_query = self.outputattention_query(memory_).transpose(0, 1) #[batchsize * nsamples, npoints - 3, emsize]
        output_key = self.outputattention_key(output).transpose(0, 1) #[batchsize * nsamples, npoints - 3, emsize]
        attention_mask = torch.full([ninternalpoints, ninternalpoints], True, dtype = torch.bool, device = device).triu(1) #[npoints - 3, npoints - 3], True for i < j
        output_attention = torch.matmul(output_query * self.emsize ** -0.5, output_key.transpose(-1, -2))
        """quick fix to stop divergence"""
        output_attention_tanh = output_attention.tanh()
        
        output_attention_tanh = output_attention_tanh.masked_fill(attention_mask, float('-inf'))
        output_attention_tanh = output_attention_tanh - output_attention_tanh.logsumexp(-2, keepdim = True) #[batchsize * nsamples, npoints - 3, npoints - 3]
        
        output_attention = output_attention.masked_fill(attention_mask, float('-inf'))
        output_attention = output_attention - output_attention.logsumexp(-2, keepdim = True) #[batchsize * nsamples, npoints - 3, npoints - 3]
        
        """infact I'm almost tempted to not mask choosing a previous point, so it's forced to learn it and somehow incorporate it into its computation, but without much impact on reinforcing good examples"""
        logprob_tanh = output_attention_tanh.diagonal(dim1 = -1, dim2 = -2).sum(-1) #[batchsize * nsamples]
        logprob = output_attention.diagonal(dim1 = -1, dim2 = -2).sum(-1) #[batchsize * nsamples]
        return logprob_tanh, logprob #[batchsize * nsamples]

NN = Graph_Transformer().to(device)
optimizer = optim.Adam(NN.parameters())


class environment():
    def reset(self, npoints, batchsize, nsamples):
        if npoints <= 3:
            print("Error: not enough points for valid problem instance")
            return
        self.batchsize = batchsize * nsamples #so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = torch.rand([batchsize, npoints - 3, 2], device = device).unsqueeze(1).expand(-1, nsamples, -1, -1).reshape(self.batchsize, npoints - 3, 2)
        self.corner_points = torch.tensor([[0, 0], [2, 0], [0, 2]], dtype = torch.float, device = device)
        self.points = torch.cat([self.corner_points.unsqueeze(0).expand(self.batchsize, -1, -1), self.points], dim = -2) #[batchsize * nsamples, npoints, 2]
        self.points_mask = torch.cat([torch.ones([self.batchsize, 3], dtype = torch.bool, device = device), torch.zeros([self.batchsize, npoints - 3], dtype = torch.bool, device = device)], dim = 1)
        self.points_sequence = torch.empty([self.batchsize, 0], dtype = torch.long, device = device)
        
        """use a trick, for the purpose of an 'external' triangle that is always left untouched, which means we don't have to deal with boundary edges as being different. external triangle is [0, 1, 2] traversed clockwise..."""
        self.partial_delaunay_triangles = torch.tensor([[0, 2, 1], [0, 1, 2]], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1, -1).contiguous() #[batchsize, ntriangles, 3] contains index of points, always anticlockwise
        self.partial_delaunay_edges = torch.tensor([5, 4, 3, 2, 1, 0], dtype = torch.int64, device = device).unsqueeze(0).expand(self.batchsize, -1).contiguous() #[batchsize, ntriangles * 3] contains location of corresponding edge (edges go in order 01, 12, 20). Edges will always flip since triangles are stored anticlockwise.
        
        self.ntriangles = 2 #can store as scalar, since will always be the same
        self.cost = torch.zeros([self.batchsize], device = device)
        
        self.logprob = torch.zeros([self.batchsize], device = device, requires_grad = True)
    
    def update(self, point_index): #point_index is [batchsize]
        if point_index.size(0) != self.batchsize:
            print("Error: point_index.size() doesn't match expected size, should be [batchsize]")
            return
        if self.points_mask.gather(1, point_index.unsqueeze(1)).sum():
            print("Error: some points already added")
            return
        triangles_coordinates = self.points.gather(1, self.partial_delaunay_triangles.view(self.batchsize, self.ntriangles * 3).unsqueeze(2).expand(-1, -1, 2)).view(self.batchsize, self.ntriangles, 3, 2) #[batchsize, ntriangles, 3, 2]
        newpoint = self.points.gather(1, point_index.unsqueeze(1).unsqueeze(2).expand(self.batchsize, 1, 2)).squeeze(1) #[batchsize, 2]
        
        incircle_matrix = torch.cat([triangles_coordinates, newpoint.unsqueeze(1).unsqueeze(2).expand(-1, self.ntriangles, 1, -1)], dim = -2) #[batchsize, ntriangles, 4, 2]
        incircle_matrix = torch.cat([incircle_matrix, (incircle_matrix * incircle_matrix).sum(-1, keepdim = True), torch.ones([self.batchsize, self.ntriangles, 4, 1], device = device)], dim = -1) #[batchsize, ntriangles, 4, 4]
        incircle_test = incircle_matrix.det() > 0 #[batchsize, ntriangles], is True if inside incircle
        removed_edge_mask = incircle_test.unsqueeze(2).expand(-1, -1, 3).reshape(-1) #[batchsize * ntriangles * 3]
        
        edges = (self.partial_delaunay_edges + self.ntriangles * 3 * torch.arange(self.batchsize, device = device).unsqueeze(1)).view(-1) #[batchsize * ntriangles * 3]
        neighbouring_edge = edges.masked_select(removed_edge_mask)
        neighbouring_edge_mask = torch.zeros([self.batchsize * self.ntriangles * 3], device = device, dtype = torch.bool)
        neighbouring_edge_mask[neighbouring_edge] = True
        neighbouring_edge_mask = (neighbouring_edge_mask * removed_edge_mask.logical_not()) #[batchsize * ntriangles * 3]
        
        n_new_triangles = neighbouring_edge_mask.view(self.batchsize, -1).sum(-1) #[batchsize]
        
        new_point = point_index.unsqueeze(1).expand(-1, self.ntriangles * 3).masked_select(neighbouring_edge_mask.view(self.batchsize, -1))
        
        second_point_mask = neighbouring_edge_mask.view(self.batchsize, -1, 3) #[batchsize, ntriangles 3]
        (first_point_indices0, first_point_indices1, first_point_indices2) = second_point_mask.nonzero(as_tuple = True)
        first_point_indices2 = (first_point_indices2 != 2) * (first_point_indices2 + 1)
        
        first_point = self.partial_delaunay_triangles[first_point_indices0, first_point_indices1, first_point_indices2] #[?]
        second_point = self.partial_delaunay_triangles.masked_select(second_point_mask) #[?]
        
        new_triangles_mask = torch.cat([incircle_test, torch.ones([self.batchsize, 2], dtype = torch.bool, device = device)], dim = 1) #[batchsize, ntriangles + 2]
        
        new_neighbouring_edges = 3 * new_triangles_mask.nonzero(as_tuple = True)[1] #[?], 3* since is the 01 edge of new triangles (see later)
        self.partial_delaunay_edges.masked_scatter_(neighbouring_edge_mask.view(self.batchsize, -1), new_neighbouring_edges) #still [batchsize, ntriangles * 3] for now
        
        self.partial_delaunay_triangles = torch.cat([self.partial_delaunay_triangles, torch.empty([self.batchsize, 2, 3], dtype = torch.long, device = device)], dim = 1)
        self.partial_delaunay_edges = torch.cat([self.partial_delaunay_edges, torch.empty([self.batchsize, 6], dtype = torch.long, device = device)], dim = 1)
        new_triangles = torch.stack([first_point, second_point, new_point], dim = 1) #[?, 3], edge here is flipped compared to edge in neighbouring triangle (so first_point is the second point in neighbouring edge)
        self.partial_delaunay_triangles.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3), new_triangles) #[batchsize, ntriangles + 2, 3]
        
        new_edge01 = neighbouring_edge_mask.view(self.batchsize, -1).nonzero(as_tuple = True)[1] #[?]
        
        """we are currently storing which triangles have to be inserted, via the edges along the perimeter of the delaunay cavity, we need to compute which edge is to the 'left'/'right' of each edge"""
        """don't have the memory to do a batchsize * n * n boolean search, don't have the speed to do a batchsize^2 search (as would be the case for sparse matrix or similar)"""
        """best alternative: rotate the edge around right point, repeat until hit edge in mask (will never go to an edge of a removed triangle before we hit edge in mask) should basically be order 1!!!!!"""
        
        neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?]
        next_neighbouring_edge_index = torch.empty_like(neighbouring_edge_index) #[?]
        
        rotating_flipped_neighbouring_edge_index = neighbouring_edge_mask.nonzero(as_tuple = True)[0] #[?], initialise
        todo_mask = torch.ones_like(next_neighbouring_edge_index, dtype = torch.bool) #[?]
        while todo_mask.sum():
            rotating_neighbouring_edge_index = rotating_flipped_neighbouring_edge_index + 1 - 3 * (rotating_flipped_neighbouring_edge_index % 3 == 2) #[todo_mask.sum()], gets smaller until nothing left EFFICIENCY (this may be seriously stupid, as it requires making a bunch of copies when I could be doing stuff inplace)
            
            update_mask = neighbouring_edge_mask[rotating_neighbouring_edge_index] #[todo_mask.sum()]
            update_mask_unravel = torch.zeros_like(todo_mask).masked_scatter(todo_mask, update_mask) #[?]
            
            next_neighbouring_edge_index.masked_scatter_(update_mask_unravel, rotating_neighbouring_edge_index.masked_select(update_mask)) #[?]
            
            todo_mask.masked_fill_(update_mask_unravel, False) #[?]
            rotating_flipped_neighbouring_edge_index = edges[rotating_neighbouring_edge_index.masked_select(update_mask.logical_not())] #[todo_mask.sum()]
        triangle_index = new_triangles_mask.view(-1).nonzero(as_tuple = True)[0] #[?], index goes up to batchsize * (ntriangles + 2), this is needed for when we invert the permutation by scattering (won't scatter same number of triangles per batch)
        
        next_triangle_index = torch.empty_like(edges).masked_scatter_(neighbouring_edge_mask, triangle_index)[next_neighbouring_edge_index] #[?], index goes up to batchsize * (ntriangles + 2)
        next_edge = 3 * next_triangle_index + 1 #[?]
        
        invert_permutation = torch.empty_like(new_triangles_mask.view(-1), dtype=torch.long) #[batchsize * (ntriangles + 2)]
        invert_permutation[next_triangle_index] = triangle_index #[batchsize * (ntriangles + 2)]
        previous_triangle_index = invert_permutation.masked_select(new_triangles_mask.view(-1)) #[?]
        previous_edge = 3 * previous_triangle_index + 2 #[?]
        
        """in the above we rotated around 'first_point' in our new triangles"""
        new_edge20 = next_edge % ((self.ntriangles + 2) * 3) #[?]
        new_edge12 = previous_edge % ((self.ntriangles + 2) * 3) #[?]
        
        new_edges = torch.stack([new_edge01, new_edge12, new_edge20], dim = 1) #[?, 3]
        self.partial_delaunay_edges.masked_scatter_(new_triangles_mask.unsqueeze(2).expand(-1, -1, 3).reshape(self.batchsize, -1), new_edges) #[batchsize, (ntriangles + 2) * 3]
        
        self.ntriangles += 2
        """currently only count the extra triangles you replace (not the on you have to remove because you're located there, and not the ones you make because you have to create two more"""
        self.cost += (n_new_triangles - 3)
        self.points_mask.scatter_(1, point_index.unsqueeze(1).expand(-1, self.npoints), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
    
    def sample_point(self, logits): #logits must be [batchsize * nsamples, npoints]
        probs = torch.distributions.categorical.Categorical(logits = logits)
        next_point = probs.sample() #size is [batchsize * nsamples]
        self.update(next_point)
        self.logprob = self.logprob + probs.log_prob(next_point)
        return next_point #[batchsize * nsamples]
    
    def sampleandgreedy_point(self, logits): #logits must be [batchsize * nsamples, npoints], last sample will be the greedy choice (but we still need to keep track of its logits)
        logits_sample = logits.view(-1, self.nsamples, self.npoints)[:, :-1, :]
        probs = torch.distributions.categorical.Categorical(logits = logits_sample)
        
        sample_point = probs.sample() #[batchsize, (nsamples - 1)]
        greedy_point = logits.view(-1, self.nsamples, self.npoints)[:, -1, :].max(-1, keepdim = True)[1] #[batchsize, 1]
        next_point = torch.cat([sample_point, greedy_point], dim = 1).view(-1)
        self.update(next_point)
        self.logprob = self.logprob + torch.cat([probs.log_prob(sample_point), torch.zeros([sample_point.size(0), 1], device = device)], dim = 1).view(-1)
        return next_point
    

env = environment()


def train(epochs = 300000, npoints = 13, batchsize = 2000, nsamples = 8):
    NN.train()
    for i in range(epochs):
        env.reset(npoints, batchsize, nsamples)
        """include the boundary points, kinda makes sense that they should contribute (atm only in the encoder, difficult to see how in the decoder)"""
        memory = NN.encode(env.points) #[npoints, batchsize * nsamples, emsize]
        #### #### #### remember to include tgt.detach() when reinstate with torch.no_grad()
        tgt = NN.start_token.unsqueeze(0).unsqueeze(1).expand(1, batchsize * nsamples, -1).detach() #[1, batchsize * nsamples, emsize]
        #with torch.no_grad(): #to speed up computation, selecting routes is done without gradient
        NN.eval()
        with torch.no_grad():
            for j in range(3, npoints):
                #### #### #### remember to include memory.detach() when reinstate with torch.no_grad()
                _, logits = NN.decode_next(memory.detach(), tgt, env.points_mask)
                next_point = env.sampleandgreedy_point(logits)
                """
                for inputing the previous embedding into decoder
                """
                tgt = torch.cat([tgt, memory.gather(0, next_point.unsqueeze(0).unsqueeze(2).expand(1, -1, memory.size(2)))]) #[nsofar, batchsize * nsamples, emsize]
                """
                for inputing the previous decoder output into the decoder (allows for an evolving strategy, but doesn't allow for fast training
                """
                ############

        
        _, logprob = NN.calculate_logprob(memory, env.points_sequence) #[batchsize * nsamples]
        NN.train()
        """
        clip logprob so doesn't reinforce things it already knows
        TBH WANT SOMETHING DIFFERENT ... want to massively increase training if find something unexpected and otherwise not
        """
        greedy_prob = logprob.view(batchsize, nsamples)[:, -1].detach() #[batchsize]
        greedy_baseline = env.cost.view(batchsize, nsamples)[:, -1] #[batchsize], greedy sample
        fixed_baseline = 0.5 * torch.ones([1], device = device)
        min_baseline = env.cost.view(batchsize, nsamples)[:, :-1].min(-1)[0] #[batchsize], minimum cost
        baseline = greedy_baseline
        positive_reinforcement = - F.relu( - (env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1))) #don't scale positive reinforcement
        negative_reinforcement = F.relu(env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1))
        positive_reinforcement_binary = env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1) <= -0.05
        negative_reinforcement_binary = env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1) > 1
        """
        binary positive reinforcement
        """
        #loss = - ((logprob.view(batchsize, nsamples)[:, :-1] < -0.2) * logprob.view(batchsize, nsamples)[:, :-1] * positive_reinforcement_binary).mean() #+ (logprob.view(batchsize, nsamples)[:, :-1] > -1) * logprob.view(batchsize, nsamples)[:, :-1] * negative_reinforcement_binary
        """
        clipped binary reinforcement
        """
        loss = ( 
                - logprob.view(batchsize, nsamples)[:, :-1] 
                #* (logprob.view(batchsize, nsamples)[:, :-1] < 0) 
                * positive_reinforcement_binary 
                + logprob.view(batchsize, nsamples)[:, :-1] 
                #* (logprob.view(batchsize, nsamples)[:, :-1] > greedy_prob.unsqueeze(1) - 25) 
                * negative_reinforcement_binary 
        ).mean()
        """
        clipped binary postive, clipped weighted negative
        """
        #loss = ( - logprob.view(batchsize, nsamples)[:, :-1] * (logprob.view(batchsize, nsamples)[:, :-1] < -0.2) * positive_reinforcement_binary + logprob.view(batchsize, nsamples)[:, :-1] * (logprob.view(batchsize, nsamples)[:, :-1] > -2) * negative_reinforcement ).mean()
        """
        clipped reinforcement without rescaling
        """
        #loss = ((logprob.view(batchsize, nsamples)[:, :-1] < -0.7) * logprob.view(batchsize, nsamples)[:, :-1] * positive_reinforcement + (logprob.view(batchsize, nsamples)[:, :-1] > -5) * logprob.view(batchsize, nsamples)[:, :-1] * negative_reinforcement).mean()
        """
        clipped reinforcement
        """
        #loss = (logprob.view(batchsize, nsamples)[:, :-1] * positive_reinforcement / (positive_reinforcement.var() + 0.001).sqrt() + (logprob.view(batchsize, nsamples)[:, :-1] > -3) * logprob.view(batchsize, nsamples)[:, :-1] * negative_reinforcement / (negative_reinforcement.var() + 0.001).sqrt()).mean()
        """
        balanced reinforcement
        """
        #loss = (logprob.view(batchsize, nsamples)[:, :-1] * (positive_reinforcement / (positive_reinforcement.var() + 0.001).sqrt() + negative_reinforcement / (negative_reinforcement.var() + 0.001).sqrt())).mean()
        """
        regular loss
        """
        #loss = (logprob.view(batchsize, nsamples)[:, :-1] * (positive_reinforcement + negative_reinforcement)).mean()
        optimizer.zero_grad()
        loss.backward()
        #print(NN.encoder.weight.grad)
        optimizer.step()
        #print(greedy_baseline.mean().item())
        print(greedy_baseline.mean().item(), logprob.view(batchsize, nsamples)[:, -1].mean().item(), logprob.view(batchsize, nsamples)[:, :-1].mean().item(), logprob[batchsize - 1].item(), logprob[0].item(), env.logprob[0].item())
  

In [5]:
train(npoints = 53, batchsize = 350, nsamples = 8)

96.2028579711914 -2.817742109298706 -5.433173656463623 -7.984920501708984 -5.265766143798828 -5.265705108642578
95.13713836669922 -2.7170419692993164 -5.217680931091309 -7.64599609375 -5.500823974609375 -5.50067138671875
95.74857330322266 -2.7514443397521973 -5.246238708496094 -5.35504150390625 -4.742530822753906 -4.742591857910156
96.25142669677734 -2.6464622020721436 -5.108283042907715 -4.303173065185547 -6.6016082763671875 -6.601799011230469
95.77999877929688 -2.628082275390625 -5.098591327667236 -8.219062805175781 -5.448735237121582 -5.448828220367432
95.9142837524414 -2.7377288341522217 -5.225794792175293 -5.06256103515625 -2.46185302734375 -2.4619140625
95.6257095336914 -2.5805139541625977 -4.984772682189941 -4.4173126220703125 -1.70361328125 -1.703704833984375
96.40571594238281 -2.549777030944824 -5.058006286621094 -10.236221313476562 -1.364044189453125 -1.36407470703125
96.05428314208984 -2.7187411785125732 -5.216022491455078 -7.47216796875 -2.056049346923828 -2.056041717529297

94.85713958740234 -2.276479721069336 -4.3953728675842285 -3.0658397674560547 -1.1288604736328125 -1.1288604736328125
95.14571380615234 -2.269887685775757 -4.399417877197266 -1.5493345260620117 -4.729991912841797 -4.730083465576172
94.67428588867188 -2.3611514568328857 -4.566633224487305 -2.643310546875 -1.82257080078125 -1.822479248046875
94.03713989257812 -2.3356099128723145 -4.543118476867676 -3.5706329345703125 -3.021209716796875 -3.021209716796875
95.58000183105469 -2.281839370727539 -4.486057281494141 -1.9053659439086914 -7.021148681640625 -7.021209716796875
95.15714263916016 -2.3799145221710205 -4.58156156539917 -4.922651290893555 -5.881744384765625 -5.88165283203125
95.63713836669922 -2.280189275741577 -4.482450485229492 -1.9474849700927734 -4.261946678161621 -4.26185417175293
94.31428527832031 -2.2048237323760986 -4.256336688995361 -2.999908447265625 -0.5205078125 -0.5205078125
94.94571685791016 -2.238103151321411 -4.4544501304626465 -0.5386753082275391 -2.00201416015625 -2.001

94.67142486572266 -2.3800315856933594 -4.655252933502197 -5.230918884277344 -2.6638755798339844 -2.6638755798339844
94.53142547607422 -2.441631555557251 -4.692161560058594 -4.067265033721924 -6.360076904296875 -6.360107421875
94.22285461425781 -2.3381125926971436 -4.529796600341797 -5.018604278564453 -6.611236572265625 -6.611297607421875
93.79999542236328 -2.2905168533325195 -4.490406513214111 -3.7185287475585938 -6.353900909423828 -6.353900909423828
94.61428833007812 -2.3218724727630615 -4.517297744750977 -9.555667877197266 -5.57481575012207 -5.574850082397461
94.28571319580078 -2.3702123165130615 -4.688745975494385 -2.4048538208007812 -3.332489013671875 -3.332489013671875
94.14856719970703 -2.2996113300323486 -4.447468280792236 -3.1917877197265625 -2.2655715942382812 -2.2657241821289062
94.32571411132812 -2.3617630004882812 -4.501012325286865 -7.061457633972168 -1.4602909088134766 -1.460352897644043
94.24285888671875 -2.3255462646484375 -4.514655590057373 -4.949127197265625 -5.992942

94.58285522460938 -2.5224013328552246 -4.916953086853027 -4.8538818359375 -4.694797515869141 -4.694881439208984
94.4942855834961 -2.4894015789031982 -4.824251651763916 -1.2738456726074219 -5.37103271484375 -5.371185302734375
95.47999572753906 -2.62065052986145 -4.953302383422852 -3.662883758544922 -5.500640869140625 -5.500640869140625
95.05714416503906 -2.541916847229004 -4.813852310180664 -8.613540649414062 -2.8950843811035156 -2.8951454162597656
94.5857162475586 -2.522571325302124 -4.827142238616943 -8.083354949951172 -0.7862815856933594 -0.78631591796875
94.95143127441406 -2.4753334522247314 -4.745116710662842 -4.738481521606445 -7.339282989501953 -7.339347839355469
95.49713897705078 -2.3838250637054443 -4.696080207824707 -2.7295589447021484 -3.1089401245117188 -3.108959197998047
94.86285400390625 -2.339850425720215 -4.580604553222656 -4.311546325683594 -2.3746795654296875 -2.3748016357421875
94.47428131103516 -2.357893466949463 -4.529815196990967 -5.332035064697266 -3.3733406066894

94.58000183105469 -2.7711658477783203 -5.306251525878906 -4.2852582931518555 -2.1588287353515625 -2.15887451171875
94.94857025146484 -2.7569878101348877 -5.350726127624512 -5.3659515380859375 -8.000213623046875 -8.000244140625
94.92285919189453 -2.7505385875701904 -5.373875617980957 -7.39056921005249 -3.2291030883789062 -3.2289810180664062
94.37999725341797 -2.774604558944702 -5.35710334777832 -3.689992904663086 -7.0530595779418945 -7.0529680252075195
94.8228530883789 -2.72041654586792 -5.289041996002197 -8.11734676361084 -6.554668426513672 -6.554634094238281
94.71142578125 -2.7542483806610107 -5.366069793701172 -6.313694000244141 -3.8449859619140625 -3.8448638916015625
94.93714141845703 -2.735367774963379 -5.3448486328125 -4.116241455078125 -3.2806777954101562 -3.2807388305664062
93.95143127441406 -2.7402353286743164 -5.24866247177124 -2.461002826690674 -8.916141510009766 -8.91617202758789
95.01142883300781 -2.7974298000335693 -5.362894058227539 -3.2625675201416016 -4.011096954345703 

94.97428131103516 -1.691773533821106 -3.365349054336548 -0.9038257598876953 -3.488372802734375 -3.488433837890625
95.65142822265625 -1.7319509983062744 -3.2750821113586426 -1.1093902587890625 -1.8329849243164062 -1.8329849243164062
95.4028549194336 -1.6877137422561646 -3.3510076999664307 -6.609134674072266 -5.367616653442383 -5.367364883422852
94.53713989257812 -1.7413051128387451 -3.4209980964660645 -5.990745544433594 -2.230107307434082 -2.230168342590332
95.45428466796875 -1.6800751686096191 -3.255901575088501 -2.830455780029297 -4.661781311035156 -4.661872863769531
95.36856842041016 -1.6205215454101562 -3.153916835784912 -3.642688751220703 -7.152423858642578 -7.152545928955078
95.4085693359375 -1.6314133405685425 -3.2466280460357666 -5.242618560791016 -1.7973594665527344 -1.797332763671875
95.12000274658203 -1.6962238550186157 -3.24605131149292 -2.655059814453125 -2.6165237426757812 -2.6165847778320312
95.49142456054688 -1.6722311973571777 -3.370788097381592 -2.1342544555664062 -2.8

98.0142822265625 -2.511359691619873 -4.987445831298828 -3.5419158935546875 -4.354698181152344 -4.354576110839844
99.04000091552734 -2.6662914752960205 -5.178682804107666 -0.6598052978515625 -3.693134307861328 -3.6931190490722656
98.24857330322266 -2.5881717205047607 -4.9739089012146 -6.5665740966796875 -8.549476623535156 -8.549537658691406
98.0199966430664 -2.4816365242004395 -4.924543857574463 -2.4296722412109375 -2.6877288818359375 -2.6876068115234375
98.07142639160156 -2.6345367431640625 -5.006759166717529 -1.3584461212158203 -2.7095489501953125 -2.7097320556640625
97.73999786376953 -2.561185121536255 -4.9750518798828125 -1.9773406982421875 -10.044151306152344 -10.044120788574219
98.26856994628906 -2.517153739929199 -4.9322123527526855 -4.1324005126953125 -2.6591339111328125 -2.6591644287109375
97.77143096923828 -2.4345414638519287 -4.746306419372559 -0.9798812866210938 -4.103294372558594 -4.103385925292969
97.29428100585938 -2.357898712158203 -4.659854412078857 -5.4071807861328125 

95.07142639160156 -1.623996615409851 -3.2111635208129883 -1.2919769287109375 -1.8094024658203125 -1.8092803955078125
95.44000244140625 -1.7515723705291748 -3.356403112411499 -7.382408142089844 -5.457511901855469 -5.45758056640625
94.88285827636719 -1.6763995885849 -3.306159496307373 -3.9527435302734375 -0.3148326873779297 -0.3148326873779297
95.39142608642578 -1.625185489654541 -3.222222089767456 -5.632704734802246 -5.737758636474609 -5.737697601318359
94.6914291381836 -1.6138092279434204 -3.171530246734619 -3.890270233154297 -4.32183837890625 -4.321830749511719
95.37714385986328 -1.6297626495361328 -3.199995994567871 -8.526611328125 -1.5338516235351562 -1.5337295532226562
95.21142578125 -1.5951372385025024 -3.1457090377807617 -1.7613525390625 -8.12155532836914 -8.121484756469727
94.95143127441406 -1.5602926015853882 -3.0978970527648926 -0.280029296875 -6.098243713378906 -6.098148345947266
95.18571472167969 -1.5509840250015259 -3.050346851348877 -4.29669189453125 -1.5485343933105469 -1

95.11714172363281 -2.061457633972168 -4.0460004806518555 -5.020275115966797 -4.1421051025390625 -4.1421051025390625
95.52857208251953 -1.910813331604004 -3.7009997367858887 -0.43552589416503906 -4.066465377807617 -4.066526412963867
95.7828598022461 -1.981675148010254 -3.8557960987091064 -2.1038455963134766 -3.682098388671875 -3.68206787109375
95.6114273071289 -1.938044786453247 -3.723644733428955 -4.067107200622559 -7.381597518920898 -7.381536483764648
95.0142822265625 -1.9505336284637451 -3.7333664894104004 -5.6673431396484375 -8.988100051879883 -8.988065719604492
95.12000274658203 -1.918845772743225 -3.8039653301239014 -6.644536972045898 -5.562839508056641 -5.562652587890625
95.16571044921875 -1.9643754959106445 -3.8025996685028076 -10.690528869628906 -3.1545639038085938 -3.1545372009277344
94.394287109375 -1.7790576219558716 -3.549448251724243 -0.6894645690917969 -2.3993606567382812 -2.3993606567382812
95.5857162475586 -1.8802196979522705 -3.6785969734191895 -5.593254089355469 -2.06

94.47714233398438 -1.8761107921600342 -3.716312885284424 -1.8247947692871094 -3.321310043334961 -3.321187973022461
95.3114242553711 -1.932646632194519 -3.8496010303497314 -2.5327682495117188 -1.8916015625 -1.8915252685546875
94.43142700195312 -1.8551026582717896 -3.5917036533355713 -2.558124542236328 -4.816318511962891 -4.816310882568359
94.5857162475586 -1.891646146774292 -3.6985151767730713 -7.159307479858398 -2.978168487548828 -2.978168487548828
95.02285766601562 -1.834983468055725 -3.6285152435302734 -5.727668762207031 -3.1288375854492188 -3.1289024353027344
94.4085693359375 -1.8238526582717896 -3.579113483428955 -8.001506805419922 -5.220996856689453 -5.220905303955078
94.91714477539062 -1.9000533819198608 -3.7509686946868896 -4.5007781982421875 -1.1580429077148438 -1.1580429077148438
94.58285522460938 -1.8959832191467285 -3.6639842987060547 -4.8026123046875 -3.071216583251953 -3.071247100830078
94.49142456054688 -1.7964951992034912 -3.542417526245117 -1.50482177734375 -10.06214523

94.317138671875 -2.2448890209198 -4.303950309753418 -5.35772705078125 -2.997283935546875 -2.9972686767578125
93.99713897705078 -2.2654831409454346 -4.296201229095459 -2.6813907623291016 -4.6073150634765625 -4.6073150634765625
94.54571533203125 -2.286083221435547 -4.385415554046631 -3.01324462890625 -5.030342102050781 -5.030357360839844
94.72571563720703 -2.2500360012054443 -4.2782697677612305 -7.751911163330078 -6.794551849365234 -6.7945709228515625
94.44285583496094 -2.3752975463867188 -4.529047012329102 -1.5380706787109375 -11.333553314208984 -11.333492279052734
93.85142517089844 -2.4903814792633057 -4.773692607879639 -2.01605224609375 -7.237205505371094 -7.23736572265625
94.64285278320312 -2.3694732189178467 -4.646427631378174 -3.884765625 -3.231903076171875 -3.2319488525390625
94.65999603271484 -2.5128650665283203 -4.757615566253662 -2.2860031127929688 -4.77768611907959 -4.77774715423584
94.67142486572266 -2.6252684593200684 -4.871620178222656 -7.925449371337891 -1.12884521484375 -

KeyboardInterrupt: 

In [2]:
NN.load_state_dict(torch.load('2d_50points_temp'))

<All keys matched successfully>

In [4]:
torch.save(NN.state_dict(), '2D_50points_small_93')