In [10]:
import torch.nn.functional as F

def load_model(exp_configure):
    if exp_configure['model'] == 'GCN':
        from dgllife.model import GCNPredictor
        model = GCNPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] * exp_configure['num_gnn_layers'],
            activation=[F.relu] * exp_configure['num_gnn_layers'],
            residual=[exp_configure['residual']] * exp_configure['num_gnn_layers'],
            batchnorm=[exp_configure['batchnorm']] * exp_configure['num_gnn_layers'],
            dropout=[exp_configure['dropout']] * exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'GAT':
        from dgllife.model import GATPredictor
        model = GATPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] * exp_configure['num_gnn_layers'],
            num_heads=[exp_configure['num_heads']] * exp_configure['num_gnn_layers'],
            feat_drops=[exp_configure['dropout']] * exp_configure['num_gnn_layers'],
            attn_drops=[exp_configure['dropout']] * exp_configure['num_gnn_layers'],
            alphas=[exp_configure['alpha']] * exp_configure['num_gnn_layers'],
            residuals=[exp_configure['residual']] * exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks']
        )
    elif exp_configure['model'] == 'Weave':
        from dgllife.model import WeavePredictor
        model = WeavePredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            num_gnn_layers=exp_configure['num_gnn_layers'],
            gnn_hidden_feats=exp_configure['gnn_hidden_feats'],
            graph_feats=exp_configure['graph_feats'],
            gaussian_expand=exp_configure['gaussian_expand'],
            n_tasks=exp_configure['n_tasks']
        )
    elif exp_configure['model'] == 'MPNN':
        from dgllife.model import MPNNPredictor
        model = MPNNPredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            node_out_feats=exp_configure['node_out_feats'],
            edge_hidden_feats=exp_configure['edge_hidden_feats'],
            num_step_message_passing=exp_configure['num_step_message_passing'],
            num_step_set2set=exp_configure['num_step_set2set'],
            num_layer_set2set=exp_configure['num_layer_set2set'],
            n_tasks=exp_configure['n_tasks']
        )
    elif exp_configure['model'] == 'AttentiveFP':
        from dgllife.model import AttentiveFPPredictor
        model = AttentiveFPPredictor(
            node_feat_size=exp_configure['in_node_feats'],
            edge_feat_size=exp_configure['in_edge_feats'],
            num_layers=exp_configure['num_layers'],
            num_timesteps=exp_configure['num_timesteps'],
            graph_feat_size=exp_configure['graph_feat_size'],
            dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks']
        )
    elif exp_configure['model'] in ['gin_supervised_contextpred', 'gin_supervised_infomax',
                                    'gin_supervised_edgepred', 'gin_supervised_masking']:
        from dgllife.model import GINPredictor
        from dgllife.model import load_pretrained
        model = GINPredictor(
            num_node_emb_list=[120, 3],
            num_edge_emb_list=[6, 3],
            num_layers=5,
            emb_dim=300,
            JK=exp_configure['jk'],
            dropout=0.5,
            readout=exp_configure['readout'],
            n_tasks=exp_configure['n_tasks']
        )
        model.gnn = load_pretrained(exp_configure['model'])
        model.gnn.JK = exp_configure['jk']
    elif exp_configure['model'] == 'NF':
        from dgllife.model import NFPredictor
        model = NFPredictor(
            in_feats=exp_configure['in_node_feats'],
            n_tasks=exp_configure['n_tasks'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] * exp_configure['num_gnn_layers'],
            batchnorm=[exp_configure['batchnorm']] * exp_configure['num_gnn_layers'],
            dropout=[exp_configure['dropout']] * exp_configure['num_gnn_layers'],
            predictor_hidden_size=exp_configure['predictor_hidden_feats'],
            predictor_batchnorm=exp_configure['batchnorm'],
            predictor_dropout=exp_configure['dropout']
        )
    else:
        return ValueError("Expect model to be from ['GCN', 'GAT', 'Weave', 'MPNN', 'AttentiveFP', "
                          "'gin_supervised_contextpred', 'gin_supervised_infomax', "
                          "'gin_supervised_edgepred', 'gin_supervised_masking', 'NF'], "
                          "got {}".format(exp_configure['model']))

    return model


In [11]:
import torch
import json
from argparse import ArgumentParser

