Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model and optimizer.target diff #192

Closed
us opened this issue Apr 9, 2020 · 2 comments
Closed

model and optimizer.target diff #192

us opened this issue Apr 9, 2020 · 2 comments
Assignees

Comments

@us
Copy link
Contributor

us commented Apr 9, 2020

I am working on pix2pix example. I don't get an error when I give the picture to the model, but when I give the picture to model_optimizer.target I am getting error. Why?

Code

import jax
import flax

import numpy as onp
import jax.numpy as jnp


OUTPUT_CHANNELS = 3

class DownSample(flax.nn.Module):
  def apply(self, x, features, size, apply_batchnorm=True):
    x = flax.nn.Conv(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
    if apply_batchnorm:
      x = flax.nn.BatchNorm(x)
    x = flax.nn.leaky_relu(x)
    return x

class UpSample(flax.nn.Module):
  def apply(self, x, features, size, apply_dropout=True):
    x = flax.nn.ConvTranspose(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
    x = flax.nn.BatchNorm(x)
    if apply_dropout:
      x = flax.nn.dropout(x, 0.5)
    x = flax.nn.relu(x)
    return x

down_list = [[64, 4, False],
             [128, 4],
             [256, 4],
             [512, 4],
             [512, 4],
             [512, 4],
             [512, 4],
             [512, 4]]

up_list = [[512, 4, True],
           [512, 4, True],
           [512, 4, True],
           [512, 4],
           [256, 4],
           [128, 4],
           [64, 4]]

class Generator(flax.nn.Module):
  def apply(self, x):
    skips = []
    for down in down_list:
      x = DownSample(x, *down)
      skips.append(x)
    
    skips = list(reversed(skips[:-1]))
    for up, skip in zip(up_list, skips):
      x = UpSample(x, *up)
      x = jnp.concatenate((x,skip))
    
    x = flax.nn.ConvTranspose(x, features=OUTPUT_CHANNELS, kernel_size=(4,4), strides=(2,2), padding='SAME')
    x = flax.nn.tanh(x)
    return x

def create_model(key, batch_size, image_size, model_def):
  input_shape = (batch_size, image_size, image_size, 3)
  with flax.nn.stateful() as init_state:
    with flax.nn.stochastic(jax.random.PRNGKey(0)):
      _, initial_params = model_def.init_by_shape(key, [(input_shape, jnp.float32)])
      model = flax.nn.Model(model_def, initial_params)
  return model, init_state

def create_optimizer(model, learning_rate, beta):
  optimizer_def = flax.optim.Adam(learning_rate=learning_rate,
                                 beta1=beta)
  optimizer = optimizer_def.create(model)
  optimizer = flax.jax_utils.replicate(optimizer)
  return optimizer

key = jax.random.PRNGKey(0)
generator_model, generator_state = create_model(key, 1, 256, Generator)
generator_optimizer = create_optimizer(generator_model, 2e-4, 0.5)

test_input = jax.random.normal(jax.random.PRNGKey(1), (1, 256, 256, 3))
with flax.nn.stochastic(jax.random.PRNGKey(0)):
  prediction = generator_model(test_input)  # work with no error
  print('prediction ok')
with flax.nn.stochastic(jax.random.PRNGKey(0)):
  prediction_opt = generator_optimizer.target(test_input)

Error

ValueError: Existing shape (1, 4, 4, 3, 64) differs from requested shape (4, 4, 3, 64)

Related pr: #186

Thank you!

@us us added the bug label Apr 9, 2020
@levskaya
Copy link
Collaborator

levskaya commented Apr 9, 2020

OK, so if you were just trying to evaluate this on a single device, you'd just remove this line from create_optimizer:

  optimizer = flax.jax_utils.replicate(optimizer)

the replicate fn is broadcasting the params to every "device" being used now - on a single gpu just 1 (which is kind of silly), but on multigpu or tpu: 4, 8, etc. Notice you didn't replicate the model above but replicated the optimizer (again both hold params inside them for convenience) which is why the second broke.

Now, assuming you do want to do replicated SPMD computation across multiple devices, you would keep that replicate call in create_optimizer but you also need to define a model eval function to be pmapped so that it can use the replicated parameters it stores across devices, for example:

@jax.pmap
def eval_w_pmap(model, x, prng_key):
  with flax.nn.stochastic(prng_key):
    # model is just a container for replicated params
    return model(x)
ldc = jax.local_device_count()
pmap_test_input = jax.random.normal(jax.random.PRNGKey(1), (ldc, 1, 256, 256, 3))
pmap_rngs = jax.random.split(jax.random.PRNGKey(0), ldc)
eval_w_pmap(generator_optimizer.target, pmap_test_input, pmap_rngs).shape  # (1, 8, 256, 256, 3)

I hope that helps explain what's going on - please let me know if it's still not clear!

@levskaya levskaya self-assigned this Apr 9, 2020
@us us mentioned this issue Apr 9, 2020
@us
Copy link
Contributor Author

us commented Apr 9, 2020

Thank you @levskaya, this is very clear explanation!

@us us closed this as completed Apr 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants