In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions
from collections import namedtuple
from itertools import count

device = "cuda:0"
floattype = torch.float

batchsize = 512
nsamples = 8
npoints = 5
emsize = 512


class Graph_Transformer(nn.Module):
    def __init__(self, emsize = 64, nhead = 8, nhid = 1024, nlayers = 4, ndecoderlayers = 2, dropout = 0):
        super().__init__()
        self.emsize = emsize
        from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
        encoder_layers = TransformerEncoderLayer(emsize, nhead, nhid, dropout = dropout)
        decoder_layers = TransformerDecoderLayer(emsize, nhead, nhid, dropout = dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.transformer_decoder = TransformerDecoder(decoder_layers, ndecoderlayers)
        self.encoder = nn.Linear(2, emsize)
        self.outputattention_query = nn.Linear(emsize, emsize, bias = False)
        self.outputattention_key = nn.Linear(emsize, emsize, bias = False)
        self.start_token = nn.Parameter(torch.randn([emsize], device = device))
    
    def generate_subsequent_mask(self, sz): #last dimension will be softmaxed over when adding to attention logits, if boolean the ones turn into -inf
        mask = (torch.triu(torch.ones(sz, sz, device = device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        
        #mask = torch.triu(torch.ones([sz, sz], dtype = torch.bool, device = device), diagonal = 1)
        return mask
    
    def encode(self, src): #src must be [batchsize * nsamples, npoints, 3]
        src = self.encoder(src).transpose(0, 1)
        output = self.transformer_encoder(src)
        return output #[npoints, batchsize * nsamples, emsize]
    
    def decode_next(self, memory, tgt, route_mask): #route mask is [batchsize * nsamples, npoints], both memory and tgt must have batchsize and nsamples in same dimension (the 1th one)
        npoints = memory.size(0)
        batchsize = tgt.size(1)
        """if I really wanted this to be efficient I'd only recompute the decoder for the last tgt, and just remebering what the others looked like from before (won't change due to mask)"""
        """have the option to freeze the autograd on all but the last part of tgt, although at the moment this is a very natural way to say: initial choices matter more"""
        tgt_mask = self.generate_subsequent_mask(tgt.size(0))
        output = self.transformer_decoder(tgt, memory, tgt_mask) #[tgt, batchsize * nsamples, emsize]
        output_query = self.outputattention_query(memory).transpose(0, 1) #[batchsize * nsamples, npoints, emsize]
        output_key = self.outputattention_key(output[-1]) #[batchsize * nsamples, emsize]
        output_attention = torch.matmul(output_query * self.emsize ** -0.5, output_key.unsqueeze(-1)).squeeze(-1) #[batchsize * nsamples, npoints], technically don't need to scale attention as we divide by variance next anyway
        output_attention_tanh = output_attention.tanh() #[batchsize * nsamples, npoints]
        
        #we clone the route_mask incase we want to backprop using it (else it was modified by inplace opporations)
        output_attention = output_attention.masked_fill(route_mask.clone(), float('-inf')) #[batchsize * nsamples, npoints]
        output_attention_tanh = output_attention_tanh.masked_fill(route_mask.clone(), float('-inf')) #[batchsize * nsamples, npoints]
        
        return output_attention_tanh, output_attention #[batchsize * nsamples, npoints]
    
    def calculate_logprob(self, memory, routes): #memory is [npoints, batchsize * nsamples, emsize], routes is [batchsize * nsamples, npoints - 4], rather than backproping the entire loop, this saves vram (and computation)
        npoints = memory.size(0)
        ninternalpoints = routes.size(1)
        bigbatchsize = memory.size(1)
        memory_ = memory.gather(0, routes.transpose(0, 1).unsqueeze(2).expand(-1, -1, self.emsize)) #[npoints - 4, batchsize * nsamples, emsize] reorder memory into order of routes
        tgt = torch.cat([self.start_token.unsqueeze(0).unsqueeze(1).expand(1, bigbatchsize, -1), memory_[:-1]]) #[npoints - 4, batchsize * nroutes, emsize], want to go from memory to tgt
        tgt_mask = self.generate_subsequent_mask(ninternalpoints)
        output = self.transformer_decoder(tgt, memory, tgt_mask) #[npoints - 4, batchsize * nsamples, emsize]
        """want probability of going from key to query, but first need to normalise (softmax with mask)"""
        output_query = self.outputattention_query(memory_).transpose(0, 1) #[batchsize * nsamples, npoints - 4, emsize]
        output_key = self.outputattention_key(output).transpose(0, 1) #[batchsize * nsamples, npoints - 4, emsize]
        attention_mask = torch.full([ninternalpoints, ninternalpoints], True, device = device).triu(1) #[npoints - 4, npoints - 4], True for i < j
        output_attention = torch.matmul(output_query * self.emsize ** -0.5, output_key.transpose(-1, -2))
        """quick fix to stop divergence"""
        output_attention_tanh = output_attention.tanh()
        
        output_attention_tanh = output_attention_tanh.masked_fill(attention_mask, float('-inf'))
        output_attention_tanh = output_attention_tanh - output_attention_tanh.logsumexp(-2, keepdim = True) #[batchsize * nsamples, npoints - 4, npoints - 4]
        
        output_attention = output_attention.masked_fill(attention_mask, float('-inf'))
        output_attention = output_attention - output_attention.logsumexp(-2, keepdim = True) #[batchsize * nsamples, npoints - 4, npoints - 4]
        
        """infact I'm almost tempted to not mask choosing a previous point, so it's forced to learn it and somehow incorporate it into its computation, but without much impact on reinforcing good examples"""
        logprob_tanh = output_attention_tanh.diagonal(dim1 = -1, dim2 = -2).sum(-1) #[batchsize * nsamples]
        logprob = output_attention.diagonal(dim1 = -1, dim2 = -2).sum(-1) #[batchsize * nsamples]
        return logprob_tanh, logprob #[batchsize * nsamples]

NN = Graph_Transformer().to(device)
optimizer = optim.Adam(NN.parameters())


class environment:    
    def reset(self, npoints, batchsize, nsamples=1):
        self.batchsize = (
            batchsize * nsamples
        )  # so that I don't have to rewrite all this code, we store these two dimensions together
        self.nsamples = nsamples
        self.npoints = npoints
        self.points = (
            torch.rand([batchsize, npoints, 2], dtype = floattype, device=device)
            .unsqueeze(1)
            .expand(-1, nsamples, -1, -1)
            .reshape(self.batchsize, npoints, 2)
        )
        
        self.distance_matrix = (self.points.unsqueeze(1) - self.points.unsqueeze(2)).square().sum(-1).sqrt() # [batchsize * nsamples, npoints, npoints]
        
        self.previous_point = None
        
        self.points_mask = torch.zeros(
                    [self.batchsize, npoints], dtype=torch.bool, device=device
                )
        self.points_sequence = torch.empty(
            [self.batchsize, 0], dtype=torch.long, device=device
        )
        
        self.cost = torch.zeros([self.batchsize], dtype = floattype, device=device)

        self.logprob = torch.zeros([self.batchsize], dtype = floattype, device=device, requires_grad=True)

    def update(self, point_index):  # point_index is [batchsize]
        
        assert list(point_index.size()) == [self.batchsize]
        assert str(point_index.device) == device
        assert self.points_mask.gather(1, point_index.unsqueeze(1)).sum() == 0
        
        if self.previous_point != None:
            self.cost += self.distance_matrix.gather(2, self.previous_point.unsqueeze(1).unsqueeze(2).expand(-1, self.npoints, 1)).squeeze(2).gather(1, point_index.unsqueeze(1)).squeeze(1)
        
        self.previous_point = point_index
        self.points_mask.scatter_(1, point_index.unsqueeze(1), True)
        self.points_sequence = torch.cat([self.points_sequence, point_index.unsqueeze(1)], dim = 1)
        
        return
    
    def sample_point(self, logits): #logits must be [batchsize * nsamples, npoints]
        probs = torch.distributions.categorical.Categorical(logits = logits)
        next_point = probs.sample() #size is [batchsize * nsamples]
        self.update(next_point)
        self.logprob = self.logprob + probs.log_prob(next_point)
        return next_point #[batchsize * nsamples]
    
    def sampleandgreedy_point(self, logits): #logits must be [batchsize * nsamples, npoints], last sample will be the greedy choice (but we still need to keep track of its logits)
        logits_sample = logits.view(-1, self.nsamples, self.npoints)[:, :-1, :]
        probs = torch.distributions.categorical.Categorical(logits = logits_sample)
        
        sample_point = probs.sample() #[batchsize, (nsamples - 1)]
        greedy_point = logits.view(-1, self.nsamples, self.npoints)[:, -1, :].max(-1, keepdim = True)[1] #[batchsize, 1]
        next_point = torch.cat([sample_point, greedy_point], dim = 1).view(-1)
        self.update(next_point)
        self.logprob = self.logprob + torch.cat([probs.log_prob(sample_point), torch.zeros([sample_point.size(0), 1], device = device)], dim = 1).view(-1)
        return next_point
    
    def laststep(self):
        
        assert self.points_sequence.size(1) == self.npoints
        
        self.cost += self.distance_matrix.gather(2, self.points_sequence[:, 0].unsqueeze(1).unsqueeze(2).expand(-1, self.npoints, 1)).squeeze(2).gather(1, self.points_sequence[:, -1].unsqueeze(1)).squeeze(1)
    

env = environment()


def train(epochs = 30000, npoints = 10, batchsize = 100, nsamples = 8, negative_cutoff = 1):
    NN.train()
    for i in range(epochs):
        env.reset(npoints, batchsize, nsamples)
        """include the boundary points, kinda makes sense that they should contribute (atm only in the encoder, difficult to see how in the decoder)"""
        memory = NN.encode(env.points) #[npoints, batchsize * nsamples, emsize]
        #### #### #### remember to include tgt.detach() when reinstate with torch.no_grad()
        tgt = NN.start_token.unsqueeze(0).unsqueeze(1).expand(1, batchsize * nsamples, -1).detach() #[1, batchsize * nsamples, emsize]
        #with torch.no_grad(): #to speed up computation, selecting routes is done without gradient
        with torch.no_grad():
            for j in range(0, npoints):
                #### #### #### remember to include memory.detach() when reinstate with torch.no_grad()
                _, logits = NN.decode_next(memory.detach(), tgt, env.points_mask)
                next_point = env.sampleandgreedy_point(logits)
                """
                for inputing the previous embedding into decoder
                """
                tgt = torch.cat([tgt, memory.gather(0, next_point.unsqueeze(0).unsqueeze(2).expand(1, -1, memory.size(2)))]) #[nsofar, batchsize * nsamples, emsize]
                """
                for inputing the previous decoder output into the decoder (allows for an evolving strategy, but doesn't allow for fast training
                """
                ############

        env.laststep()
        
        NN.eval()
        _, logprob = NN.calculate_logprob(memory, env.points_sequence) #[batchsize * nsamples]
        NN.train()
        """
        clip logprob so doesn't reinforce things it already knows
        TBH WANT SOMETHING DIFFERENT ... want to massively increase training if find something unexpected and otherwise not
        """
        greedy_prob = logprob.view(batchsize, nsamples)[:, -1].detach() #[batchsize]
        greedy_baseline = env.cost.view(batchsize, nsamples)[:, -1] #[batchsize], greedy sample
        fixed_baseline = 0.5 * torch.ones([1], device = device)
        min_baseline = env.cost.view(batchsize, nsamples)[:, :-1].min(-1)[0] #[batchsize], minimum cost
        baseline = greedy_baseline
        positive_reinforcement = - F.relu( - (env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1))) #don't scale positive reinforcement
        negative_reinforcement = F.relu(env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1))
        positive_reinforcement_binary = env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1) <= -0.05
        negative_reinforcement_binary = env.cost.view(batchsize, nsamples)[:, :-1] - baseline.unsqueeze(1) > negative_cutoff
        """
        binary positive reinforcement
        """
        #loss = - ((logprob.view(batchsize, nsamples)[:, :-1] < -0.2) * logprob.view(batchsize, nsamples)[:, :-1] * positive_reinforcement_binary).mean() #+ (logprob.view(batchsize, nsamples)[:, :-1] > -1) * logprob.view(batchsize, nsamples)[:, :-1] * negative_reinforcement_binary
        """
        clipped binary reinforcement
        """
        #loss = ( 
        #        - logprob.view(batchsize, nsamples)[:, :-1] 
        #        #* (logprob.view(batchsize, nsamples)[:, :-1] < 0) 
        #        * positive_reinforcement_binary 
        #        + logprob.view(batchsize, nsamples)[:, :-1] 
        #        * (logprob.view(batchsize, nsamples)[:, :-1] > greedy_prob.unsqueeze(1) - 80) 
        #        * negative_reinforcement_binary 
        #).mean()
        """
        clipped binary postive, clipped weighted negative
        """
        #loss = ( - logprob.view(batchsize, nsamples)[:, :-1] * (logprob.view(batchsize, nsamples)[:, :-1] < -0.2) * positive_reinforcement_binary + logprob.view(batchsize, nsamples)[:, :-1] * (logprob.view(batchsize, nsamples)[:, :-1] > -2) * negative_reinforcement ).mean()
        """
        clipped reinforcement without rescaling
        """
        #loss = ((logprob.view(batchsize, nsamples)[:, :-1] < -0.7) * logprob.view(batchsize, nsamples)[:, :-1] * positive_reinforcement + (logprob.view(batchsize, nsamples)[:, :-1] > -5) * logprob.view(batchsize, nsamples)[:, :-1] * negative_reinforcement).mean()
        """
        clipped reinforcement
        """
        #loss = (logprob.view(batchsize, nsamples)[:, :-1] * positive_reinforcement / (positive_reinforcement.var() + 0.001).sqrt() + (logprob.view(batchsize, nsamples)[:, :-1] > -3) * logprob.view(batchsize, nsamples)[:, :-1] * negative_reinforcement / (negative_reinforcement.var() + 0.001).sqrt()).mean()
        """
        balanced reinforcement
        """
        #loss = (logprob.view(batchsize, nsamples)[:, :-1] * (positive_reinforcement / (positive_reinforcement.var() + 0.001).sqrt() + negative_reinforcement / (negative_reinforcement.var() + 0.001).sqrt())).mean()
        """
        regular loss
        """
        loss = (logprob.view(batchsize, nsamples)[:, :-1] * (positive_reinforcement + negative_reinforcement)).mean()
        optimizer.zero_grad()
        loss.backward()
        #print(NN.encoder.weight.grad)
        optimizer.step()
        #print(greedy_baseline.mean().item())
        print(greedy_baseline.mean().item(), logprob.view(batchsize, nsamples)[:, -1].mean().item(), logprob.view(batchsize, nsamples)[:, :-1].mean().item(), logprob[batchsize - 1].item(), logprob[0].item(), env.logprob[0].item())
        
   

In [2]:
train(300000, 50, 400, 8, 1)

19.33003044128418 -146.136474609375 -148.4093780517578 -146.40731811523438 -148.34957885742188 -148.3495635986328
14.266571044921875 -141.11985778808594 -147.7121124267578 -140.7115478515625 -147.5675048828125 -147.56748962402344
14.759800910949707 -120.56681060791016 -137.9326629638672 -121.71231079101562 -133.65328979492188 -133.65328979492188
13.19284439086914 -95.56456756591797 -117.49597930908203 -98.23407745361328 -118.01054382324219 -118.01053619384766
9.876015663146973 -101.44986724853516 -126.246337890625 -97.87459564208984 -132.3051300048828 -132.30514526367188
12.03374195098877 -83.46735382080078 -105.8130111694336 -87.72215270996094 -105.43649291992188 -105.43648529052734
12.128379821777344 -80.26680755615234 -103.36506652832031 -83.88217163085938 -106.64299774169922 -106.64299774169922
11.223166465759277 -76.7909927368164 -101.11714172363281 -75.93172454833984 -99.51309204101562 -99.5130615234375
9.960550308227539 -66.40434265136719 -91.32095336914062 -68.78376007080078 -9

7.349706649780273 -17.78327178955078 -31.465469360351562 -17.817176818847656 -26.24360466003418 -26.243610382080078
7.234799861907959 -17.48845672607422 -31.27753257751465 -22.975181579589844 -23.47623634338379 -23.47623062133789
7.0589799880981445 -16.931303024291992 -30.47750473022461 -16.198190689086914 -46.34967803955078 -46.34967803955078
6.989931583404541 -15.8538818359375 -28.983814239501953 -16.586029052734375 -31.43427848815918 -31.434276580810547
6.978456974029541 -15.694719314575195 -28.385976791381836 -19.445758819580078 -25.35430145263672 -25.354299545288086
6.999050617218018 -15.45067310333252 -28.465423583984375 -14.723217964172363 -28.52054214477539 -28.520530700683594
6.895200252532959 -15.701408386230469 -28.868518829345703 -15.743189811706543 -19.96455192565918 -19.964550018310547
6.861154079437256 -15.283720016479492 -28.291454315185547 -15.62879467010498 -27.030494689941406 -27.030467987060547
6.809391498565674 -14.937609672546387 -27.718908309936523 -14.1767978668

6.50493860244751 -8.65223217010498 -16.757326126098633 -6.61370325088501 -16.105363845825195 -16.105348587036133
6.435307502746582 -8.537440299987793 -16.412500381469727 -9.266490936279297 -26.11478614807129 -26.11479377746582
6.458516597747803 -8.46693229675293 -16.434785842895508 -9.517578125 -13.579084396362305 -13.579086303710938
6.486640930175781 -8.399648666381836 -16.228731155395508 -5.151573181152344 -15.31338882446289 -15.313384056091309
6.458113193511963 -8.285188674926758 -16.186893463134766 -8.554869651794434 -15.163000106811523 -15.162989616394043
6.489268779754639 -8.398886680603027 -16.318334579467773 -6.621795177459717 -16.829689025878906 -16.829710006713867
6.4669365882873535 -8.413708686828613 -16.378782272338867 -6.572201728820801 -12.979625701904297 -12.979629516601562
6.450483322143555 -8.373091697692871 -16.323598861694336 -6.680539131164551 -17.823410034179688 -17.823400497436523
6.478993892669678 -8.495205879211426 -16.44256591796875 -8.97665786743164 -20.077022

6.424017906188965 -6.701207160949707 -13.054659843444824 -5.761398792266846 -12.478910446166992 -12.478907585144043
6.4453606605529785 -6.509206295013428 -12.562358856201172 -5.167810440063477 -13.813067436218262 -13.813089370727539
6.4959001541137695 -6.482944965362549 -12.66742992401123 -6.447569847106934 -7.563160419464111 -7.5631608963012695
6.393218040466309 -6.580146312713623 -12.874239921569824 -6.856791973114014 -13.603631019592285 -13.603647232055664
6.42982292175293 -6.549321174621582 -12.837151527404785 -3.8388750553131104 -12.885007858276367 -12.885011672973633
6.466092586517334 -6.797470569610596 -13.2374849319458 -4.371448516845703 -8.721896171569824 -8.721891403198242
6.44587516784668 -6.597377300262451 -12.844226837158203 -4.847912788391113 -13.447333335876465 -13.447334289550781
6.430847644805908 -6.396340370178223 -12.515007019042969 -5.188965797424316 -9.52888298034668 -9.528868675231934
6.447946548461914 -6.337051868438721 -12.545817375183105 -5.7302093505859375 -11

6.42296838760376 -5.762937068939209 -11.171730041503906 -6.367124557495117 -7.822014331817627 -7.822031497955322
6.395906925201416 -5.774982929229736 -11.21397876739502 -4.4924774169921875 -7.223455429077148 -7.2234697341918945
6.3679680824279785 -5.636014461517334 -11.03567123413086 -8.73355770111084 -12.465396881103516 -12.465400695800781
6.379608154296875 -5.836000919342041 -11.344263076782227 -7.4495744705200195 -16.740419387817383 -16.740436553955078
6.363124847412109 -5.893789768218994 -11.63676929473877 -5.137575149536133 -10.743850708007812 -10.743842124938965
6.396153450012207 -5.845457553863525 -11.52864933013916 -7.913110733032227 -10.465043067932129 -10.465065002441406
6.371052265167236 -5.756358623504639 -11.439305305480957 -5.120235443115234 -11.633796691894531 -11.633796691894531
6.354645252227783 -5.7805304527282715 -11.301883697509766 -7.230766296386719 -16.568824768066406 -16.568811416625977
6.405754089355469 -5.70701265335083 -11.122907638549805 -3.7142016887664795 -

6.398036479949951 -5.094111919403076 -10.197518348693848 -3.8120028972625732 -15.774946212768555 -15.774941444396973
6.367917537689209 -5.090175628662109 -9.915542602539062 -3.4131813049316406 -8.115500450134277 -8.11546516418457
6.401500701904297 -5.02052116394043 -9.83594036102295 -2.675961494445801 -8.143256187438965 -8.143294334411621
6.412670612335205 -5.039845943450928 -9.931805610656738 -5.977276802062988 -4.510491371154785 -4.510499477386475
6.37891149520874 -4.960155963897705 -9.74666976928711 -4.351527214050293 -11.858598709106445 -11.858612060546875
6.384548187255859 -5.008721828460693 -9.740277290344238 -5.631043434143066 -10.946181297302246 -10.946163177490234
6.408669948577881 -4.997714519500732 -9.741244316101074 -5.796229362487793 -7.249935626983643 -7.249900817871094
6.410121917724609 -4.9725446701049805 -9.842970848083496 -4.470605850219727 -5.67562198638916 -5.675607204437256
6.368121147155762 -4.937530517578125 -9.773114204406738 -6.227197647094727 -11.5338726043701

6.332939147949219 -4.597762584686279 -9.10146713256836 -3.3583457469940186 -11.110270500183105 -11.110322952270508
6.3462724685668945 -4.4224467277526855 -8.730202674865723 -4.230208396911621 -7.713600158691406 -7.713608741760254
6.3326849937438965 -4.338639259338379 -8.654183387756348 -3.983428955078125 -9.21229362487793 -9.212287902832031
6.335909843444824 -4.346498966217041 -8.657280921936035 -2.3187570571899414 -11.856459617614746 -11.856476783752441
6.340419769287109 -4.505638122558594 -8.858529090881348 -3.813098907470703 -8.912325859069824 -8.912320137023926
6.32783317565918 -4.492217540740967 -8.862382888793945 -3.8281352519989014 -12.25655746459961 -12.256559371948242
6.353309154510498 -4.483640193939209 -8.9212064743042 -4.195924282073975 -5.108293533325195 -5.108331680297852
6.339332103729248 -4.51167631149292 -8.859567642211914 -4.3223066329956055 -5.0999979972839355 -5.100009441375732
6.378233432769775 -4.472928524017334 -8.848581314086914 -5.376364707946777 -8.31511497497

6.383330345153809 -4.220529079437256 -8.452473640441895 -4.641427993774414 -10.608854293823242 -10.60887336730957
6.367135524749756 -4.274581432342529 -8.428918838500977 -4.469339370727539 -10.535978317260742 -10.535974502563477
6.413577079772949 -4.33121395111084 -8.538346290588379 -3.412459373474121 -7.507007598876953 -7.507010459899902
6.31535005569458 -4.168598651885986 -8.132670402526855 -4.950207710266113 -3.1179802417755127 -3.1180107593536377
6.31187105178833 -3.962597370147705 -7.9137797355651855 -7.459897518157959 -3.6586742401123047 -3.6586971282958984
6.373387336730957 -4.074217319488525 -8.119643211364746 -5.046590805053711 -6.668691158294678 -6.66870641708374
6.33155632019043 -4.057862758636475 -8.070526123046875 -3.466696262359619 -3.5898101329803467 -3.589824914932251
6.390193939208984 -4.075043678283691 -8.167434692382812 -4.531830310821533 -6.242598533630371 -6.242581367492676
6.4419732093811035 -4.290799140930176 -8.42593765258789 -3.06825590133667 -6.122341632843018

6.312224388122559 -3.883446455001831 -7.776552677154541 -5.199129104614258 -4.204375267028809 -4.204367637634277
6.304249286651611 -3.8483400344848633 -7.567453384399414 -4.013813018798828 -4.112580299377441 -4.112567901611328
6.294275760650635 -3.8900344371795654 -7.707000255584717 -2.0189311504364014 -5.485822677612305 -5.485795974731445
6.278273582458496 -3.973362684249878 -7.901910305023193 -4.732812881469727 -9.190359115600586 -9.190376281738281
6.3122382164001465 -4.0597357749938965 -8.140106201171875 -3.544548511505127 -8.138221740722656 -8.138233184814453
6.332158088684082 -4.008086681365967 -8.217670440673828 -3.3971452713012695 -10.558113098144531 -10.558067321777344
6.3271989822387695 -4.099143981933594 -8.212441444396973 -4.246238708496094 -13.779958724975586 -13.779934883117676
6.334656715393066 -4.17632532119751 -8.198500633239746 -2.42252254486084 -10.987648010253906 -10.98758316040039
6.338342189788818 -4.00740385055542 -7.949265956878662 -1.535383701324463 -4.237078666

6.300744533538818 -3.6703038215637207 -7.395851135253906 -4.8920745849609375 -2.234426498413086 -2.2344188690185547
6.3208794593811035 -3.6564207077026367 -7.269511699676514 -2.8270044326782227 -6.3220415115356445 -6.322039604187012
6.421947002410889 -3.5887107849121094 -7.118345260620117 -3.7846689224243164 -4.083593368530273 -4.083600997924805
6.33948278427124 -3.613741159439087 -7.190025806427002 -2.8384876251220703 -12.217658996582031 -12.217693328857422
6.295749187469482 -3.5613112449645996 -6.943709373474121 -6.708425998687744 -7.940185546875 -7.9401702880859375
6.330727577209473 -3.478008985519409 -6.915922164916992 -3.345815658569336 -5.040170192718506 -5.040154933929443
6.285986423492432 -3.4998228549957275 -7.037168502807617 -3.483121871948242 -5.093550682067871 -5.0935468673706055
6.291315078735352 -3.5002260208129883 -6.918384552001953 -3.5399627685546875 -3.0581846237182617 -3.0581846237182617
6.305544376373291 -3.4822824001312256 -6.806455612182617 -2.3294625282287598 -6.

6.302193641662598 -3.5265748500823975 -6.9117536544799805 -4.655794143676758 -10.126945495605469 -10.126976013183594
6.295857906341553 -3.4791862964630127 -6.899807929992676 -5.218935012817383 -3.653620719909668 -3.6536359786987305
6.313863277435303 -3.4893548488616943 -6.9187164306640625 -3.0303053855895996 -5.748652458190918 -5.748679161071777
6.31577730178833 -3.3459136486053467 -6.6355085372924805 -2.082514762878418 -2.140220880508423 -2.140228509902954
6.344442844390869 -3.3126749992370605 -6.601051330566406 -1.9626253843307495 -6.621679306030273 -6.621696472167969
6.333854675292969 -3.361064910888672 -6.6317830085754395 -5.469881057739258 -5.91325569152832 -5.913237571716309
6.278262615203857 -3.4335176944732666 -6.869359493255615 -2.601118803024292 -10.853010177612305 -10.853006362915039
6.298494338989258 -3.4969987869262695 -6.848498821258545 -4.923319339752197 -7.251059532165527 -7.251058101654053
6.306710243225098 -3.429936408996582 -6.8593597412109375 -5.164907455444336 -5.1

6.285101413726807 -3.499882698059082 -6.884414196014404 -5.042047500610352 -6.0432281494140625 -6.043212890625
6.309528827667236 -3.3276290893554688 -6.744335651397705 -1.6096060276031494 -11.237285614013672 -11.237289428710938
6.283287048339844 -3.423074245452881 -6.849546432495117 -2.158611297607422 -13.170108795166016 -13.170075416564941
6.258761882781982 -3.3000435829162598 -6.635580062866211 -2.5669217109680176 -6.8896684646606445 -6.889674186706543
6.290558815002441 -3.211219310760498 -6.3635149002075195 -4.573883056640625 -7.146231651306152 -7.146224021911621
6.329382419586182 -3.229071855545044 -6.527505397796631 -3.347461700439453 -10.798393249511719 -10.798385620117188
6.284080505371094 -3.3032755851745605 -6.471895694732666 -4.05715274810791 -9.686749458312988 -9.686711311340332
6.2852325439453125 -3.299509286880493 -6.527124881744385 -4.638587951660156 -8.963298797607422 -8.963323593139648
6.288382530212402 -3.3209667205810547 -6.53515625 -1.81890869140625 -4.50093412399292

6.27094841003418 -3.2789909839630127 -6.6249775886535645 -3.487335205078125 -5.215730667114258 -5.215753555297852
6.251048564910889 -3.1862852573394775 -6.481245040893555 -1.8002243041992188 -9.765270233154297 -9.765270233154297
6.273274898529053 -3.318080425262451 -6.507038116455078 -5.135677337646484 -7.2785186767578125 -7.278511047363281
6.27197265625 -3.2887871265411377 -6.53822135925293 -4.439296722412109 -3.8030128479003906 -3.803070068359375
6.244131565093994 -3.241567850112915 -6.457894802093506 -2.9050493240356445 -11.908342361450195 -11.908332824707031
6.2694091796875 -3.298395872116089 -6.476773262023926 -1.9399375915527344 -5.801756858825684 -5.801760673522949
6.275209903717041 -3.223600149154663 -6.432542324066162 -2.0374603271484375 -5.627871036529541 -5.627835750579834
6.251208305358887 -3.2866809368133545 -6.458125114440918 -3.8596439361572266 -5.9874796867370605 -5.987498760223389
6.229615211486816 -3.213576555252075 -6.355610370635986 -4.171213150024414 -5.93226146697

6.245235919952393 -3.2401363849639893 -6.343776226043701 -2.280341625213623 -5.700952529907227 -5.700899124145508
6.254775047302246 -3.2405288219451904 -6.529824256896973 -1.6739435195922852 -7.7096710205078125 -7.709625244140625
6.205583572387695 -3.1535427570343018 -6.23030424118042 -4.461458206176758 -4.626974105834961 -4.626985549926758
6.232304573059082 -3.083543539047241 -6.027349472045898 -3.6819095611572266 -6.901102066040039 -6.901073455810547
6.219959735870361 -3.093790292739868 -6.199235916137695 -1.8879213333129883 -7.589620590209961 -7.589646339416504
6.255912780761719 -3.1088664531707764 -6.170839786529541 -2.879619598388672 -4.428622722625732 -4.4285807609558105
6.1850714683532715 -3.034588575363159 -6.132758140563965 -2.989377021789551 -2.142667770385742 -2.1426525115966797
6.221800327301025 -3.0769124031066895 -6.1267523765563965 -2.320927619934082 -10.170927047729492 -10.17091178894043
6.264737606048584 -2.96632719039917 -6.0014214515686035 -3.5585999488830566 -6.2318

6.261255741119385 -2.900623083114624 -5.832141399383545 -3.4887499809265137 -5.881644248962402 -5.881674766540527
6.231536865234375 -2.9428868293762207 -5.826429843902588 -1.3399735689163208 -2.308476448059082 -2.308487892150879
6.239958763122559 -2.8225882053375244 -5.595407485961914 -3.10581111907959 -4.014057159423828 -4.014072418212891
6.315671443939209 -2.954802989959717 -5.861447811126709 -3.1840970516204834 -2.411813974380493 -2.4118294715881348
6.239461421966553 -2.9992470741271973 -5.930800914764404 -2.705108642578125 -8.327198028564453 -8.327214241027832
6.21827507019043 -3.017901659011841 -5.9354567527771 -2.4029674530029297 -4.511209011077881 -4.511186122894287
6.301362037658691 -3.129953145980835 -6.277260780334473 -4.887813568115234 -5.696079254150391 -5.696014404296875
6.237692832946777 -3.028826951980591 -6.214446067810059 -3.585498809814453 -1.6277813911437988 -1.6278271675109863
6.273708820343018 -3.0301873683929443 -6.099236965179443 -2.0027847290039062 -11.633557319

6.269144535064697 -3.471403121948242 -6.871779918670654 -3.107593059539795 -3.2353668212890625 -3.2353434562683105
6.21986722946167 -3.5315470695495605 -6.896057605743408 -4.831926345825195 -7.06587028503418 -7.065820693969727
6.26936674118042 -3.5545711517333984 -7.011132717132568 -4.010720252990723 -6.597414970397949 -6.597373962402344
6.215976238250732 -3.5014560222625732 -6.978701114654541 -3.977710247039795 -8.976408004760742 -8.976408004760742
6.1989288330078125 -3.5549256801605225 -7.008191108703613 -2.416339874267578 -3.59897518157959 -3.598989963531494
6.197490215301514 -3.5170247554779053 -6.932877063751221 -4.445046901702881 -2.8258073329925537 -2.825824499130249
6.209249019622803 -3.5662448406219482 -7.137401580810547 -4.330692291259766 -5.706857681274414 -5.706865310668945
6.228809833526611 -3.487773895263672 -6.998098373413086 -3.8619511127471924 -6.12558650970459 -6.125591278076172
6.218449592590332 -3.656320810317993 -7.159782409667969 -3.4351558685302734 -9.46151924133

6.168106555938721 -3.4633560180664062 -6.797721862792969 -2.8923072814941406 -5.325981140136719 -5.325981140136719
6.229274749755859 -3.5846681594848633 -7.1020731925964355 -4.174298286437988 -9.556041717529297 -9.555973052978516
6.145666599273682 -3.7302911281585693 -7.286090850830078 -2.8435726165771484 -4.985599517822266 -4.985591888427734
6.158493518829346 -3.518249750137329 -6.938694000244141 -3.4250049591064453 -2.9973602294921875 -2.9973526000976562
6.194765567779541 -3.5460901260375977 -7.080442905426025 -4.161993503570557 -8.913322448730469 -8.913326263427734
6.177839756011963 -3.849719762802124 -7.502414703369141 -4.481311798095703 -7.934525489807129 -7.9344892501831055
6.21146821975708 -4.019595146179199 -7.902556896209717 -2.813243865966797 -7.441913604736328 -7.441905975341797
6.108489990234375 -3.9294233322143555 -7.74910831451416 -3.866504669189453 -5.387165069580078 -5.387164115905762
6.121530532836914 -3.5423967838287354 -6.9878997802734375 -4.362213134765625 -4.456031

6.112399578094482 -3.3052432537078857 -6.519288539886475 -3.2561912536621094 -12.263957977294922 -12.263919830322266
6.111662864685059 -3.2946367263793945 -6.457895278930664 -6.336267471313477 -7.854175567626953 -7.854183197021484
6.089831829071045 -3.2774693965911865 -6.518589019775391 -3.0568313598632812 -3.4407119750976562 -3.440673828125
6.109322547912598 -3.3928091526031494 -6.586215019226074 -3.6458892822265625 -6.136911869049072 -6.1368889808654785
6.13066291809082 -3.3738203048706055 -6.753149509429932 -3.528657913208008 -7.804178237915039 -7.804216384887695
6.114846229553223 -3.374063014984131 -6.647347927093506 -2.456453323364258 -4.149716377258301 -4.149746894836426
6.09478235244751 -3.3287246227264404 -6.60198450088501 -2.9261083602905273 -3.0642776489257812 -3.0642471313476562
6.134787559509277 -3.369821786880493 -6.836653232574463 -2.1365699768066406 -5.919167518615723 -5.919186592102051
6.118931293487549 -3.40010929107666 -6.8310699462890625 -2.791105270385742 -5.4955129

6.0759735107421875 -3.475083827972412 -6.87922477722168 -3.081745147705078 -7.665187835693359 -7.665189743041992
6.087902069091797 -3.413564920425415 -6.804405689239502 -4.3321075439453125 -2.148893356323242 -2.1488780975341797
6.095430850982666 -3.38116455078125 -6.684708118438721 -2.0181503295898438 -4.046564102172852 -4.0465497970581055
6.048058986663818 -3.260507106781006 -6.60371732711792 -5.447132110595703 -2.3708038330078125 -2.3707962036132812
6.086666107177734 -3.2307214736938477 -6.439690113067627 -4.780635833740234 -5.631145477294922 -5.631092071533203
6.051230430603027 -3.262873411178589 -6.305365562438965 -4.199394226074219 -9.466197967529297 -9.466320037841797
6.074657917022705 -3.2318103313446045 -6.323136329650879 -1.7059974670410156 -6.29414176940918 -6.294111251831055
6.063990116119385 -3.2291345596313477 -6.3616461753845215 -2.924335479736328 -4.78382682800293 -4.783773422241211
6.053036212921143 -3.2248895168304443 -6.478322982788086 -2.0599002838134766 -4.477096557

6.093143939971924 -2.926995277404785 -5.908349514007568 -2.601837158203125 -6.283924102783203 -6.283893585205078
6.135760307312012 -2.9747283458709717 -5.894532680511475 -3.9228363037109375 -2.8638648986816406 -2.8639183044433594
6.107645034790039 -2.9097955226898193 -5.741368293762207 -2.0919189453125 -3.1867122650146484 -3.1866893768310547
6.050647258758545 -2.756359815597534 -5.584629058837891 -2.2510414123535156 -3.949756622314453 -3.9497108459472656
6.108016014099121 -2.8809802532196045 -5.782443046569824 -5.162746429443359 -3.2178192138671875 -3.2178268432617188
6.107272624969482 -2.924046516418457 -5.841464996337891 -3.2058677673339844 -3.5302047729492188 -3.5301513671875
6.08718729019165 -2.9290077686309814 -5.8221259117126465 -2.390705108642578 -6.508964538574219 -6.508979797363281
6.083410739898682 -3.0037102699279785 -5.929259300231934 -1.9429473876953125 -10.399035453796387 -10.399008750915527
6.094064712524414 -2.950944185256958 -5.889968395233154 -6.703163146972656 -8.560

6.055828094482422 -2.826214551925659 -5.652393341064453 -5.0360260009765625 -6.953893661499023 -6.95380973815918
6.066459655761719 -2.8260061740875244 -5.646817207336426 -2.1577095985412598 -5.090801239013672 -5.090785980224609
6.088109493255615 -2.7913196086883545 -5.671806812286377 -2.8521575927734375 -2.624969482421875 -2.6249237060546875
6.048030376434326 -2.7310779094696045 -5.430140972137451 -1.809276819229126 -6.927343368530273 -6.92744255065918
6.065072059631348 -2.7685742378234863 -5.5350022315979 -4.135601043701172 -4.700969696044922 -4.701007843017578
6.072907447814941 -2.8305304050445557 -5.536027908325195 -2.2179603576660156 -3.693775177001953 -3.6937904357910156
6.09714937210083 -2.792818546295166 -5.52345609664917 -4.930141925811768 -2.132740020751953 -2.1327247619628906
6.066073894500732 -2.7423930168151855 -5.512932300567627 -3.8340606689453125 -3.0182838439941406 -3.0183143615722656
6.06967830657959 -2.7225711345672607 -5.4788923263549805 -2.0877418518066406 -6.578083

6.08971643447876 -2.6568474769592285 -5.319100379943848 -1.6044502258300781 -4.937778472900391 -4.937770843505859
6.072056770324707 -2.6000969409942627 -5.297317981719971 -2.2915077209472656 -0.5346145629882812 -0.5345993041992188
6.135871887207031 -2.623330593109131 -5.25899600982666 -2.338958740234375 -4.054931640625 -4.054908752441406
6.112475395202637 -2.4805238246917725 -4.986461162567139 -2.187774658203125 -1.7108039855957031 -1.7107887268066406
6.0692362785339355 -2.6436562538146973 -5.153634548187256 -2.562776565551758 -2.4654088020324707 -2.4654316902160645
6.080073833465576 -2.6961936950683594 -5.360445022583008 -2.356393814086914 -4.890890121459961 -4.890893936157227
6.087014198303223 -2.6756503582000732 -5.462314128875732 -3.0678977966308594 -7.3002166748046875 -7.300203323364258
6.054144382476807 -2.7595083713531494 -5.454588890075684 -2.9574756622314453 -2.0174407958984375 -2.0174484252929688
6.101603984832764 -2.7851037979125977 -5.559943675994873 -3.1632919311523438 -13

6.07007360458374 -2.542877674102783 -5.043116092681885 -0.5458526611328125 -0.49703216552734375 -0.49703216552734375
6.0580058097839355 -2.341946840286255 -4.707261562347412 -3.2445526123046875 -3.12310791015625 -3.1231002807617188
6.127530574798584 -2.432572364807129 -4.8191142082214355 -2.3025588989257812 -2.562591552734375 -2.5625762939453125
6.12160587310791 -2.2830193042755127 -4.611243724822998 -0.9390411376953125 -1.164581298828125 -1.1645660400390625
6.090504169464111 -2.2591235637664795 -4.584822177886963 -2.4644012451171875 -3.1871509552001953 -3.187135696411133
6.089017391204834 -2.3500888347625732 -4.659417629241943 -2.990875244140625 -4.01824951171875 -4.018226623535156
6.100786209106445 -2.5173563957214355 -4.934963703155518 -2.4536590576171875 -1.9864025115966797 -1.9864025115966797
6.12663459777832 -2.5343892574310303 -4.969429969787598 -1.0419769287109375 -1.3942852020263672 -1.3943157196044922
6.073231220245361 -2.5598256587982178 -4.991123199462891 -2.253440856933593

6.072000980377197 -2.359368324279785 -4.7469072341918945 -2.3563995361328125 -1.593231201171875 -1.5932159423828125
6.048222541809082 -2.2766215801239014 -4.609384536743164 -0.88519287109375 -4.4438629150390625 -4.4437713623046875
6.0573344230651855 -2.3383564949035645 -4.698730945587158 -2.6113827228546143 -6.5953521728515625 -6.595298767089844
6.037463188171387 -2.3774752616882324 -4.7384514808654785 -1.6270294189453125 -6.755159378051758 -6.755136489868164
6.043300628662109 -2.467989683151245 -4.842337608337402 -1.935760498046875 -6.135139465332031 -6.1351470947265625
6.051362991333008 -2.6362011432647705 -5.16688346862793 -1.3928070068359375 -4.216571807861328 -4.216709136962891
6.0502142906188965 -2.552194595336914 -5.073945999145508 -1.1188793182373047 -4.660163879394531 -4.660148620605469
6.03835916519165 -2.5029122829437256 -4.889787197113037 -2.0335235595703125 -10.853607177734375 -10.8536376953125
6.063566207885742 -2.5122108459472656 -4.905876159667969 -3.3374252319335938 -5

6.030551433563232 -2.3812191486358643 -4.624149799346924 -1.248270034790039 -7.5692596435546875 -7.5692901611328125
6.056416034698486 -2.2755300998687744 -4.551759243011475 -0.7900810241699219 -5.517099380493164 -5.517023086547852
6.027668476104736 -2.3018112182617188 -4.607569217681885 -2.133350372314453 -3.6969375610351562 -3.6969680786132812
6.059947490692139 -2.27689266204834 -4.523604393005371 -1.379974365234375 -6.331153869628906 -6.3311309814453125
6.016168117523193 -2.2416305541992188 -4.498420238494873 -4.661170959472656 -1.5457038879394531 -1.5456733703613281
6.021347522735596 -2.292048692703247 -4.5627121925354 -4.8507843017578125 -3.6234922409057617 -3.623538017272949
6.036349773406982 -2.2251455783843994 -4.383234977722168 -0.66265869140625 -3.4762535095214844 -3.4762611389160156
6.023169040679932 -2.2928578853607178 -4.535923004150391 -1.5513429641723633 -5.995487213134766 -5.9954833984375
6.054195404052734 -2.3561599254608154 -4.639978885650635 -1.4018936157226562 -3.828

6.087625980377197 -2.193477153778076 -4.407115936279297 -2.2352752685546875 -1.6598917245864868 -1.6598840951919556
6.059452533721924 -2.349642276763916 -4.623088359832764 -3.7745704650878906 -3.95050048828125 -3.95050048828125
6.064529895782471 -2.310576915740967 -4.57757043838501 -3.007068634033203 -10.645018577575684 -10.645049095153809
6.072282791137695 -2.374436855316162 -4.6751179695129395 -3.437347412109375 -6.092987060546875 -6.092994689941406
6.028260231018066 -2.2571280002593994 -4.4633259773254395 -3.9361953735351562 -4.60394287109375 -4.60394287109375
6.0176591873168945 -2.1545705795288086 -4.186982154846191 -3.4924774169921875 -6.002010345458984 -6.002040863037109
6.073052883148193 -1.996539831161499 -4.0293145179748535 -1.0962295532226562 -2.245208740234375 -2.245269775390625
6.082691669464111 -2.0503463745117188 -3.985731601715088 -3.314685821533203 -9.398246765136719 -9.3983154296875
6.09188175201416 -2.0297558307647705 -4.107326030731201 -3.07940673828125 -6.6874265670

6.035006999969482 -2.222302198410034 -4.320028781890869 -2.7564544677734375 -1.3267822265625 -1.3267669677734375
6.027291297912598 -2.1183176040649414 -4.250256061553955 -0.407257080078125 -3.4328155517578125 -3.4328041076660156
6.052084922790527 -2.1418426036834717 -4.279172897338867 -2.5140380859375 -1.8045425415039062 -1.8045730590820312
6.014750957489014 -2.162362813949585 -4.330291748046875 -3.4791312217712402 -1.5287219285964966 -1.5286990404129028
6.010194778442383 -2.2634592056274414 -4.47779655456543 -2.3000259399414062 -5.013877868652344 -5.013896942138672
6.042282581329346 -2.231218099594116 -4.354273319244385 -2.9033126831054688 -7.077056884765625 -7.0770416259765625
6.063281536102295 -2.294200897216797 -4.398778915405273 -1.9475860595703125 -5.490623474121094 -5.490509033203125
6.013682842254639 -2.140162229537964 -4.286874294281006 -1.740002155303955 -7.708763122558594 -7.708824157714844
6.056893825531006 -2.086667060852051 -4.329707145690918 -1.5630302429199219 -3.152256

6.01199197769165 -2.1882803440093994 -4.299367904663086 -1.2047653198242188 -1.5101699829101562 -1.510162353515625
6.024225234985352 -2.256056070327759 -4.489656925201416 -2.1157302856445312 -3.070718765258789 -3.070772171020508
6.021371364593506 -2.2262611389160156 -4.407971382141113 -3.0254364013671875 -8.33857536315918 -8.33863639831543
6.000516414642334 -2.226447105407715 -4.476352691650391 -0.5669479370117188 -2.3306045532226562 -2.3306198120117188
6.017801284790039 -2.2522053718566895 -4.492978572845459 -3.3592376708984375 -1.5117874145507812 -1.5117568969726562
6.020867824554443 -2.1966488361358643 -4.249969959259033 -3.0810317993164062 -9.055950164794922 -9.055904388427734
6.05754280090332 -2.0860719680786133 -4.123812198638916 -3.1264190673828125 -1.2356185913085938 -1.2356491088867188
6.037773132324219 -2.0908102989196777 -4.120652675628662 -1.0196075439453125 -1.4515838623046875 -1.451568603515625
6.026955604553223 -2.1282107830047607 -4.2291693687438965 -3.0281600952148438 

KeyboardInterrupt: 

In [3]:
train(300000, 50, 400, 8, 1)

6.004909515380859 -1.869944453239441 -3.7802224159240723 -1.39208984375 -14.177070617675781 -14.176979064941406
6.49668550491333 -2.217228889465332 -4.407912254333496 -2.334476947784424 -1.00799560546875 -1.0079498291015625
6.18903112411499 -1.9306493997573853 -3.8357763290405273 -1.1486587524414062 -2.7324609756469727 -2.7324447631835938
6.253580093383789 -1.9833091497421265 -3.8868629932403564 -1.2449569702148438 -1.9937667846679688 -1.9938125610351562
6.172737121582031 -2.0072288513183594 -3.973785161972046 -1.337432861328125 -4.444782257080078 -4.444873809814453
6.157509803771973 -2.045542001724243 -4.009119987487793 -1.0617847442626953 -5.1882171630859375 -5.1881866455078125
6.085348606109619 -2.0759365558624268 -4.1287455558776855 -1.024871826171875 -2.605335235595703 -2.605335235595703
6.065239906311035 -2.1254994869232178 -4.173759937286377 -1.240203857421875 -4.008415222167969 -4.00848388671875
6.100636959075928 -1.9560027122497559 -3.8220300674438477 -0.6887622475624084 -1.93

5.997255802154541 -1.8261302709579468 -3.640043258666992 -1.990997314453125 -1.0973854064941406 -1.0973844528198242
6.045840263366699 -1.8626693487167358 -3.756639242172241 -0.577545166015625 -5.560691833496094 -5.560722351074219
6.03754997253418 -1.812174916267395 -3.6294639110565186 -1.7860107421875 -1.1716346740722656 -1.1716499328613281
6.065587043762207 -1.7870512008666992 -3.566493511199951 -1.9807968139648438 -1.7948379516601562 -1.7948684692382812
6.021229267120361 -1.6907621622085571 -3.4814305305480957 -1.955718994140625 -4.212974548339844 -4.2129669189453125
6.027080059051514 -1.7766759395599365 -3.595275402069092 -1.3526763916015625 -3.2512893676757812 -3.2512893676757812
6.032229900360107 -1.914474606513977 -3.786656618118286 -2.108245849609375 -2.0050735473632812 -2.0051040649414062
6.049139499664307 -2.0102767944335938 -3.974022388458252 -1.7793731689453125 -3.781024932861328 -3.781024932861328
6.040006160736084 -2.04783296585083 -4.086305141448975 -2.6056365966796875 -2

6.030548095703125 -1.9952175617218018 -3.9769704341888428 -2.1142425537109375 -2.836221694946289 -2.836282730102539
5.996162414550781 -2.0118911266326904 -3.9964849948883057 -1.812255859375 -1.34075927734375 -1.3408050537109375
6.01146936416626 -1.9804205894470215 -3.8357889652252197 -2.3402786254882812 -8.576751708984375 -8.576751708984375
6.022223949432373 -1.9923995733261108 -3.9087865352630615 -1.64013671875 -3.4050846099853516 -3.4050350189208984
5.991706371307373 -1.9302839040756226 -3.8418662548065186 -1.887054443359375 -6.03167724609375 -6.031829833984375
6.008272171020508 -2.0481112003326416 -4.077005863189697 -2.5167617797851562 -2.9688568115234375 -2.968841552734375
6.024898529052734 -2.1048049926757812 -4.1211018562316895 -1.1552581787109375 -3.6933517456054688 -3.6933822631835938
5.991414546966553 -2.1069235801696777 -4.118629455566406 -3.238872528076172 -6.633544921875 -6.6335906982421875
6.037352085113525 -1.8918495178222656 -3.8686647415161133 -0.8825454711914062 -3.529

6.035398960113525 -1.773677945137024 -3.5000460147857666 -1.4200897216796875 -6.2127838134765625 -6.2129364013671875
6.058258056640625 -1.7054002285003662 -3.353588104248047 -2.2430286407470703 -11.639595031738281 -11.639656066894531
6.057218074798584 -1.6277618408203125 -3.2485101222991943 -1.5027313232421875 -1.5826644897460938 -1.5826339721679688
6.0239152908325195 -1.6338906288146973 -3.2486793994903564 -1.9890594482421875 -5.047975540161133 -5.047898292541504
6.035024166107178 -1.5731292963027954 -3.0920004844665527 -1.0812764167785645 -1.934967041015625 -1.9349365234375
6.033423900604248 -1.690094232559204 -3.3169965744018555 -1.8498458862304688 -3.0706329345703125 -3.0706024169921875
6.081745624542236 -1.7873481512069702 -3.47613525390625 -2.4507102966308594 -1.2267608642578125 -1.2267608642578125
6.073150634765625 -1.6477165222167969 -3.218644618988037 -1.03118896484375 -0.5271453857421875 -0.5271148681640625
6.009598731994629 -1.6554718017578125 -3.3197083473205566 -2.50999450

6.1192626953125 -1.7735440731048584 -3.5366196632385254 -1.4469757080078125 -8.351341247558594 -8.351234436035156
6.0413737297058105 -1.7647161483764648 -3.5172293186187744 -1.5680084228515625 -5.353481292724609 -5.353572845458984
6.060417175292969 -1.878129482269287 -3.7589075565338135 -1.9548797607421875 -3.9859390258789062 -3.9858169555664062
6.037166595458984 -1.843079686164856 -3.6620075702667236 -1.0366439819335938 -4.251495361328125 -4.25152587890625
6.03261137008667 -1.798761248588562 -3.568784236907959 -1.1295013427734375 -0.8187408447265625 -0.818695068359375
6.045101165771484 -1.7162823677062988 -3.466093063354492 -2.1299285888671875 -8.547836303710938 -8.547836303710938
6.0938849449157715 -1.743327021598816 -3.428516387939453 -2.3689231872558594 -3.0028533935546875 -3.0028228759765625
6.072945594787598 -1.722800850868225 -3.476851224899292 -1.9338645935058594 -2.7174301147460938 -2.7174835205078125
6.090677261352539 -1.7241483926773071 -3.44905686378479 -1.9411697387695312 

6.036983489990234 -1.6578279733657837 -3.352673053741455 -1.95489501953125 -1.200836181640625 -1.200897216796875
6.0212273597717285 -1.6576473712921143 -3.2970738410949707 -1.8939743041992188 -2.1944732666015625 -2.1945648193359375
6.022148132324219 -1.6383761167526245 -3.3220436573028564 -0.64471435546875 -0.9096908569335938 -0.9097213745117188
6.049388885498047 -1.6296846866607666 -3.258125066757202 -1.801849365234375 -4.214599609375 -4.2145538330078125
6.043952465057373 -1.7414300441741943 -3.415268898010254 -1.255218505859375 -5.592720031738281 -5.592811584472656
6.026594638824463 -1.8127541542053223 -3.6080172061920166 -1.5889968872070312 -1.9635696411132812 -1.963653564453125
6.048704624176025 -1.7884533405303955 -3.534937620162964 -0.7066192626953125 -1.4062347412109375 -1.4062957763671875
6.037310600280762 -1.8749374151229858 -3.702195405960083 -0.5509834289550781 -1.7465972900390625 -1.7465972900390625
6.026337623596191 -1.8309201002120972 -3.683464288711548 -1.245094299316406

6.015767574310303 -1.6385222673416138 -3.2347731590270996 -1.3040084838867188 -3.3925018310546875 -3.39251708984375
6.062524318695068 -1.5907937288284302 -3.171383857727051 -2.5362548828125 -0.6418685913085938 -0.6418075561523438
6.090701580047607 -1.6793643236160278 -3.4145679473876953 -1.2707977294921875 -7.805572509765625 -7.8057098388671875
6.038266658782959 -1.6359552145004272 -3.1970155239105225 -1.727783203125 -3.3731460571289062 -3.37310791015625
6.044846057891846 -1.7770164012908936 -3.524195671081543 -2.116384506225586 -1.478790283203125 -1.4788818359375
6.0650315284729 -1.8306427001953125 -3.563337802886963 -2.4067230224609375 -7.1583404541015625 -7.158416748046875
6.045806884765625 -1.9490456581115723 -3.8091659545898438 -0.936798095703125 -4.460784912109375 -4.460784912109375
6.058125019073486 -2.162958860397339 -4.231051921844482 -1.2232131958007812 -7.825592041015625 -7.825592041015625
6.008068561553955 -2.065128803253174 -4.009923458099365 -2.92041015625 -5.506238937377

6.02264404296875 -1.7152533531188965 -3.410659074783325 -1.77789306640625 -2.0276641845703125 -2.02764892578125
6.017127513885498 -1.7025271654129028 -3.4097604751586914 -1.86383056640625 -1.5777130126953125 -1.5777130126953125
6.005137920379639 -1.7417515516281128 -3.496816635131836 -1.9308624267578125 -3.2418975830078125 -3.2420501708984375
6.018027305603027 -1.6733760833740234 -3.2563133239746094 -3.4626007080078125 -10.507278442382812 -10.507308959960938
5.992719650268555 -1.6362756490707397 -3.3193917274475098 -1.7988967895507812 -3.2456207275390625 -3.2456817626953125
5.975152492523193 -1.7506613731384277 -3.4518063068389893 -2.3191604614257812 -5.5950927734375 -5.595123291015625
5.994629383087158 -1.6597614288330078 -3.270170211791992 -0.8511276245117188 -1.2983779907226562 -1.2983627319335938
5.99791145324707 -1.7211240530014038 -3.4324915409088135 -0.9601287841796875 -3.6985397338867188 -3.6984634399414062
6.014777183532715 -1.6918773651123047 -3.4006619453430176 -0.8170204162

6.042608737945557 -1.5208269357681274 -3.0947957038879395 -2.5006790161132812 -3.2865066528320312 -3.2863845825195312
6.034857273101807 -1.5334330797195435 -3.021559715270996 -4.085968017578125 -1.1708526611328125 -1.1707916259765625
6.004877090454102 -1.4982424974441528 -3.119779109954834 -1.705169677734375 -5.600555419921875 -5.60064697265625
6.043605804443359 -1.6771557331085205 -3.3249897956848145 -2.01007080078125 -0.3430938720703125 -0.3430633544921875
6.073487281799316 -1.7461150884628296 -3.4523842334747314 -1.7568817138671875 -13.676651000976562 -13.676712036132812
6.064028263092041 -1.853680968284607 -3.6526708602905273 -1.3271331787109375 -9.764236450195312 -9.764144897460938
6.030018329620361 -1.8606481552124023 -3.732206344604492 -0.10736083984375 -3.3785858154296875 -3.3784942626953125
6.083928108215332 -1.8833980560302734 -3.6625537872314453 -1.9105072021484375 -3.71063232421875 -3.710601806640625
6.0242533683776855 -1.893986463546753 -3.7922215461730957 -1.3481292724609

6.040686130523682 -1.7920717000961304 -3.4942362308502197 -1.0713539123535156 -3.2369537353515625 -3.2369232177734375
6.03404426574707 -1.710086464881897 -3.396620988845825 -1.7251434326171875 -3.29974365234375 -3.29974365234375
6.011738300323486 -1.6136746406555176 -3.2204792499542236 -1.086517333984375 -2.516357421875 -2.5163116455078125
6.050605297088623 -1.5877883434295654 -3.2053773403167725 -0.8821418285369873 -7.742801666259766 -7.742877960205078
6.038130760192871 -1.5775623321533203 -3.1714560985565186 -1.7802734375 -3.2465667724609375 -3.24664306640625
6.012831687927246 -1.5068713426589966 -2.9889984130859375 -1.8244476318359375 -3.8126068115234375 -3.8125762939453125
6.017024993896484 -1.4765697717666626 -2.9406487941741943 -1.7510948181152344 -5.197662353515625 -5.197662353515625
6.038521766662598 -1.5154833793640137 -3.059558868408203 -1.097686767578125 -4.595664978027344 -4.595619201660156
6.051908493041992 -1.474731683731079 -2.9482626914978027 -1.3247222900390625 -3.3669

6.049709320068359 -1.5136773586273193 -2.9927453994750977 -1.876220703125 -2.314544677734375 -2.3145294189453125
6.060441970825195 -1.482019066810608 -2.9858195781707764 -2.1031951904296875 -5.2828369140625 -5.282867431640625
6.083563327789307 -1.59439218044281 -3.1417407989501953 -1.4553070068359375 -1.1304702758789062 -1.1304397583007812
6.034404277801514 -1.5418996810913086 -3.0569894313812256 -3.41656494140625 -4.236347198486328 -4.236362457275391
6.048730373382568 -1.5044838190078735 -2.9275708198547363 -1.494964599609375 -3.184864044189453 -3.184833526611328
6.098546028137207 -1.4414870738983154 -2.8696117401123047 -0.9112091064453125 -2.4860992431640625 -2.4861602783203125
6.077377796173096 -1.3702208995819092 -2.7554383277893066 -0.9411373138427734 -0.8272705078125 -0.827301025390625
6.0778632164001465 -1.451485276222229 -2.895240545272827 -1.0886993408203125 -3.6116180419921875 -3.6115875244140625
5.997026443481445 -1.4005727767944336 -2.7939157485961914 -1.69610595703125 -3.5

5.995612144470215 -1.700370192527771 -3.3754546642303467 -0.7646198272705078 -2.4296722412109375 -2.4298248291015625
6.057459354400635 -1.6444588899612427 -3.256709575653076 -4.2755126953125 -1.3205413818359375 -1.3205413818359375
6.083680152893066 -1.669374942779541 -3.3533401489257812 -0.23010492324829102 -2.6248931884765625 -2.6249847412109375
6.069260120391846 -1.6517308950424194 -3.3576900959014893 -1.648651123046875 -2.53765869140625 -2.537628173828125
6.052239418029785 -1.7264243364334106 -3.361056089401245 -0.33538055419921875 -8.159957885742188 -8.159988403320312
6.039010524749756 -1.7132558822631836 -3.389414072036743 -2.067596435546875 -2.8674392700195312 -2.8674087524414062
6.013553142547607 -1.6173319816589355 -3.310044527053833 -1.3915863037109375 -2.9675140380859375 -2.96759033203125
6.0594377517700195 -1.6874297857284546 -3.30767822265625 -1.7785797119140625 -0.4288673400878906 -0.4288368225097656
6.065240383148193 -1.5135748386383057 -3.0735371112823486 -1.262527465820

KeyboardInterrupt: 

In [2]:
NN.load_state_dict(torch.load('TSP_50points_big_6.0'))

<All keys matched successfully>

In [4]:
torch.save(NN.state_dict(),'TSP_50points_big_6.00')

In [5]:
def evaluate(epochs = 30000, npoints = 10, batchsize = 100, nsamples = 8, negative_cutoff = 1):
    NN.eval()
    for i in range(epochs):
        env.reset(npoints, batchsize, nsamples)
        """include the boundary points, kinda makes sense that they should contribute (atm only in the encoder, difficult to see how in the decoder)"""
        memory = NN.encode(env.points) #[npoints, batchsize * nsamples, emsize]
        #### #### #### remember to include tgt.detach() when reinstate with torch.no_grad()
        tgt = NN.start_token.unsqueeze(0).unsqueeze(1).expand(1, batchsize * nsamples, -1).detach() #[1, batchsize * nsamples, emsize]
        
        NN.eval()
        #with torch.no_grad(): #to speed up computation, selecting routes is done without gradient
        with torch.no_grad():
            for j in range(0, npoints):
                #### #### #### remember to include memory.detach() when reinstate with torch.no_grad()
                _, logits = NN.decode_next(memory.detach(), tgt, env.points_mask)
                next_point = env.sampleandgreedy_point(logits)
                """
                for inputing the previous embedding into decoder
                """
                tgt = torch.cat([tgt, memory.gather(0, next_point.unsqueeze(0).unsqueeze(2).expand(1, -1, memory.size(2)))]) #[nsofar, batchsize * nsamples, emsize]
                """
                for inputing the previous decoder output into the decoder (allows for an evolving strategy, but doesn't allow for fast training
                """
                ############

        env.laststep()
        
        _, logprob = NN.calculate_logprob(memory, env.points_sequence) #[batchsize * nsamples]
        
        """
        clip logprob so doesn't reinforce things it already knows
        TBH WANT SOMETHING DIFFERENT ... want to massively increase training if find something unexpected and otherwise not
        """
        greedy_prob = logprob.view(batchsize, nsamples)[:, -1].detach() #[batchsize]
        greedy_baseline = env.cost.view(batchsize, nsamples)[:, -1] #[batchsize], greedy sample
        
        print(greedy_baseline.mean().item(), logprob.view(batchsize, nsamples)[:, -1].mean().item(), logprob.view(batchsize, nsamples)[:, :-1].mean().item(), logprob[batchsize - 1].item(), logprob[0].item(), env.logprob[0].item())
     

In [9]:
evaluate(10, 1000, 1, 2)

36.123863220214844 -374.4749755859375 -705.6748046875 -705.6748046875 -705.6748046875 -705.6744995117188
36.29134750366211 -386.007568359375 -686.5157470703125 -686.5157470703125 -686.5157470703125 -686.5166015625
36.29013442993164 -370.65765380859375 -662.8458251953125 -662.8458251953125 -662.8458251953125 -662.8469848632812
37.08647918701172 -347.39495849609375 -647.4351806640625 -647.4351806640625 -647.4351806640625 -647.4359741210938
37.4781608581543 -355.7390441894531 -674.520263671875 -674.520263671875 -674.520263671875 -674.5214233398438
37.6519889831543 -381.50885009765625 -668.430419921875 -668.430419921875 -668.430419921875 -668.42919921875
37.44511413574219 -342.01507568359375 -608.0435791015625 -608.0435791015625 -608.0435791015625 -608.0423583984375
38.30666732788086 -341.02490234375 -552.3333740234375 -552.3333740234375 -552.3333740234375 -552.3348388671875
37.618324279785156 -346.57269287109375 -624.7066650390625 -624.7066650390625 -624.7066650390625 -624.7078247070312
3