In [1]:
import json
from collections import Counter
import pickle
import torch_geometric
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import GATConv
import matplotlib.pyplot as plt
import seaborn as sns
import math
from torch_geometric.data import Data, DataLoader
import random
import time
from earlystopping import EarlyStopping
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import os, psutil

In [2]:
def RAM():
    process = psutil.Process(os.getpid())
    return (process.memory_info().rss/1024/1024/1024)

In [3]:
infile = open('common_Terminals','rb')
common_terminals = pickle.load(infile)
infile.close()

In [4]:
infile = open('map_terminals_val','rb')
map_terminals_val = pickle.load(infile)
infile.close()

In [5]:
infile = open('map_terminals_type','rb')
map_type = pickle.load(infile)
infile.close()

In [6]:
def non_terminal(node):
    if('children' in node):
        return True
    return False

curr_idx = 0
map_idx_val = {}


def dfs(program, idx, flattened_tree, edge_list, map_right_sibling, map_non_terminal_child):
    global curr_idx
    global map_idx_val
    node = program[idx]
    map_idx_val[idx] = curr_idx
    curr_idx += 1
    edge_list.append(torch.tensor([map_idx_val[idx], map_idx_val[idx]]))
    if(non_terminal(node)):
        x = non_terminals_mapping[(map_type[node['type']], map_right_sibling[idx], map_non_terminal_child[idx])]
        flattened_tree.append([x, map_terminals_val['none']])
        for i in node['children']:
            dfs(program, i, flattened_tree, edge_list, map_right_sibling, map_non_terminal_child)
        for i in node['children']:
            edge_list.append(torch.tensor([map_idx_val[idx], map_idx_val[i]]))
            edge_list.append(torch.tensor([map_idx_val[i], map_idx_val[idx]]))
#             edge_list1.append(map_idx_val[idx])
#             edge_list2.append(map_idx_val[i])
#             edge_list1.append(map_idx_val[i]) #Bi-directional edges
#             edge_list2.append(map_idx_val[idx])
            
    else:
        val = ""
        if(node['value'] not in map_terminals_val):
            val = map_terminals_val['UNK']
        else:
            val = map_terminals_val[node['value']]
        x = non_terminals_mapping[(map_type[node['type']], 0, 0)]
        flattened_tree.append([x, val])
            

def construct_Seq(program):
    global curr_idx
    global map_idx_val
    map_right_sibling = {}
    map_non_terminal_child = {}
    tokens = []
    edge_list = []
    flattened_tree = []
    map_right_sibling[0] = False
    for node in program:
        if(non_terminal(node)):
            for i in range(len(node['children'])):
                if(i + 1 == len(node['children'])):
                    map_right_sibling[node['children'][i]] = False
                else:
                    map_right_sibling[node['children'][i]] = True
            non_terminal_child = False
            for i in node['children']:
                if('children' in program[i]):
                    non_terminal_child = True
                    break
            map_non_terminal_child[node['id']] = non_terminal_child
    curr_idx = 0
    map_idx_val = {}
    dfs(program, 0, flattened_tree, edge_list, map_right_sibling, map_non_terminal_child)
    return flattened_tree, edge_list

In [7]:
non_terminals_mapping = {}
c = 0
for i in range(46):
    for j in range(2):
        for k in range(2):
            non_terminals_mapping[(i, j, k)] = c
            c += 1
            
inverse_mapping = {}
for i in non_terminals_mapping:
    inverse_mapping[non_terminals_mapping[i]] = i[0]

In [8]:
print("RAM used is: ", RAM())

RAM used is:  1.1315650939941406


In [9]:
def genProgramsEval():
    with open('programs_eval.json', encoding='latin1') as f:
        for line in f:
            t = json.loads(line)
            t.pop()
            last_elem = t[-1]['id']
            d = {'id' : last_elem + 1, 'type': "EOF", 'value' : "EOF"}
            t.append(d)
            for j in t:
                if("children" not in j and "value" not in j):
                    j['value'] = j['type']
            tree, edges = construct_Seq(t)
            edges = sorted(edges, key = lambda x : max(x[0], x[1]))
            edges = torch.stack(edges, dim = 0)
            #print(edges.shape)
            #print(tree, edge_list)
            tree = torch.tensor(tree, dtype = torch.long)
            #train_data.append(tree)
            #edge_list.append(edges)
            yield tree, edges

In [None]:
def genPrograms():
    with open('programs_training.json', encoding='latin1') as f:
        for line in f:
            t = json.loads(line)
            t.pop()
            last_elem = t[-1]['id']
            d = {'id' : last_elem + 1, 'type': "EOF", 'value' : "EOF"}
            t.append(d)
            for j in t:
                if("children" not in j and "value" not in j):
                    j['value'] = j['type']
            tree, edges = construct_Seq(t)
            edges = sorted(edges, key = lambda x : max(x[0], x[1]))
            edges = torch.stack(edges, dim = 0)
            #print(edges.shape)
            #print(tree, edge_list)
            tree = torch.tensor(tree, dtype = torch.long)
            #train_data.append(tree)
            #edge_list.append(edges)
            yield tree, edges

In [10]:
def constructTensorfromProgram(data, edges):
    #print(edges)
    res = torch.split(data, 50)
    res = list(res)
    edge_list = []
    
    c = 1
    curr_idx = []
    i = 0
    min_val = 0
    while(i < edges.shape[0]):
        if(max(edges[i][0], edges[i][1]) >= min_val + 48):
            edge_list.append(torch.stack(curr_idx, dim = 0))
            curr_idx = []
            min_val += 49
        else:
            while(i < edges.shape[0] and min(edges[i][0], edges[i][1]) <= min_val):
                i += 1
            if(i < edges.shape[0]):
                curr_idx.append(edges[i])
        i += 1
        
    if(len(curr_idx) > 0):
        edge_list.append(torch.stack(curr_idx, dim = 0))
    
    
    if(res[-1].shape[0] != 50):
        x = non_terminals_mapping[(map_type['EOF'], 0, 0)]
        dummy_tensor = torch.tensor([x, map_terminals_val['none']])
        dummy_tensor = dummy_tensor.repeat((50 - res[-1].shape[0], 1))
        dummy_tensor = torch.cat([res[-1], dummy_tensor])
        res = res[:-1]
        res.append(dummy_tensor)
    res = torch.stack(res, dim = 0)
    ans = res[:, 49, :]
    res = res[:, :-1, :]
    ans[:, 0].apply_(lambda x : inverse_mapping[x])
    return res, ans, edge_list

In [11]:
SEQ_LEN = 49
NUM_CLASSES_VALUE = len(map_terminals_val) + 1
NUM_CLASSES_TYPE = len(non_terminals_mapping) + 1
WORD_EMBEDDING_DIM1 = 128
WORD_EMBEDDING_DIM2 = 64
HIDDEN_SIZE = 512
NUM_HEADS = 1


class Graph2CodeNet(torch.nn.Module):
    def __init__(self):
        super(Graph2CodeNet, self).__init__()
        self.embeddingLayervalue = nn.Embedding(NUM_CLASSES_VALUE, WORD_EMBEDDING_DIM1)
        self.embeddingLayertype = nn.Embedding(NUM_CLASSES_TYPE, WORD_EMBEDDING_DIM2)
        self.GRU = nn.GRU(input_size =WORD_EMBEDDING_DIM1 + WORD_EMBEDDING_DIM2, hidden_size = HIDDEN_SIZE, 
                          batch_first = True, num_layers = 2, bidirectional = True)
        #self.linearTerminal = nn.Linear(in_features = HIDDEN_SIZE * 4, out_features = 10002)
        self.att1 = GATConv(HIDDEN_SIZE * 2, HIDDEN_SIZE * 4, NUM_HEADS)
        self.att2 = GATConv(HIDDEN_SIZE * 4, HIDDEN_SIZE * 4, NUM_HEADS)
        self.linearNon_Terminal = nn.Linear(in_features = HIDDEN_SIZE * 4, out_features = 46)
        
        

    def forward(self, x, h_0, edges):
        typeEmbedded = self.embeddingLayertype(x[:,:,0])
        valEmbedded = self.embeddingLayervalue(x[:,:,1])
        x = torch.cat((typeEmbedded, valEmbedded), 2)
        h_seq, h_n = self.GRU(x, h_0)
        h_seq = h_seq.view(49, -1)
        edges = edges.t().contiguous()
        #print(edges)
        h_seq = self.att1(h_seq, edges)
        h_seq = self.att2(h_seq, edges)
        #HERE ONWARDS
        h_out = h_seq[-1, :].view(1, -1)
        out = self.linearNon_Terminal(h_out)
        #print(out.shape)
        #print(h_out.shape, h_0.shape)
        return out, h_out.view(4, 1, 512)
    
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
#model = Graph2CodeNet()
#device = torch.device('cuda:0')
#model = model.to(device)
# if torch.cuda.device_count() > 1:
#     print("Let's use", torch.cuda.device_count(), "GPUs!")
#   # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
#     model = nn.DataParallel(model)

# model.to(device)

In [14]:
#!shuf programs_training.json -o shuffled_programs_training.json

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
BATCH_SIZE = 128

train_losses = 0
avg_train_losses = []

start = time.time()
count = 0
idx = 0

for train_data, edge_list in genPrograms():
    h_0 = torch.zeros(4, 1, 512)
    h_0 = h_0.to(device)
    res, ans, edges = constructTensorfromProgram(train_data, edge_list)
    for j in range(res.shape[0]):
        x = res[j, :,:].view(1, 49, -1)
        x = x.to(device)
        #print(idx, j)
        edge = edges[j]
        edge = edge.to(device)
        out_N, h_0 = model(x, h_0, torch.sub(edge, other = 49, alpha = j))
        count += 1       
        ans1 = ans[j, 0:1]
        ans1 = ans1.to(device)
        loss = criterion(out_N, ans1) #ANS IS TO BE PROCESSED
        loss.backward()
        nn.utils.clip_grad_value_(model.parameters(), clip_value = 5.0)
        train_losses += loss.item()
        if(count % BATCH_SIZE == 0):
            avg_train_losses.append(train_losses / BATCH_SIZE)
            train_losses = 0
            optimizer.step()
            optimizer.zero_grad()
        h_0 = h_0.detach()
        #print(j)
        if(j >= 100):
            break
    #print(idx)
    if(idx == 20000):
        break
    if(idx % 1000 == 0):
        print(f"Train time is: {(time.time() - start)/60} -> {idx*100 / 20000}%")
    idx += 1
        
plt.plot(avg_train_losses)
end = time.time()
print("Time taken:", (end - start)/60)

In [None]:
torch.save(model.state_dict(), './GNN3')

Load a previous ran model.

In [15]:
model = Graph2CodeNet()
model.load_state_dict(torch.load('./GNN2', map_location=device))

<All keys matched successfully>

In [18]:
model = model.to(device)

In [19]:
print(model)

Graph2CodeNet(
  (embeddingLayervalue): Embedding(10002, 128)
  (embeddingLayertype): Embedding(185, 64)
  (GRU): GRU(192, 512, num_layers=2, batch_first=True, bidirectional=True)
  (att1): GATConv(1024, 2048, heads=1)
  (att2): GATConv(2048, 2048, heads=1)
  (linearNon_Terminal): Linear(in_features=2048, out_features=46, bias=True)
)


In [22]:
with torch.no_grad():
    total = 0
    idx = 0
    correct = 0
    for train_data, edge_list in genProgramsEval():
        h_0 = torch.zeros(4, 1, 512)
        h_0 = h_0.to(device)
        res, ans, edges = constructTensorfromProgram(train_data, edge_list)
        for j in range(res.shape[0]):
            x = res[j, :,:].view(1, 49, -1)
            x = x.to(device)
            #print(idx, j)
            edge = edges[j]
            edge = edge.to(device)
            out_N, h_0 = model(x, h_0, torch.sub(edge, other = 49, alpha = j))
            total += 1       
            ans1 = ans[j, 0:1]
            ans1 = ans1.to(device)
            out_N = torch.argmax(out_N, dim = 1) #ANS IS TO BE PROCESSED
            correct += sum(out_N == ans1).item()
            if(total >= 10000):
                break
        if(total >= 10000):
            break
        idx += 1

    print(f"Validation set prediction acccuracy is: {100*correct/total}%")

Validation set prediction acccuracy is: 74.03%