#parser = ArgumentParser('Testing Loading previous Prop Prediction Model for Multi-label Binary Classification')
"""
parser.add_argument('-f', '--file-path', type=str, required=True,
                    help='Path to a .csv/.txt file of SMILES strings')
parser.add_argument('-sc', '--smiles-column', type=str,
                    help='Header for the SMILES column in the CSV file, can be '
                            'omitted if the input file is a .txt file or the .csv '
                            'file only has one column of SMILES strings')
parser.add_argument('-tp', '--train-result-path', type=str, default='classification_results',
                    help='Path to the saved training results, which will be used for '
                            'loading the trained model and related configurations')
parser.add_argument('-ip', '--inference-result-path', type=str, default='classification_inference_results',
                    help='Path to save the inference results')
parser.add_argument('-t', '--task-names', default=None, type=str,
                    help='Task names for saving model predictions in the CSV file to output, '
                            'which should be the same as the ones used for training. If not '
                            'specified, we will simply use task1, task2, ...')
parser.add_argument('-s', '--soft-classification', action='store_true', default=False,
                    help='By default we will perform hard classification with binary labels. '
                            'This flag allows performing soft classification instead.')
parser.add_argument('-nw', '--num-workers', type=int, default=1,
                    help='Number of processes for data loading (default: 1)')
"""

#args = parser.parse_args().__dict__
args = {}
args['model'] = 'GCN'
args['train_result_path'] = 'M2OR_Uniprot_original_GCN'

with open('data/configures/M2OR/GCN_canonical.json', 'r') as f:
    args.update(json.load(f))

if torch.cuda.is_available():
        args['device'] = torch.device('cuda:0')
else:
        args['device'] = torch.device('cpu')


In [12]:
args

{'model': 'GCN',
 'train_result_path': 'M2OR_Uniprot_original_GCN',
 'lr': 0.02,
 'weight_decay': 0,
 'patience': 30,
 'batch_size': 32,
 'dropout': 0.05,
 'gnn_hidden_feats': 256,
 'predictor_hidden_feats': 128,
 'num_gnn_layers': 2,
 'residual': True,
 'batchnorm': False,
 'device': device(type='cuda', index=0)}

In [22]:
## Load UniProt GCN model for experiment
from dgllife.utils import CanonicalAtomFeaturizer
args['node_featurizer'] = CanonicalAtomFeaturizer()
args['in_node_feats'] = args['node_featurizer'].feat_size()
"""
if args['edge_featurizer'] is not None:
    exp_config['in_edge_feats'] = args['edge_featurizer'].feat_size()
exp_config.update({
    'n_tasks': args['n_tasks'],
    'model': args['model']
})
"""

args['n_tasks'] = 574
args['device'] = torch.device('cuda:0')

model = load_model(args).to(args['device'])
checkpoint = torch.load(args['train_result_path'] + '/model.pth', map_location=args['device'])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()


GCNPredictor(
  (gnn): GCN(
    (gnn_layers): ModuleList(
      (0): GCNLayer(
        (graph_conv): GraphConv(in=74, out=256, normalization=none, activation=<function relu at 0x7f0cf524b310>)
        (dropout): Dropout(p=0.05, inplace=False)
        (res_connection): Linear(in_features=74, out_features=256, bias=True)
      )
      (1): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization=none, activation=<function relu at 0x7f0cf524b310>)
        (dropout): Dropout(p=0.05, inplace=False)
        (res_connection): Linear(in_features=256, out_features=256, bias=True)
      )
    )
  )
  (readout): WeightedSumAndMax(
    (weight_and_sum): WeightAndSum(
      (atom_weighting): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=True)
        (1): Sigmoid()
      )
    )
  )
  (predict): MLPPredictor(
    (predict): Sequential(
      (0): Dropout(p=0.05, inplace=False)
      (1): Linear(in_features=512, out_features=128, bias=True)
      (2): ReLU()
 

In [19]:
device = torch.device('cuda:0')

In [36]:
args['n_tasks'] = 128

In [37]:
from dgllife.model.model_zoo.mlp_predictor import MLPPredictor
gnn_out_feats = model.gnn.hidden_feats[-1]
model.predict = MLPPredictor(2 * gnn_out_feats, args['predictor_hidden_feats'],
                                    args['n_tasks'], dropout = args['dropout'])



In [38]:
model

GCNPredictor(
  (gnn): GCN(
    (gnn_layers): ModuleList(
      (0): GCNLayer(
        (graph_conv): GraphConv(in=74, out=256, normalization=none, activation=<function relu at 0x7f0cf524b310>)
        (dropout): Dropout(p=0.05, inplace=False)
        (res_connection): Linear(in_features=74, out_features=256, bias=True)
      )
      (1): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization=none, activation=<function relu at 0x7f0cf524b310>)
        (dropout): Dropout(p=0.05, inplace=False)
        (res_connection): Linear(in_features=256, out_features=256, bias=True)
      )
    )
  )
  (readout): WeightedSumAndMax(
    (weight_and_sum): WeightAndSum(
      (atom_weighting): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=True)
        (1): Sigmoid()
      )
    )
  )
  (predict): MLPPredictor(
    (predict): Sequential(
      (0): Dropout(p=0.05, inplace=False)
      (1): Linear(in_features=512, out_features=128, bias=True)
      (2): ReLU()
 