In [1]:
import random, math, time
import igraph as ig
from fractions import Fraction
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

In [2]:
input_dim = 22
output_dim = 100
normal_factor = int(3e3)
upper_cnt = int(1.5e3)
upper_cnt2 = int(1e3)

In [3]:
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.fc1 = nn.Linear(self.input_dim, output_dim*2)
        self.fc2 = nn.Linear(output_dim*2, output_dim*2)
        self.fc_pi = nn.Linear(output_dim*2, self.output_dim)
        
        
    def pi(self, x, softmax_dim = 1):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def action(self, state):
        
        state   = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            prob = self.pi(state).view(-1)
        return prob
    
    def load_models(self, name=None):
        self.load_state_dict(torch.load(name + '_net.pt'))
        print ('Models loaded succesfully')

In [4]:
agent = net()
agent.load_models('sq3')

Models loaded succesfully


In [5]:
sq = lambda x : int(np.floor(np.sqrt(x)))

In [6]:
class tree_operator:
    def __init__(self, agent, init_n, init_ptn, due=1):
        self.agent = agent
        self.due = due
        self.records = [] 
        self.G = ig.Graph()
        self.G.add_vertices(1)
        self.G.vs[0]['wt1'] = init_n
        self.G.vs[0]['wt2'] = sum(Fraction(1, sq(pt)) for pt in init_ptn) 
        self.G.vs[0]['count'] = 0
        
        self.G.vs[0]['score'] = 0
        self.init_n = init_n
        self.init_ptn = init_ptn
        self.pos = 0 
        self.path = [0]
        self.counter = 0
        
    def reset(self):
        self.pos = 0
        self.path = [0]
        pass
    
    def expand(self):
        t_wt1 = self.G.vs[self.pos]['wt1']
        t_wt2 = self.G.vs[self.pos]['wt2']
        upper_n = sq(t_wt1)
        min_n =  int(np.ceil(1/(1-t_wt2))) if t_wt2 !=1 else 1
        
        ptn = [ self.G.vs[v]['wt1'] for v in self.path]
        r_ptn = [ptn[i]-ptn[i+1] for i in range(len(ptn)-1)]
        st = [ (self.init_n-sum(r_ptn))/normal_factor]
        r_ptn = self.init_ptn + r_ptn
        xx = sum(Fraction(1, sq(pt)) for pt in r_ptn)
        x, y = xx.denominator, xx.numerator
        st += [y/x] 
        st += [sq(pt)/output_dim for pt in r_ptn] 
        st += [ 0 for _ in range(input_dim-len(st))]
        st = np.array(st, dtype=np.float64)
        pi = self.agent.action(st)
        if min(output_dim, upper_n) >= min_n:
            new_v = min_n  + Categorical(pi[min_n-1:upper_n]).sample().item()
        else:
            new_v = upper_n
        
        new_v = new_v**2
        
        
                
        child = [v for v in self.G.neighbors(self.pos) if v>self.pos]
        if (self.G.vs[self.pos]['wt1']-new_v) not in [self.G.vs[v]['wt1'] for v in child]:
            self.G.add_vertices(1)
            add_e = (self.pos, len(self.G.vs)-1)
            self.G.add_edges([add_e])
            self.G.vs[add_e[1]]['score'] =0
            self.G.es[-1]['count'] = 0
            self.G.vs[add_e[1]]['wt1'] = self.G.vs[add_e[0]]['wt1'] - new_v
            self.G.vs[add_e[1]]['wt2'] = self.G.vs[add_e[0]]['wt2'] + Fraction(1,sq(new_v)) if new_v !=0 else self.G.vs[add_e[0]]['wt2']
            self.G.vs[add_e[1]]['count'] = 0
            pass
    
    
    def select(self):
        if len(self.G.incident(self.pos)) < min(15, sq(self.G.vs[self.pos]['wt1'])+1):
            self.expand()
        self.counter +=1
        
        child_e = [e for e in self.G.incident(self.pos) if self.G.es[e].target > self.pos]
        child = [self.G.es[e].target for e in child_e]
        prob = [100 /(1 + self.G.es[e]['count']) + np.sqrt(self.counter)*self.G.vs[child[i]]['score']/max(1, self.G.vs[child[i]]['count']) \
                for i, e in enumerate(child_e) ]
        
        prob = np.array(prob, dtype=np.float64)
        prob = prob / np.sum(prob)
        choosen_v = np.random.choice(child, p = prob)
        return choosen_v
    
    def move(self, a):
        e_id = self.G.get_eid(self.pos, a)
        self.G.es[e_id]['count'] +=1
        self.pos = a 
        self.path.append(a)
        
        if self.G.vs[self.pos]['wt1'] <= 0 or self.G.vs[self.pos]['wt2'] >= 1:
            wt1 = self.G.vs[self.pos]['wt1']
            wt2 = self.G.vs[self.pos]['wt2']
            score = 0
            if wt1==0 and wt2==1:
                ptn = [ self.G.vs[v]['wt1'] for v in self.path]
                r_ptn = self.init_ptn + [ptn[i]-ptn[i+1] for i in range(len(ptn)-1)]
                self.records.append(r_ptn)
                score +=1
            
            score += .2*int(wt1==0) +.8*int(wt2==1)
            
                
            for v in self.path:
                self.G.vs[v]['count'] += 1
                self.G.vs[v]['score'] += score
                
                
            self.reset()
    
    def run(self):
        st = time.time()
        while self.G.vs[0]['count'] < self.due:
            self.move(self.select())
        
        arcs = list(self.G.incident(0))
        ctr = np.argsort([self.G.es[e]['count'] for e in arcs])
        harcs = [arcs[k] for k in ctr[-min(10,len(arcs)):]]

        t_prob = {}
        #t_v = 0
        for e in harcs:
            pt = self.init_n - self.G.vs[self.G.es[e].target]['wt1']
            t_prob.setdefault(pt,0)
            t_prob[pt] += self.G.es[e]['count']
            #t_v = max(self.G.vs[self.G.es[e].target]['score'],t_v)
        r_prob = { a**2 : t_prob.get(a**2,0) for a in range(1,sq(self.init_n)+1)}
        r_prob = np.array(list(r_prob.values()), dtype=np.float64)
        r_prob /= np.sum(r_prob)
        return r_prob

In [7]:
def grow(target_n, ptn, due):
    founds = []
    while sum(ptn) < target_n and sum(Fraction(1, sq(pt)) for pt in ptn) <1:
        wt1 = target_n - sum(ptn)
        tt = tree_operator(agent, wt1, ptn, max(due*(1-len(ptn)*.1),100))
        prob = tt.run()
        founds += tt.records
        pt = np.random.choice([i for i in range(1, sq(wt1)+1)], p = prob)
        ptn.append(pt**2)
    s_founds = []
    for ptn in founds:
        if sorted(ptn) not in s_founds:
            s_founds.append(sorted(ptn))
    return ptn, s_founds

In [8]:
def finder(target_n):
    st = time.time()
    e_data = {'count': [],'size':[], 'prob':[], 'founds':[], 'f_founds': []}
    e_data['f_time'] = 0
    founds = []
    tries  = []
    hopes  = []
    for _ in range(4):
        tt = tree_operator(agent, target_n, [], upper_cnt)
        prob   = tt.run()
        e_data['count'].append(tt.G.vs[0]['count'])
        e_data['size'].append(len(tt.G.vs))
        e_data['prob'].append(prob)
        founds += tt.records
        if founds != [] and e_data['f_time'] ==0:
            e_data['f_time'] = time.time()-st
        
        
        for _ in range(3):
            pt = np.random.choice([i for i in range(1, sq(target_n)+1)], p = prob)
            tries.append(pt**2)
    
    for pt in tries:
        a, b = grow(target_n, [pt], upper_cnt2)
        hopes.append(a)
        founds += b
        if founds != [] and e_data['f_time'] ==0:
            e_data['f_time'] = time.time()-st
    
    for ptn in hopes:
        if sum(ptn)==target_n and sum(Fraction(1,sq(pt)) for pt in ptn)==1:
            founds.append(ptn)
            if e_data['f_time'] ==0:
                e_data['f_time'] = time.time()-st
        
        else:
            e_data['f_founds'].append(sorted(ptn))
            
    s_founds = []
    for ptn in founds:
        if sorted(ptn) not in s_founds:
            s_founds.append(sorted(ptn))
    e_data['founds'] = s_founds[:]
    e_data['time'] = time.time()-st
    return e_data

In [9]:
FF = finder(4729)
print("found square Graham partitions : \n %s" %FF['founds'])
print("total running time : %.2f" %FF['time'])
print("time to find the first one : %.2f" %FF['f_time'])

found square Graham partitions : 
 [[16, 25, 100, 100, 100, 144, 144, 900, 1600, 1600], [64, 100, 100, 100, 100, 100, 144, 144, 225, 576, 576, 900, 1600]]
total running time : 152.94
time to find the first one : 35.17
