# Loading and saving `RCModel`s
This notebook will show you:
- how to load any `acrd.RCModel` by hand
- how to save any `arcd.RCModel` by hand

Note: This notebook depends on files created in the first notebook `1_Toy_pytorch_simple_setup.ipynb`. Please do it first if you have not already.

In [1]:
%matplotlib inline

In [2]:
import os
import arcd
import numpy as np
import matplotlib.pyplot as plt
import openpathsampling as paths

Using TensorFlow backend.


In [3]:
# change to the working directory of choice
# (should be the same as in the first notebook)
wdir = '/homeloc/scratch/hejung/arcd_scratch/toys/'
#wdir = None
if wdir is not None:
    os.chdir(wdir)

## Loading: The harder but more flexible way
Of course you can also load a `RCModel` by hand. This is useful for example to do analysis on intermediate models that have been saved.

In [4]:
# open old storage for reading
# we need the storage to get the openpathsampling CV function that is attached to the model
# if no OPS CV is attached as descriptor_transform we can do without
storage = paths.Storage('22dim_toy_pytorch.nc', 'r')

In [5]:
# lets have a look which models there are
possible_mods = os.listdir()
# the model save file names are usually derived from the corresponding OPS storage
possible_mods = [m for m in possible_mods if '22dim_toy_pytorch.nc' in m]

In [6]:
possible_mods

['22dim_toy_pytorch.nc',
 '22dim_toy_pytorch.nc_RCmodel_at_step100.pckl',
 '22dim_toy_pytorch.nc_RCmodel_at_step200.pckl',
 '22dim_toy_pytorch.nc_RCmodel_at_step300.pckl',
 '22dim_toy_pytorch.nc_RCmodel_at_step400.pckl',
 '22dim_toy_pytorch.nc_RCmodel_at_step500.pckl',
 '22dim_toy_pytorch.nc_RCmodel.pckl']

In [7]:
# lets load the model at step200
fname = possible_mods[2]  # == '22dim_toy_pytorch.nc_RCmodel_at_step200.pckl'

from arcd.base.rcmodel import RCModel
# the first step is always asking the `arcd.base.RCModel` which subclass the actual model belongs to
# this gives us the correct subclass and also a half-fixed state, i.e. we set descriptor_transform to the OPS CV
state, cls = RCModel.load_state(fname, storage)
# now we let the specific RCModel subclass implementation fix its state
# this corrects the rest of the state, e.g. load the ANN with weights
state = cls.fix_state(state)
# this finally instantiates the correct RCModel class using the fixed state
model = cls.set_state(state)

In [8]:
# check that it worked
print(model)
print(model(storage.snapshots[-1]))

<arcd.pytorch.rcmodel.EEPytorchRCModel object at 0x7f2168307518>
[[0.5737502]]


## Saving: The flexible way
Of course you can also save any `arcd.RCModel` you want at any time. In case the model has an attached OPS `CollectiveVariable` as `descriptor_transform` it will save its name instead, such that you can later only completely reconstruct the model if passing an `openpathsampling.Storage` that contains the/any CV with that name.

Note that without an OPS storage passed to `load_state()` it will still reconstruct the model, but leave the `descriptor_transform` untouched. This enables you to attach arbitrary (pickleable) functions as `descriptor_transform` (and not only OPS CVs) and they will be reconstructed as intended.

In [9]:
model.save('save_test1')

In [10]:
# see that it worked
possible_mods = os.listdir()
# the model save file names are usually derived from the corresponding OPS storage
possible_mods = [m for m in possible_mods if 'save_test1' in m]
print(possible_mods)

['save_test1.pckl']


In [11]:
# now lets test saving arbitrary python functions as descriptor_transform
def test_func(x, y):
    return x + y

# this will break our model as it changes the call signature of descriptor_transform, only for demonstration purposes
model.descriptor_transform = test_func

In [12]:
model.save('save_test2')

In [13]:
# load it to check if it worked
fname = 'save_test2.pckl'  # need filename with extension to load
# same as above, except we do not (have to) pass an OPS storage
state, cls = RCModel.load_state(fname)
state = cls.fix_state(state)
loaded_model = cls.set_state(state)

In [14]:
# should return 2
loaded_model.descriptor_transform(1, 1)

2