# PyTorch to Jax/Haiku Weight Transfer (POC)

Proof of concept transferring the weights of Werner Duvaud's networks to a Jax implementation.

## Imports and instantiation

In [1]:
# WERNER IMPORTS
import models
from games.cartpole import MuZeroConfig

# JAX IMPORTS
import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np

In [2]:
config = MuZeroConfig()

In [3]:
model = models.MuZeroNetwork(config)
s = model.state_dict()

## Parse the PyTorch weights

We rename it to conform to Haiku conventions and cast to `jnp.array`.

In [4]:
parsed_pytorch_weights = {}

for key, value in s.items():
    split = key.split(".")
    new_key = f"{'_'.join(split[0].split('_')[:-1])}/~/mlp/~/linear_{int(int(split[-2])/2)}"
    
    value_to_set = jnp.array(value) if split[-1][0] == "b" else jnp.array(value).T
    
    if new_key not in parsed_pytorch_weights.keys():
        parsed_pytorch_weights[new_key] = {
            split[-1][0]: value_to_set
        }
    else:
        parsed_pytorch_weights[new_key][split[-1][0]] = value_to_set

list(parsed_pytorch_weights.keys())



['representation/~/mlp/~/linear_0',
 'dynamics_encoded_state/~/mlp/~/linear_0',
 'dynamics_encoded_state/~/mlp/~/linear_1',
 'dynamics_reward/~/mlp/~/linear_0',
 'dynamics_reward/~/mlp/~/linear_1',
 'prediction_policy/~/mlp/~/linear_0',
 'prediction_policy/~/mlp/~/linear_1',
 'prediction_value/~/mlp/~/linear_0',
 'prediction_value/~/mlp/~/linear_1']

## Define the networks

In [5]:
class Representation(hk.Module):
    def __init__(self):
        super().__init__()
        self.fc = hk.nets.MLP([8])
        
    def __call__(self, x):
        return self.fc(x)

In [6]:
class DynamicsEncodedState(hk.Module):
    def __init__(self):
        super().__init__()
        self.fc = hk.nets.MLP([16, 8], activation=jax.nn.elu)
        
    def __call__(self, x):
        return self.fc(x)

In [7]:
class DynamicsReward(hk.Module):
    def __init__(self):
        super().__init__()
        self.fc = hk.nets.MLP([16, 21], activation=jax.nn.elu)
        
    def __call__(self, x):
        return self.fc(x)

In [8]:
class PredictionPolicy(hk.Module):
    def __init__(self):
        super().__init__()
        self.fc = hk.nets.MLP([16, 2], activation=jax.nn.elu)
        
    def __call__(self, x):
        return self.fc(x)

In [9]:
class PredictionValue(hk.Module):
    def __init__(self):
        super().__init__()
        self.fc = hk.nets.MLP([16, 21], activation=jax.nn.elu)
        
    def __call__(self, x):
        return self.fc(x)

We just test one here. You'd obviously want to instantiate each with the correct dummy dimensions.

In [10]:
def f(x):
    r = PredictionValue()
    return r(x)

f = hk.transform(f)
dummy_obs = jnp.zeros((1, 8)) # should be (1, 4) for representation network
rng_key = jax.random.PRNGKey(42)

params = f.init(rng_key, dummy_obs)

## Automatic weight transfer

In [11]:
mod_params = hk.data_structures.to_mutable_dict(params)

In [12]:
# GENERAL PURPOSE WEIGHT TRANSFER (works for all networks, assuming you're using the combined state_dict)
for k, v in mod_params.items():
    mod_params[k] = parsed_pytorch_weights[k]

## Test network with new parameters

In [13]:
# you'd want to insert a breakpoint into Werner's model to get an input/output pair
# so you can test. our network should produce the exact same output (rounding excepted).

test_obs = jnp.array([[0.3075, 0.2620, 1.0000, 0.0000, 0.9065, 0.9287, 0.0450, 0.7008]])
f.apply(mod_params, None, test_obs)

DeviceArray([[-0.32268435, -0.13136506, -0.00575749, -0.12883782,
              -0.14013603,  0.16850004,  0.23229945, -0.02030243,
              -0.09497203,  0.1749499 ,  0.2773396 , -0.634164  ,
              -0.4616676 , -0.23096934,  0.30355474, -0.42916992,
               0.17451528,  0.26546034, -0.06063245, -0.2616744 ,
              -0.24721263]], dtype=float32)

## Manual weight transfer

You probably don't need this, but it's here just in case.

In [14]:
# mod_params['prediction_value/~/mlp/~/linear_0']['w'] = temp_p['module.0.weight'].T
# mod_params['prediction_value/~/mlp/~/linear_0']['b'] = temp_p['module.0.bias']
# mod_params['prediction_value/~/mlp/~/linear_1']['w'] = temp_p['module.2.weight'].T
# mod_params['prediction_value/~/mlp/~/linear_1']['b'] = temp_p['module.2.bias']

In [15]:
# mod_params['prediction_policy/~/mlp/~/linear_0']['w'] = temp_p['module.0.weight'].T
# mod_params['prediction_policy/~/mlp/~/linear_0']['b'] = temp_p['module.0.bias']
# mod_params['prediction_policy/~/mlp/~/linear_1']['w'] = temp_p['module.2.weight'].T
# mod_params['prediction_policy/~/mlp/~/linear_1']['b'] = temp_p['module.2.bias']

In [16]:
# mod_params['dynamics_state/~/mlp/~/linear_0']['w'] = temp_p['module.0.weight'].T
# mod_params['dynamics_state/~/mlp/~/linear_0']['b'] = temp_p['module.0.bias']
# mod_params['dynamics_state/~/mlp/~/linear_1']['w'] = temp_p['module.2.weight'].T
# mod_params['dynamics_state/~/mlp/~/linear_1']['b'] = temp_p['module.2.bias']

# mod_params['dynamics_reward/~/mlp/~/linear_0']['w'] = temp_p['module.0.weight'].T
# mod_params['dynamics_reward/~/mlp/~/linear_0']['b'] = temp_p['module.0.bias']
# mod_params['dynamics_reward/~/mlp/~/linear_1']['w'] = temp_p['module.2.weight'].T
# mod_params['dynamics_reward/~/mlp/~/linear_1']['b'] = temp_p['module.2.bias']

In [17]:
# mod_params['representation/~/mlp/~/linear_0']['w'] = jnp.array([[ 0.4536,  0.3991, -0.0396,  0.2076],
#         [-0.2074,  0.3541,  0.3898, -0.1555],
#         [ 0.3405, -0.0368, -0.0834,  0.2472],
#         [-0.3188,  0.4464,  0.4505, -0.0251],
#         [-0.2798, -0.2960,  0.2525,  0.1830],
#         [ 0.4745, -0.3207,  0.0425, -0.3299],
#         [ 0.4757, -0.1136, -0.0694,  0.4792],
#         [ 0.1024,  0.0567,  0.2998, -0.2345]]).T
# mod_params['representation/~/mlp/~/linear_0']['b'] = jnp.array([ 0.4359,  0.4771,  0.2069, -0.0961, -0.0662, -0.2580, -0.0164, -0.2669])