This repository was archived by the owner on Oct 31, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 826
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Issue using L2Loss for 2-D regression problem 🐛 #980
Copy link
Copy link
Open
Description
Description
The ranks of data batches are different between the training and evaluation tasks.
In the output below note the following: during the training task, the (model, target, weights) arrays given to the L2Loss object have shapes:
Model (256, 2)
Targets (256, 2)
Weights (256, 2)
(note, I'm trying to do a 2-D regression).
However, during the evaluation task the (model, target, weights) given to the L2Loss object have shapes
Model (4, 2)
Targets (4,)
Weights (4,)
which breaks the evaluator. What I'd expect is for all of these to be of size (4,2) (or more generally (N,2)) i.e. the same behavior as in the train task.
Environment information
Linux version 4.15.0-112-generic (buildd@lcy01-amd64-027)
(gcc version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04))
#113-Ubuntu SMP Thu Jul 9 23:41:39 UTC 2020
$ pip freeze | grep trax
trax==1.3.4
$ pip freeze | grep tensor
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.0
tensorflow-addons==0.11.2
tensorflow-datasets==3.2.1
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.8.0
tensorflow-metadata==0.23.0
tensorflow-probability==0.7.0
tensorflow-text==2.3.0
$ pip freeze | grep jax
jax==0.1.75
jaxlib==0.1.52
$ python -V
Python 3.6.9
For bugs: reproduction and error logs
import os
import trax
from trax import layers as tl
from trax.supervised import training
import numpy
import random
#train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
#eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
def generate_samples():
# (text, lat/lon)
data= [
("Aberdeen MS", numpy.array((33.824742, -88.554591)) ),
("Aberdeen SD", numpy.array((45.463186, -98.471033))),
("Aberdeen WA", numpy.array((46.976432, -123.795781))),
("Amite City LA", numpy.array((30.733723, -90.5208))),
("Amory MS", numpy.array((33.984789, -88.48001))),
("Amouli AS", numpy.array((-14.26556, -170.589772))),
("Amsterdam NY", numpy.array((42.953149, -74.19505)))
]
#data= [
# ("Aberdeen MS", numpy.array([1.0,])),
# ("Aberdeen SD", numpy.array([0.0,])),
#("Aberdeen WA", numpy.array([0.0,])),
# ("Amite City LA", numpy.array([0.0,])),
# ("Amory MS", numpy.array([1.0,])),
# ("Amouli AS", numpy.array([0.0,])),
# ("Amsterdam NY", numpy.array([0.0,]))
#]
for i in range(1024*8):
yield random.choice(data)
train_stream = generate_samples()
eval_stream = generate_samples()
model = tl.Serial(
tl.Embedding(vocab_size=8192, d_feature=256),
tl.Mean(axis=1), # Average on axis 1 (length of sentence).
tl.Dense(2), # Regress to lat/lon
# tl.LogSoftmax() # Produce log-probabilities.
)
# You can print model structure.
print(model)
print(next(train_stream)) # See one example.
data_pipeline = trax.data.Serial(
trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
trax.data.Shuffle(),
# trax.data.FilterByLength(max_length=2048, length_keys=[0]),
trax.data.BucketByLength(boundaries=[ 8, 128,],
batch_sizes=[256, 64, 4],
length_keys=[0]),
trax.data.AddLossWeights()
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes.:wq
example_batch = next(eval_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes.:wq
from trax.fastmath import numpy as jnp
from trax.layers.base import Fn
def SizeReportL2Loss():
"""Returns a layer that computes total L2 loss for one batch."""
def f(model_output, targets, weights): # pylint: disable=invalid-name
"""Returns elementwise-weighted L2 norm of `model_output - targets`.
Args:
model_output: Output from one batch, treated as an unanalyzed tensor.
targets: Tensor of same shape as `model_output` containing element-wise
target values.
weights: Tensor of same shape as `model_output` and `targets`.
"""
print("Model", model_output.shape)
print("Targets", targets.shape)
print("Weights", weights.shape)
trax.shapes.assert_same_shape(model_output, targets)
trax.shapes.assert_same_shape(targets, weights)
l2 = weights * (model_output - targets)**2
return jnp.sum(l2) / jnp.sum(weights)
return Fn('L2Loss', f)
# Training task.
train_task = training.TrainTask(
labeled_data=train_batches_stream,
# loss_layer=tl.CrossEntropyLoss(),
loss_layer=SizeReportL2Loss(),
optimizer=trax.optimizers.Adam(0.01),
# optimizer=trax.optimizers.RMSProp(),
n_steps_per_checkpoint=500,
)
# Evaluaton task.
eval_task = training.EvalTask(
labeled_data=eval_batches_stream,
metrics=[SizeReportL2Loss(),],
# metrics=[tl.L2Loss(), tl.CrossEntropyLoss(), tl.Accuracy()],
n_eval_batches=20 # For less variance in eval numbers.
)
# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
training_loop = training.Loop(model,
train_task,
eval_tasks=[eval_task],
output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(8)
Error logs:
2020-08-27 21:21:08.059843: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
Serial[
Embedding_8192_256
Mean
Dense_2
]
('Aberdeen WA', array([ 46.976432, -123.795781]))
2020-08-27 21:21:10.585194: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found: Could not locate the credentials file.". Retrieving token from GCE failed with "Failed precondition: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Couldn't resolve host 'metadata'".
shapes = [(256, 8), (256, 2), (256, 2)]
shapes = [(256, 8), (256, 2), (256, 2)]
/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Model (256, 2)
Targets (256, 2)
Weights (256, 2)
Model (4, 2)
Targets (4,)
Weights (4,)
Traceback (most recent call last):
File "trax04.py", line 119, in <module>
output_dir=output_dir)
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/supervised/training.py", line 187, in __init__
self.load_checkpoint()
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/supervised/training.py", line 588, in load_checkpoint
self._model_in_training.init_from_file(path)
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/layers/base.py", line 311, in init_from_file
weights_and_state_sig = self.weights_and_state_signature(input_signature)
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/layers/base.py", line 443, in weights_and_state_signature
return abstract_init(input_signature)
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/fastmath/jax.py", line 310, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/api.py", line 1753, in eval_shape
*map(abstractify, args_flat))
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 324, in abstract_eval_fun
instantiate=True, stage_out=True)
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 423, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/layers/base.py", line 288, in init
input_signature, trace) from None
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/supervised/training.py, line 146
layer input shapes: (Traced<ShapedArray(int32[4,1024]):JaxprTrace(level=0/0)>, Traced<ShapedArray(int32[4]):JaxprTrace(level=0/0)>, Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)>)
File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer L2Loss (in _forward_abstract):
layer created in file [...]/trax/layers/base.py, line 704
layer input shapes: (ShapeDtype{shape:(4, 2), dtype:float32}, Traced<ShapedArray(int32[4]):JaxprTrace(level=0/0)>, Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)>)
File [...]/jax/interpreters/partial_eval.py, line 324, in abstract_eval_fun
instantiate=True, stage_out=True)
File [...]/jax/interpreters/partial_eval.py, line 423, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer L2Loss (in pure_fn):
layer created in file [...]/trax/layers/base.py, line 704
layer input shapes: (ShapeDtype{shape:(4, 2), dtype:float32}, ShapeDtype{shape:(4,), dtype:int32}, ShapeDtype{shape:(4,), dtype:float32})
File [...]/trax/layers/base.py, line 658, in forward
raw_output = self._forward_fn(inputs)
File [...]/trax/layers/base.py, line 700, in _forward
return f(*xs)
File [...]/trax04.py, line 88, in f
trax.shapes.assert_same_shape(model_output, targets)
File [...]/site-packages/trax/shapes.py, line 138, in assert_same_shape
assert_shape_equals(array1, array2.shape)
File [...]/site-packages/trax/shapes.py, line 132, in assert_shape_equals
'Invalid shape {}; expected {}.'.format(array.shape, shape)
AssertionError: Invalid shape (4, 2); expected (4,).
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels