In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
os.chdir('..')

In [2]:
from modules.function import get_elem_count, alt_read_gfa_dataset, check_cuda, get_metrics, image, read_gfa_dataset
from modules.representation_schemes import get_atomic_number_features, get_pettifor_features, get_modified_pettifor_features, get_random_features, get_random_features_dense, random_order_alpha, get_1D_features_gfa, get_dense_features_gfa
from modules.encoder import Encoder1D, EncoderDNN, Encoder
import re
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm
import joblib
import random
import torch
import pickle
from decimal import Decimal
import pandas as pd
import numpy as np
import json
from sklearn.model_selection import KFold, StratifiedKFold

In [3]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [6]:
np.random.seed(0)
kfold_loc = 'misc/gfa_kfold.pkl'
create_new = False
if os.path.exists(kfold_loc) and not create_new:
    with open(kfold_loc,'rb') as fid:
        fold_dict = pickle.load(fid)
    print('Fold dictionary loaded!')
else:
    fold_dict = {}
    X, y, _ = get_dense_features_gfa()
    kfold = StratifiedKFold(n_splits = 10, random_state = 0, shuffle=True)
    for i,(train_index, test_index) in enumerate(kfold.split(X, y)):
        #print("TRAIN:", train_index, "TEST:", test_index)
        fold_dict[i] = {'train_inds':train_index, 'test_inds':test_index}
    with open(kfold_loc,'wb') as fid:
        pickle.dump(fold_dict,fid)
    print('Fold dictionary created!')

Fold dictionary loaded!


In [12]:
saveloc = 'saved_models/Encoders'
if not os.path.exists(saveloc):
    os.makedirs(f'{saveloc}')
methods = ['random','random-tr','pettifor','mod_pettifor','PTR']
#methods = ['dense','atomic','random','random-tr','pettifor','mod_pettifor','PTR']
if os.path.exists('results/gfa_predict_results.json'):
    with open('results/gfa_predict_results.json','rb') as fid:
        sup_metrics_dict = json.load(fid)
else:
    sup_metrics_dict = {}
for method in methods:
    print('Method : {}'.format(method))
    metrics_list = {}
    if method == 'dense':
        X, y, p = get_dense_features_gfa()    
    elif method in ['atomic','pettifor','mod_pettifor','random']:
        X, y, p = get_1D_features_gfa(method)    
    elif method == 'PTR':
        X, y, p = read_gfa_dataset()
    for i in fold_dict.keys():
        i_tr, i_te = fold_dict[i]['train_inds'], fold_dict[i]['test_inds']
        X_train, X_test = X[i_tr], X[i_te]
        y_train, y_test = y[i_tr], y[i_te]
        p_train, p_test = p[i_tr], p[i_te]
        batch = 64
        Xy = [(X_train[i],y_train[i],p_train[i]) for i in range(len(y_train))]
        train_loader = DataLoader(Xy, batch_size = batch , shuffle=True)
        if method in ['atomic','pettifor','mod_pettifor','random','random-tr']:
            type = 0
            encoder = Encoder1D(1,1)
        elif method == 'dense':
            type = 1
            encoder = EncoderDNN(X.shape[-1],3,42,1)
        else:
            type = 2
            encoder = Encoder(1,1)
        e_optimizer = optim.Adam(encoder.parameters(),lr = 2e-4)
        num_iterations = 2000
        cuda = check_cuda()
        if cuda:
            encoder = encoder.cuda()
        log_interval = int(5e2)
        for iter in tqdm.notebook.tqdm(range(num_iterations)):
            train_loss = 0.0
            for data in train_loader:
                X_temp, y_temp, p_temp = data
                if cuda:
                    X_temp = X_temp.cuda()
                    y_temp = y_temp.cuda()
                    p_temp = p_temp.cuda()
                e_optimizer.zero_grad()
                target = encoder(X_temp,p_temp)
                if cuda:
                    target = target.cuda()
                e_error = torch.nn.BCELoss()(target,y_temp)
                e_error.backward(retain_graph=True)
                e_optimizer.step()
                train_loss += e_error.cpu().item()
            if iter == 0 or (iter + 1) % log_interval == 0:  
                print('Epoch : {}, Loss : {}'.format(iter+1,train_loss))
        spec_saveloc = os.path.join(saveloc,method)
        if not os.path.exists(spec_saveloc):
            os.makedirs(f'{spec_saveloc}')
        model_scripted = torch.jit.script(encoder.cpu())
        model_scripted.save(os.path.join(spec_saveloc,'Encoder{}D_{}_fold{}.pt'.format(type,method,i)))
        if X_test.dtype != torch.float32:
            X_test = torch.from_numpy(X_test)
        if p_test.dtype != torch.float32:
            p_test = torch.from_numpy(p_test)
        with torch.no_grad():
            y_predict = (encoder(X_test,p_test)).to('cpu').detach().numpy()
        metrics = get_metrics(y_test,np.round(y_predict))
        metrics_list[i] = metrics
        print('accuracy : {},precision : {},recall : {},F1 : {}'.format(metrics[0],metrics[1],metrics[2],metrics[3]))
    sup_metrics_dict[method] = metrics_list
    with open('results/gfa_predict_results.json','w') as f:
        json.dump(sup_metrics_dict,f)

Method : random


  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch : 1, Loss : 182.2973966896534
Epoch : 500, Loss : 53.34674897044897
Epoch : 1000, Loss : 43.749323051422834
Epoch : 1500, Loss : 38.012205604463816
Epoch : 2000, Loss : 33.06775265187025
accuracy : 0.9387,precision : 0.94,recall : 0.9387,F1 : 0.9391


  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch : 1, Loss : 183.68972831964493
Epoch : 500, Loss : 48.916234750300646
Epoch : 1000, Loss : 38.17846620082855
Epoch : 1500, Loss : 32.83130997419357
Epoch : 2000, Loss : 29.041188701987267
accuracy : 0.9463,precision : 0.9463,recall : 0.9463,F1 : 0.9463


  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch : 1, Loss : 184.38523265719414
Epoch : 500, Loss : 51.61269997432828
Epoch : 1000, Loss : 43.10521576553583
Epoch : 1500, Loss : 38.169378567487
Epoch : 2000, Loss : 34.682810032740235
accuracy : 0.932,precision : 0.9318,recall : 0.932,F1 : 0.9311


  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch : 1, Loss : 182.16963243484497
Epoch : 500, Loss : 47.06188079342246
Epoch : 1000, Loss : 36.04837105423212
Epoch : 1500, Loss : 30.444127559661865
Epoch : 2000, Loss : 27.06372257322073
accuracy : 0.9459,precision : 0.9464,recall : 0.9459,F1 : 0.9461


  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch : 1, Loss : 183.01399743556976
Epoch : 500, Loss : 43.074997156858444
