Skip to content

Latest commit

 

History

History
363 lines (245 loc) · 9.94 KB

objax.rst

File metadata and controls

363 lines (245 loc) · 9.94 KB

objax package

objax

Modules

Module ModuleList ForceArgs Function Grad GradValues Jit Parallel Vectorize

Module

ModuleList

:pyobjax.nn.Sequential

Usage example:

import objax

ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))])
print(ml.vars())
# (ModuleList)[1]            20 (10, 2)
# +Total(1)                  20

ml.pop()
ml.append(objax.nn.Linear(2, 3))
print(ml.vars())
# (ModuleList)[1](Linear).b        3 (3,)
# (ModuleList)[1](Linear).w        6 (2, 3)
# +Total(2)                        9

Function

Usage example:

import objax
import jax.numpy as jn

m = objax.nn.Linear(2, 3)

def f1(x, y):
    return ((m(x) - y) ** 2).mean()

# Method 1: Create module by calling objax.Function to tell which variables are used.
m1 = objax.Function(f1, m.vars())

# Method 2: Use the decorator.
@objax.Function.with_vars(m.vars())
def f2(x, y):
    return ((m(x) - y) ** 2).mean()

# All behave like functions
x = jn.arange(10).reshape((5, 2))
y = jn.arange(15).reshape((5, 3))
print(type(f1), f1(x, y))  # <class 'function'> 237.01947
print(type(m1), m1(x, y))  # <class 'objax.module.Function'> 237.01947
print(type(f2), f2(x, y))  # <class 'objax.module.Function'> 237.01947

Usage of Function is not necessary: it is made available for aesthetic reasons (to accomodate for users personal taste). It is also used internally to keep the code simple for Grad, Jit, Parallel, Vectorize and future primitives.

with_vars

Usage example:

import objax

m = objax.nn.Linear(2, 3)

@objax.Function.with_vars(m.vars())
def f(x, y):
    return ((m(x) - y) ** 2).mean()

print(type(f))  # <class 'objax.module.Function'>

auto_vars

Usage example:

import objax

m = objax.nn.Linear(2, 3)

@objax.Function.auto_vars
def f(x, y):
    return ((m(x) - y) ** 2).mean()

print(type(f))  # <class 'objax.module.Function'>

ForceArgs

One example of ForceArgs usage is to override training argument for batch normalization:

import objax
from objax.zoo.resnet_v2 import ResNet50

model = ResNet50(in_channels=3, num_classes=1000)

# Modify model to force training=False on first resnet block.
# First two ops in the resnet are convolution and padding,
# resnet blocks are starting at index 2.
model[2] = objax.ForceArgs(model[2], training=False)

# model(x, training=True) will be using `training=False` on model[2] due to ForceArgs
# ...

# Undo specific value of forced arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=True)

# Undo all values of specific argument in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=objax.ForceArgs.ANY)

# Undo all values of all arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model)

Grad

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])

def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_f = objax.Grad(f, m.vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_fx = objax.Grad(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

GradValues

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 2)])

@objax.Function.with_vars(m.vars())
def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_val_f = objax.GradValues(f, m.vars())

# Create module to compute gradients of f for only some variables
grad_val_f_head = objax.GradValues(f, m[:1].vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

Jit

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
jit_m = objax.Jit(m)                          # Jit a module

# For jitting functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

jit_f = objax.Jit(f)

For more information, refer to JIT Compilation. Also note that one can pass variables to be used by Jit for a module `m`: the rest will be optimized away as constants, for more information refer to Constant optimization.

Parallel

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
para_m = objax.Parallel(m)                         # Parallelize a module

# For parallelizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

para_f = objax.Parallel(f)

When calling a parallelized module, one must replicate the variables it uses on all devices:

x = objax.random.normal((16, 2))
with m.vars().replicate():
    y = para_m(x)

For more information, refer to Parallelism. Also note that one can pass variables to be used by Parallel for a module `m`: the rest will be optimized away as constants, for more information refer to Constant optimization.

Vectorize

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vec_m = objax.Vectorize(m)                         # Vectorize a module

# For vectorizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

vec_f = objax.Parallel(f)

For more information and examples, refer to Vectorization.

Variables

BaseVar TrainVar BaseState StateVar TrainRef RandomState VarCollection

BaseVar

TrainVar

BaseState

TrainRef

StateVar

RandomState

VarCollection

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vc = m.vars()  # This is a VarCollection

# It is a dictionary
print(repr(vc))
# {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>,
#  '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>}
print(vc.keys())  # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w'])
assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all()

# Convenience print
print(vc)
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# Extra methods for manipulation of variables:
# For example, increment all variables by 1
vc.assign([x+1 for x in vc.tensors()])

# It's used by other modules.
# For example it's used to tell what variables are used by a function.

@objax.Function.with_vars(vc)
def my_function(x):
    return objax.functional.softmax(m(x))

For more information and examples, refer to VarCollection.

assign

rename

Renaming entries in a VarCollection is a powerful tool that can be used for

  • mapping weights between models that differ slightly.
  • loading data checkpoints from foreign ML frameworks.

Usage example:

import re
import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
print(m.vars())
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# For example remove modules from the name
renamer = objax.util.Renamer([(re.compile('\([^)]+\)'), '')])
print(m.vars().rename(renamer))
# [0].b                       3 (3,)
# [0].w                       6 (2, 3)
# +Total(2)                   9

# One can chain renamers, their syntax is flexible and it can use a string mapping:
renamer_all = objax.util.Renamer({'[': '.', ']': ''}, renamer)
print(m.vars().rename(renamer_all))
# .0.b                        3 (3,)
# .0.w                        6 (2, 3)
# +Total(2)                   9

replicate

subset

tensors

update

Constants

ConvPadding