In [0]:
import time
import os
import random
import math
import torch
import torch.autograd
import torch.nn.functional as F
from torch.autograd import Variable
from bisect import bisect_left
import matplotlib.pyplot as plt


#Functions for dataset Generation
#################################

def generate_dataset(N):
    dataset = [exponential() for _ in range(N)]
    #dataset = [lognormal() for _ in range(N)]
    #dataset = [laplace() for _ in range(N)]
    #dataset = [uniform() for _ in range(N)]
    #dataset = [gaussian() for _ in range(N)]
    dataset = sorted(dataset)
    def KVrand():
        x = random.choice(dataset)
        y = dataset.index(x)
        return x, y
    return dataset, KVrand

def exponential(lambda1=1.0):
    u = random.random()
    x = - math.log(u) / lambda1
    return x 

def lognormal(mu=0, sigma=5.0):
    x = random.lognormvariate(mu, sigma)
    return x
  
def uniform(a=0.0,b=10.0):
    x = random.uniform(a,b)
    return x

def gaussian(mu=0, sigma=5.0):
    x = random.gauss(mu, sigma)
    return x

def laplace():
    u = random.random()
    if(u<=0.5):
        x = math.log(2*u)
    else:
        x = -math.log(2-2*u)    
    return x


#Functions for Neural Network Configuration
###########################################

def create_NN(dim=128):
    NN = torch.nn.Sequential(torch.nn.Linear(1, dim),torch.nn.ReLU(),torch.nn.Linear(dim, 1),)
    return NN


def to_tensor(x):   
    return torch.unsqueeze(Variable(torch.Tensor(x)), 1)


#Traditional Search Functions
#############################

def linear_search(x, dataset):
    for idx, n in enumerate(dataset):
        if n > x:
            break
    return idx - 1


def binary_search(x, dataset):
    i = bisect_left(dataset, x)
    if i:
        return i - 1
    raise ValueError

#Main
#####

def main():
    N = 1000
    lr = 0.0001
    batch_no = 0
    LF_x = []
    LF_y = []  
    minloss = N
    
    
    dataset, KVrand = generate_dataset(N)
    NN = create_NN()
    optimizer = torch.optim.Adam(NN.parameters(), lr=lr)
    
#Training
#########

    start = time.time()
    try:
        while True:
            batch_no = batch_no + 1
            batch_x = []; batch_y = []
            for _ in range(256):
                x, y = KVrand()
                batch_x.append(x)
                batch_y.append(y)

            batch_x = to_tensor(batch_x)
            batch_y = to_tensor(batch_y)

            Predicted_idx = NN(batch_x) * N

            output = F.smooth_l1_loss(Predicted_idx, batch_y)
            loss = output.data
                      
            if (minloss>loss.item()):
               minloss=loss.item()
               print('Minloss =',minloss,'at',time.time())

            #print(loss, minloss,'at',time.time())  
            
            LF_x.append(batch_no)
            LF_y.append(loss)
            
            if (loss.item()<1.0):
                break
            
            optimizer.zero_grad()
            output.backward()
            optimizer.step()
    except KeyboardInterrupt:
        pass
    end = time.time()
    TrainingTime = end - start
    print('Time required for Training =', TrainingTime)

#Convergence Plot
#################
    plt.plot(LF_x, LF_y, 'b', label='Learned Index Structure')
    plt.xlabel('Number of Batches')
    plt.ylabel('Loss')
    plt.title('Learning Convergence')
    plt.legend(loc='best')
    plt.show()

    #import pdb
    #pdb.set_trace()

if __name__ == '__main__':
    main()

Minloss = 516.3041381835938 at 1552561163.5412402
Minloss = 503.1219482421875 at 1552561163.545436
Minloss = 468.69677734375 at 1552561163.5524855
Minloss = 466.441650390625 at 1552561163.5690515
Minloss = 440.1553955078125 at 1552561163.5853813
Minloss = 426.336181640625 at 1552561163.6020646
Minloss = 419.918212890625 at 1552561163.6081598
Minloss = 390.0388488769531 at 1552561163.614062
Minloss = 386.1229248046875 at 1552561163.6203244
Minloss = 384.57281494140625 at 1552561163.6394212
Minloss = 384.49053955078125 at 1552561163.6429129
Minloss = 369.75128173828125 at 1552561163.6495514
Minloss = 369.7259826660156 at 1552561163.653085
Minloss = 336.18115234375 at 1552561163.6564503
Minloss = 334.3038024902344 at 1552561163.6728754
Minloss = 300.8546447753906 at 1552561163.6842046
Minloss = 293.1239929199219 at 1552561163.7007108
Minloss = 291.32415771484375 at 1552561163.7093334
Minloss = 286.3512268066406 at 1552561163.7129045
Minloss = 276.8747253417969 at 1552561163.716218
Minloss