You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi. Ia am trying to execute the graphcast model in a conda enviornment built with the same packages version of a working execution at google collab but whenever I try to build the model at construct_wrapped_graphcast function returns this error:
Traceback (most recent call last):
File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 76, in <module>
model = construct_wrapped_graphcast(model_config, task_config)
File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 58, in construct_wrapped_graphcast
predictor = graphcast.GraphCast(model_config, task_config)
File "/home/eloy.anguiano/repos/graphcast/graphcast/graphcast.py", line 261, in __init__
self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 139, in __call__
init(module, *args, **kwargs)
File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 433, in wrapped
raise ValueError(
ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.
I checked that both dm-haiku versions (collaboratory and local) are 0.0.11. Is there any dockerfile to build a working environment or something like that? It is very difficult to run the same collab env at local.
How to reproduce:
fromgoogle.cloudimportstoragefromgraphcastimportautoregressivefromgraphcastimportcastingfromgraphcastimportcheckpointfromgraphcastimportgraphcastfromgraphcastimportnormalizationimportxarrayMODEL_VERSION='GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz'# @title Authenticate with Google Cloud Storagegcs_client=storage.Client.create_anonymous_client()
gcs_bucket=gcs_client.get_bucket("dm_graphcast")
withgcs_bucket.blob(f"params/{MODEL_VERSION}").open("rb") asf:
ckpt=checkpoint.load(f, graphcast.CheckPoint)
params=ckpt.paramsstate= {}
model_config=ckpt.model_configtask_config=ckpt.task_configprint("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")
withgcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") asf:
diffs_stddev_by_level=xarray.load_dataset(f).compute()
withgcs_bucket.blob("stats/mean_by_level.nc").open("rb") asf:
mean_by_level=xarray.load_dataset(f).compute()
withgcs_bucket.blob("stats/stddev_by_level.nc").open("rb") asf:
stddev_by_level=xarray.load_dataset(f).compute()
defconstruct_wrapped_graphcast(
model_config: graphcast.ModelConfig,
task_config: graphcast.TaskConfig):
"""Constructs and wraps the GraphCast Predictor."""# Deeper one-step predictor.predictor=graphcast.GraphCast(model_config, task_config)
# Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to# from/to float32 to/from BFloat16.predictor=casting.Bfloat16Cast(predictor)
# Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from# BFloat16 happens after applying normalization to the inputs/targets.predictor=normalization.InputsAndResiduals(
predictor,
diffs_stddev_by_level=diffs_stddev_by_level,
mean_by_level=mean_by_level,
stddev_by_level=stddev_by_level)
# Wraps everything so the one-step model can produce trajectories.predictor=autoregressive.Predictor(predictor, gradient_checkpointing=True)
returnpredictormodel=construct_wrapped_graphcast(model_config, task_config)
print("Done")
The text was updated successfully, but these errors were encountered:
EloyAnguiano
changed the title
Haiku needs all hk.Modules must be initialized inside an hk.transform
Haiku needs all hk.Module must be initialized inside an hk.transformJan 15, 2024
Thanks for your message, this is totally expected in Haiku, because as the error says all hk.Modules must be initialized inside an hk.transform, and GraphCast is a Haiku module.
Hi. Ia am trying to execute the graphcast model in a conda enviornment built with the same packages version of a working execution at google collab but whenever I try to build the model at
construct_wrapped_graphcast
function returns this error:I checked that both dm-haiku versions (collaboratory and local) are
0.0.11
. Is there any dockerfile to build a working environment or something like that? It is very difficult to run the same collab env at local.How to reproduce:
The text was updated successfully, but these errors were encountered: