Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Haiku needs all hk.Module must be initialized inside an hk.transform #53

Closed
EloyAnguiano opened this issue Jan 15, 2024 · 1 comment
Closed

Comments

@EloyAnguiano
Copy link

EloyAnguiano commented Jan 15, 2024

Hi. Ia am trying to execute the graphcast model in a conda enviornment built with the same packages version of a working execution at google collab but whenever I try to build the model at construct_wrapped_graphcast function returns this error:

Traceback (most recent call last):
  File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 76, in <module>
    model = construct_wrapped_graphcast(model_config, task_config)
  File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 58, in construct_wrapped_graphcast
    predictor = graphcast.GraphCast(model_config, task_config)
  File "/home/eloy.anguiano/repos/graphcast/graphcast/graphcast.py", line 261, in __init__
    self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
  File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 139, in __call__
    init(module, *args, **kwargs)
  File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 433, in wrapped
    raise ValueError(
ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

I checked that both dm-haiku versions (collaboratory and local) are 0.0.11. Is there any dockerfile to build a working environment or something like that? It is very difficult to run the same collab env at local.

How to reproduce:

from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import graphcast
from graphcast import normalization
import xarray


MODEL_VERSION = 'GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz'

# @title Authenticate with Google Cloud Storage
gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

with gcs_bucket.blob(f"params/{MODEL_VERSION}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)
params = ckpt.params
state = {}

model_config = ckpt.model_config
task_config = ckpt.task_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")



with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
    """Constructs and wraps the GraphCast Predictor."""
    # Deeper one-step predictor.
    predictor = graphcast.GraphCast(model_config, task_config)

    # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
    # from/to float32 to/from BFloat16.
    predictor = casting.Bfloat16Cast(predictor)

    # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
    # BFloat16 happens after applying normalization to the inputs/targets.
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level)

    # Wraps everything so the one-step model can produce trajectories.
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
    return predictor

model = construct_wrapped_graphcast(model_config, task_config)
print("Done")
@EloyAnguiano EloyAnguiano changed the title Haiku needs all hk.Modules must be initialized inside an hk.transform Haiku needs all hk.Module must be initialized inside an hk.transform Jan 15, 2024
@alvarosg
Copy link
Collaborator

Thanks for your message, this is totally expected in Haiku, because as the error says all hk.Modules must be initialized inside an hk.transform, and GraphCast is a Haiku module.

See this bit of code in the GraphCast demo:

@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)

which contains an example of how to use hk.transform.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants