In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
# Standard packages
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
from copy import deepcopy

import scipy
from scipy import special as s
import pandas as pd
import numpy
import math
from time import time
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
cwd = os.getcwd()
NOTEBOOK_DIR = os.path.dirname(os.path.dirname(cwd+'/'))
ROOT = os.path.dirname(NOTEBOOK_DIR)

In [4]:
# Custom packages
import sys
sys.path.append(ROOT)
from utils.plot import *
from utils.tools import *
from utils.nn import TARGET_FUNCS_DICT, LOSS_DICT
from networks.muP_resnet import MuPResNet
from layers.jax.residual import Residual

In [5]:
# JAX packages
import jax
from jax import jit
import haiku as hk
import jax.numpy as jnp

# Constants

In [6]:
INPUT_DIM = 64 #30
WIDTH = 256 #512 #512 # 512 #40 # 256
#D_MODEL = 128
N_RES = 500 #1000
BIAS = False
ALPHA = 1.0
SCALE = 1.0
ACTIVATION = 'relu'

SEED = 42
BATCH_SIZE = 64
#N_TRIALS = 10 #10
BASE_LR = 1.0e-2
N_STEPS = int(4.0e3)
N_VAL = 500
VAL_ITER = 50
alpha = 1 / N_RES

In [7]:
RNG_KEY = jax.random.PRNGKey(42)
key_0, key_1 = jax.random.split(key=RNG_KEY, num=2)

# Net

In [8]:
from layers.jax.residual import Residual

In [9]:
def _forward_fn(x):
    net = Residual(d=WIDTH, width=WIDTH, activation=ACTIVATION, bias=BIAS, alpha=alpha)
    return net(x)

In [10]:
forward_fn = hk.without_apply_rng(hk.transform(_forward_fn))

In [11]:
x = jax.random.normal(key=key_0, shape=(WIDTH,))

In [12]:
params = forward_fn.init(rng=key_1, x=x)
params

TypeError: unexpected PRNG key type <class 'list'>

In [None]:
hk.initializers.

In [23]:
output = forward_fn.apply(x=x, params=params)
print(output.shape)
output

(256,)


DeviceArray([ 0.45846346,  0.06503269, -0.07381058, -0.26981813,
              0.6445098 , -0.40612713,  3.600132  ,  0.54357916,
             -0.3267934 ,  2.1096427 ,  1.3077282 ,  1.1968408 ,
             -1.3792636 ,  1.9957889 , -1.6408458 ,  1.0788362 ,
             -0.02287359,  0.882253  ,  0.48119825,  0.1783561 ,
              0.30261105,  0.80525875,  0.6288748 , -0.24098386,
             -1.0255009 ,  0.7501734 , -0.19838047,  0.07536075,
              0.66395146, -0.6117503 , -0.6955691 , -0.4440646 ,
             -1.7750372 ,  0.02283927,  0.03941284,  0.35539868,
             -0.47521847, -0.9878623 , -0.24215735, -1.0723262 ,
             -0.9975525 ,  0.22574413,  1.4150982 ,  1.5500433 ,
             -0.12253696,  0.20015198,  0.6174063 ,  0.23970821,
              0.9243334 ,  1.8405286 ,  0.88899297,  0.39620674,
             -1.521879  ,  0.29617462,  1.521638  , -0.34095857,
              0.24158362, -0.52253723, -0.23626004,  0.92786276,
             -0.6374453 ,

In [24]:
class MyLinear1(hk.Module):

    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, x):
        j, k = x.shape[-1], self.output_size
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
        w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
        b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
        return jnp.dot(x, w) + b

In [41]:
alpha = 1 / N_RES

class MyModuleCustom(hk.Module):
    def __init__(self, d=WIDTH, name='custom_linear', alpha=1.0):
        super().__init__(name=name)
        self.alpha = alpha
        #self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')
        #self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')
        self.first_layer = hk.Linear(output_size=WIDTH, with_bias=False, name='first_layer')
        self.second_layer = hk.Linear(output_size=WIDTH, with_bias=False, name='seccond_layer')

    def __call__(self, x):
        return x + self.alpha * self.second_layer(self.first_layer(x))

def _custom_forward_fn(x):
    module = MyModuleCustom()
    return module(x)

sample_x = jnp.arange(WIDTH).reshape(1,-1).astype(float)
custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)
params

{'custom_linear/~/first_layer': {'w': DeviceArray([[ 0.10137872, -0.06849594, -0.05502961, ..., -0.0580392 ,
                -0.08042663, -0.00842963],
               [-0.02619505,  0.08058587, -0.04703137, ..., -0.04650393,
                -0.09204443, -0.09061009],
               [ 0.07700656, -0.02135302,  0.01246158, ..., -0.02489848,
                 0.01373748,  0.04003981],
               ...,
               [-0.10954096,  0.03559806, -0.00922422, ..., -0.00256604,
                -0.08048458, -0.0076758 ],
               [-0.01473674, -0.02222958,  0.00128034, ...,  0.09191827,
                -0.06019867, -0.04916775],
               [ 0.00722767, -0.01513949, -0.04333597, ..., -0.05536818,
                -0.02729566, -0.06372111]], dtype=float32)},
 'custom_linear/~/seccond_layer': {'w': DeviceArray([[ 0.07734542, -0.06559161,  0.01210385, ..., -0.02081679,
                 0.00435181, -0.05712689],
               [-0.02065541,  0.05118439,  0.02106534, ...,  0.08045646,
   

In [38]:
sample_x.shape

(1, 256)