In [None]:
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim

In [None]:
import pickle
import pdb
import json
import numpy as np

from utils import *
from GCN.GCN import GCN
from GAT.GAT import GAT
from Dense.Dense import Dense
from full_data_process import graphDataProcess
from sub_data_process import subGraphProcess

from adjacency_functions import *
from feature_functions import *
from label_functions import *
from train_test_functions import *

In [None]:
def full_graph_process(param_dict, data_dir, full_processed_path, full_redo):
    if (not os.path.exists(full_processed_path)) or full_redo:
        raw_data_path = param_dict['gen_params']['raw_data_path']
        full_names_dict = param_dict['full_names_dict']
        full_redo_dict = param_dict['full_redo_dict']
        graph_data_obj = graphDataProcess(raw_data_path, data_dir, full_names_dict, full_redo_dict)
        graph_data_obj.run_all()
        pickle.dump(graph_data_obj, open(full_processed_path,'wb'))
    else:
        graph_data_obj = pickle.load(open(full_processed_path,'rb'))
    return graph_data_obj

def sub_graph_process(param_dict, data_path, full_processed_path, sub_processed_path, sub_redo):
    if (not os.path.exists(sub_processed_path)) or sub_redo:
        sampling_params = param_dict['sampling_params']
        sub_names_dict = param_dict['sub_names_dict']
        sub_redo_dict = param_dict['sub_redo_dict']
        sub_functions_dict = get_func_dict(param_dict['sub_functions_dict'])
        subgraph_data_obj = subGraphProcess(full_processed_path, data_path, sub_names_dict, sub_redo_dict, sub_functions_dict, sampling_params)
        subgraph_data_obj.run_all()
        pickle.dump(subgraph_data_obj, open(sub_processed_path,'wb'))
    else:
        subgraph_data_obj = pickle.load(open(sub_processed_path,'rb'))
    return subgraph_data_obj

In [None]:
ver = 'v0.1'
param_path = '/home/ds-team/aaron/other/MoonBoard/data/train_test/pytorch/graphNet/GraphNet/' + ver + '/params.json'
param_dict = json.load(open(param_path,'r'))

# Unwrap and set general parameters

In [None]:
model_type = param_dict['gen_params']['model_type']
ver = param_dict['gen_params']['ver']
data_dir = param_dict['gen_params']['data_dir']
result_dir = param_dict['gen_params']['result_dir']

data_path, result_path = set_paths(model_type, ver, data_dir, result_dir)

full_processed_name = param_dict['gen_params']['full_processed_name']
sub_processed_name = param_dict['gen_params']['sub_processed_name']

full_processed_path = data_dir+full_processed_name
sub_processed_path = data_path+sub_processed_name

full_redo = param_dict['gen_params']['full_redo']
sub_redo = param_dict['gen_params']['sub_redo']

In [None]:
graph_data_obj = full_graph_process(param_dict, data_dir, full_processed_path, full_redo)

In [None]:
subgraph_data_obj = sub_graph_process(param_dict, data_path, full_processed_path, sub_processed_path, sub_redo)

In [None]:
split_ratio_dict = param_dict['split_ratio_dict']
target_grade = -1
features, adj, labels, idx_train, idx_dev, idx_test = sample_and_load_pytorch_data(subgraph_data_obj, split_ratio_dict, result_path, target_grade, sub_redo)

In [None]:
num_labels = len(list(set(list(np.asarray(labels)))))

In [None]:
dense_params = param_dict['dense_params']
if dense_params:
    num_epochs = dense_params['num_epochs']
    model = Dense(nfeat=features.shape[1],
                nhid_list=dense_params['hidden'],
                nclass=num_labels,
                dropout=dense_params['dropout'])

    optimizer = optim.Adam(model.parameters(),lr=dense_params['lr'], weight_decay=dense_params['weight_decay'])

In [None]:
gcn_params = param_dict['gcn_params']
if gcn_params:
    num_epochs = gcn_params['num_epochs']
    model = GCN(nfeat=features.shape[1],
                nhid_list=gcn_params['hidden'],
                nclass=num_labels,
                dropout=gcn_params['dropout'])

    optimizer = optim.Adam(model.parameters(),lr=gcn_params['lr'], weight_decay=gcn_params['weight_decay'])

In [None]:
gat_params = param_dict['gat_params']
if gat_params['on']:
    num_epochs = gat_params['num_epochs']
    model = GAT(nfeat=features.shape[1],
                nhid=gat_params['hidden'],
                nclass=num_labels,
                dropout=gat_params['dropout'],
                alpha=gat_params['alpha'],
                nheads=gat_params['nb_heads'])

    optimizer = optim.Adam(model.parameters(),lr=gat_params['lr'], weight_decay=gat_params['weight_decay'])

In [None]:
model

In [None]:
train_dict = {}
train_dict['optimizer'] = optimizer
train_dict['features'] = features
train_dict['adj'] = adj
train_dict['labels'] = labels
train_dict['idx_train'] = idx_train
train_dict['idx_val'] = idx_dev
train_dict['num_epochs'] = num_epochs
model = run_train(model, train_dict)

In [None]:
# Testing
test_dict = {}
test_dict['features'] = features
test_dict['adj'] = adj
test_dict['labels'] = labels
test_dict['idx_test'] = idx_test
test(model, test_dict)

In [None]:
model_name = 'model.pickle'
pickle.dump(model, open(result_path+model_name,'wb'))