In [1]:
# THIS CELL SETS STUFF UP FOR DEMO / COLLAB. THIS CELL CAN BE IGNORED.

#-------------------------------------GET RID OF TF DEPRECATION WARNINGS--------------------------------------#
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

#----------------------------------INSTALL PSYCHRNN IF IN A COLAB NOTEBOOK-------------------------------------#
# Installs the correct branch / release version based on the URL. If no branch is provided, loads from master.
# Loads saved weights from correct branch and saves a local copy for later use.
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

if IN_COLAB:
    import json
    import re
    import ipykernel
    import requests 
    from requests.compat import urljoin
    from io import BytesIO
    import numpy as np
    import os

    from notebook.notebookapp import list_running_servers
    kernel_id = re.search('kernel-(.*).json',
                          ipykernel.connect.get_connection_file()).group(1)
    servers = list_running_servers()
    for ss in servers:
        response = requests.get(urljoin(ss['url'], 'api/sessions'),
                                params={'token': ss.get('token', '')})
        for nn in json.loads(response.text):
            if nn['kernel']['id'] == kernel_id:
                relative_path = nn['notebook']['path'].split('%2F')
                if 'blob' in relative_path:
                  blob = relative_path[relative_path.index('blob') + 1]
                  !pip install git+https://github.com/murraylab/PsychRNN@$blob
                  file_location = "https://github.com/murraylab/PsychRNN/blob/" + blob + "/docs/notebooks/weights/saved_weights.npz?raw=true"
                else:
                  !pip install git+https://github.com/murraylab/PsychRNN
                  file_location = "https://github.com/murraylab/PsychRNN/docs/notebooks/weights/saved_weights.npz?raw=true"

    r = requests.get(file_location, stream = True)
    data = dict(np.load(BytesIO(r.raw.read()), allow_pickle = True))
    if not os.path.exists("./weights"):
        os.makedirs("./weights")
    np.savez("./weights/saved_weights.npz", **data)

# Accessing and Modifying Weights

In [Simple Example](PerceptualDiscrimination.ipynb#Get-&-Save-Model-Weights), we saved weights to ``./weights/saved_weights``. Here we will load those weights, and modify them by silencing a few recurrent units.

In [2]:
import numpy as np
weights = dict(np.load('./weights/saved_weights.npz', allow_pickle = True))
weights['W_rec'][:10, :10] = 0

Here are all the different weights you have access to for modifying. The ones that don't end in ``Adam`` or ``Adam_1`` will be read in when loading a model from weights.

In [3]:
print(weights.keys())

dict_keys(['init_state', 'W_in', 'W_rec', 'W_out', 'b_rec', 'b_out', 'Dale_rec', 'Dale_out', 'input_connectivity', 'rec_connectivity', 'output_connectivity', 'init_state/Adam', 'init_state/Adam_1', 'W_in/Adam', 'W_in/Adam_1', 'W_rec/Adam', 'W_rec/Adam_1', 'W_out/Adam', 'W_out/Adam_1', 'b_rec/Adam', 'b_rec/Adam_1', 'b_out/Adam', 'b_out/Adam_1', 'dale_ratio'])


Save the modified weights at ``'./weights/modified_saved_weights.npz'``.

In [4]:
np.savez('./weights/modified_saved_weights.npz', **weights)

# Loading Model with Weights 

In [5]:
from psychrnn.backend.models.basic import Basic

In [6]:
network_params = {'N_batch': 50,
                  'N_in': 2,
                  'N_out': 2,
                  'dt': 10,
                  'tau': 100,
                  'T': 2000,
                  'N_steps': 200,
                  'N_rec': 50
                 }

### Load from File

Set network parameters.

In [7]:
file_network_params = network_params.copy()
file_network_params['name'] = 'file'
file_network_params['load_weights_path'] = './weights/modified_saved_weights.npz'

Instantiate model.

In [8]:
fileModel = Basic(file_network_params)

Verify that the W_rec weights are modified as expected.

In [9]:
print(fileModel.get_weights()['W_rec'][:10,:10])

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [10]:
fileModel.destruct()

### Load from Weights Dictionary

Set network parameters.

In [11]:
dict_network_params = network_params.copy()
dict_network_params['name'] = 'dict'
dict_network_params.update(weights)
type(dict_network_params['dale_ratio']) == np.ndarray and dict_network_params['dale_ratio'].item() is None

True

Instantiate model.

In [12]:
dictModel = Basic(dict_network_params)

Verify that the W_rec weights are modified as expected.

In [13]:
print(dictModel.get_weights()['W_rec'][:10,:10])

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [14]:
dictModel.destruct()