Testing out ml_collections.ConfigDict() to make saving and reloading agent and network objects easier.

In [None]:
import ml_collections

cfg = ml_collections.ConfigDict()
cfg.float_field = 12.6
cfg.integer_field = 123
cfg.another_integer_field = 234
cfg.nested = ml_collections.ConfigDict()
cfg.nested.string_field = 'tom'

print(cfg)

In [None]:
%load_ext autoreload
%autoreload
from retro_branching.agents import REINFORCEAgent
from retro_branching.networks import BipartiteGCN

policy_network = BipartiteGCN(device='cpu',
                           emb_size=64,
                           num_rounds=1,
                           cons_nfeats=5,
                           edge_nfeats=1,
                           var_nfeats=20, # 19 20 (if using filter network)
                           aggregator='add')
filter_network = BipartiteGCN(device='cpu',
                              emb_size=128,
                              num_rounds=2,
                              cons_nfeats=5,
                              edge_nfeats=1,
                              var_nfeats=19,
                              aggregator='mean')
rlgnn_agent = REINFORCEAgent(policy_network=policy_network, 
                             filter_network=filter_network,
                             device='cpu', 
                             temperature=1.0,
                             name='rl_gnn',
                             filter_method='method_2')

rlgnn_config = rlgnn_agent.create_config()
print(rlgnn_config)

In [None]:
import json

json_rlgnn_config = rlgnn_config.to_json_best_effort()
print(json_rlgnn_config)

with open('rl_gnn_config.json', 'w') as f:
    json.dump(json_rlgnn_config, f)

In [None]:
import ml_collections

with open('rl_gnn_config.json', 'r') as f:
    json_config = json.load(f)
    config = ml_collections.ConfigDict(json.loads(json_config))
print(config)

In [None]:
%autoreload
loaded_agent = REINFORCEAgent(device='cpu',
                              config='rl_gnn_config.json')
print(loaded_agent.policy_network.var_nfeats)

# Generating, saving, and loading config.json for networks and agents

Some older networks we've saved only had .pkl state dict of network saved rather than above ml_collections JSON, which is cumbersome to re-initialise each time you want to load the model. Use the below to generate and save ml_collections JSON configs for these networks so can easily re-initialise there after.

In [None]:
%load_ext autoreload
%autoreload
from retro_branching.networks import BipartiteGCN
from retro_branching.agents import REINFORCEAgent

import glob
import json
import torch
import ml_collections

device = 'cpu'

## gnn

In [None]:
graph_networks = {f'gnn_{i}': {'checkpoints': 'all', # np.arange(1, 225, 25)
                               'emb_size': 64,
                               'num_rounds': 1,
                               'cons_nfeats': 5,
                               'edge_nfeats': 1,
                               'var_nfeats': 19,
                               'aggregator': 'add'} for i in [1]}

for gnn in graph_networks.keys():
    if graph_networks[gnn]['checkpoints'] == 'all':
        # load all checkpoints
        path = f'/scratch/datasets/retro_branching/supervised_learner/gnn/{gnn}/'
        graph_networks[gnn]['checkpoints'] = [int(p.split('/')[-1].split('_')[-1]) for p in glob.glob(path+'checkpoint_*')]
        
    for cp in graph_networks[gnn]['checkpoints']:
        # load gnn from state dict
        foldername = path + f'/checkpoint_{cp}'
        graph_network = BipartiteGCN(device,
                                    emb_size=graph_networks[gnn]['emb_size'],
                                    num_rounds=graph_networks[gnn]['num_rounds'],
                                    cons_nfeats=graph_networks[gnn]['cons_nfeats'],
                                    edge_nfeats=graph_networks[gnn]['edge_nfeats'],
                                    var_nfeats=graph_networks[gnn]['var_nfeats'],
                                    aggregator=graph_networks[gnn]['aggregator'])
        graph_network.load_state_dict(torch.load(f'/scratch/datasets/retro_branching/supervised_learner/gnn/{gnn}/checkpoint_{cp}/trained_params.pkl', map_location=device))
        
        # generate and save config json file                                         
        config = graph_network.create_config().to_json_best_effort()
        with open(foldername+'/config.json', 'w') as f:
            json.dump(config, f)
        
        print(f'Created {gnn} config.json and saved to {foldername}')

In [None]:
# load from config example
config = f'/scratch/datasets/retro_branching/supervised_learner/gnn/{gnn}/checkpoint_{cp}/config.json'
print(config)
loaded_graph_network = BipartiteGCN(device=device, config=config)
loaded_graph_network.load_state_dict(torch.load(f'/scratch/datasets/retro_branching/supervised_learner/gnn/{gnn}/checkpoint_{cp}/trained_params.pkl', map_location=device))

## rl_gnn

In [None]:
policy_networks = {f'rl_gnn_{i}': {'filter_network': None, # None 'gnn_235
                                   'filter_method': 'method_2',
                                   'checkpoints': 'all', # np.arange(1, 225, 25)
                                   'emb_size': 64,
                                   'num_rounds': 1,
                                   'cons_nfeats': 5,
                                   'edge_nfeats': 1,
                                   'var_nfeats': 19, # 20 (if filter_network is not None) 19
                                   'aggregator': 'add'} for i in [578]}

filter_networks = {f'gnn_{i}': {'checkpoint': cp,
                                'emb_size': 128,
                                'num_rounds': 2,
                                'cons_nfeats': 5,
                                'edge_nfeats': 1,
                                'var_nfeats': 19,
                                'aggregator': 'add'} for i, cp in zip([261], [58])}

for rl_gnn in policy_networks.keys():
    if policy_networks[rl_gnn]['checkpoints'] == 'all':
        # load all checkpoints
        path = f'/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/{rl_gnn}/'
        policy_networks[rl_gnn]['checkpoints'] = [int(p.split('/')[-1].split('_')[-1]) for p in glob.glob(path+'checkpoint_*')]
    
    for cp in policy_networks[rl_gnn]['checkpoints']:
        # load gnn from state dict
        foldername = path + f'/checkpoint_{cp}'
        
        # load policy network
        policy_network = BipartiteGCN(device,
                                    emb_size=policy_networks[rl_gnn]['emb_size'],
                                    num_rounds=policy_networks[rl_gnn]['num_rounds'],
                                    cons_nfeats=policy_networks[rl_gnn]['cons_nfeats'],
                                    edge_nfeats=policy_networks[rl_gnn]['edge_nfeats'],
                                    var_nfeats=policy_networks[rl_gnn]['var_nfeats'],
                                    aggregator=policy_networks[rl_gnn]['aggregator'])
        policy_network.load_state_dict(torch.load(f'/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/{rl_gnn}/checkpoint_{cp}/trained_params.pkl', map_location=device))
    
        # load filter network (if applicable)
        if policy_networks[rl_gnn]['filter_network'] is not None:
            # init filter network
            filter_name = policy_networks[rl_gnn]['filter_network']
            filter_network = BipartiteGCN(device,
                                    emb_size=filter_networks[filter_name]['emb_size'],
                                    num_rounds=filter_networks[filter_name]['num_rounds'],
                                    cons_nfeats=filter_networks[filter_name]['cons_nfeats'],
                                    edge_nfeats=filter_networks[filter_name]['edge_nfeats'],
                                    var_nfeats=filter_networks[filter_name]['var_nfeats'],
                                    aggregator=filter_networks[filter_name]['aggregator'])
            filter_network.load_state_dict(torch.load('/scratch/datasets/retro_branching/supervised_learner/gnn/{}/checkpoint_{}/trained_params.pkl'.format(filter_name, filter_networks[filter_name]['checkpoint']), map_location=device))
            
            # save filter network state in same dir as where config will be saved
            filename = foldername+'/filter_params.pkl'
            torch.save(filter_network.state_dict(), filename)
            
            
        else:
            filter_network = None
        
        # init agent
        agent = REINFORCEAgent(policy_network=policy_network, filter_network=filter_network, device=device, name=rl_gnn, filter_method=policy_networks[rl_gnn]['filter_method'])
        
        # generate and save config json file                                         
        config = agent.create_config().to_json_best_effort()
        with open(foldername+'/config.json', 'w') as f:
            json.dump(config, f)
        
        print(f'Created {rl_gnn} config.json and saved to {foldername}')
        

In [None]:
# load from config example
config = f'/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/{rl_gnn}/checkpoint_{cp}/config.json'
print(config)
loaded_agent = REINFORCEAgent(device=device, config=config)
loaded_agent.policy_network.load_state_dict(torch.load(f'/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/{rl_gnn}/checkpoint_{cp}/trained_params.pkl', map_location=device))
if loaded_agent.filter_network is not None:
    loaded_agent.filter_network.load_state_dict(torch.load(f'/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/{rl_gnn}/checkpoint_{cp}/filter_params.pkl', map_location=device))