You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for releasing the source code, it helps a lot.
I'm currently working on a similar project, but I'm wondering if there is any way to display the current values of each tensor with rlax during the training? When I want to print them, it just displays me their shape :
ListenerLossOutputs(loss=Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)> tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>
Do you have any advice?
Thanks in advance
The text was updated successfully, but these errors were encountered:
Hi Bastien,
This framework is based on Jaxline. You can try "--jaxline_disable_pmap_jit" as
explained in here https://github.com/deepmind/jaxline to debug the tensors.
Good luck,
Best,
Rahma
Le mar. 4 avr. 2023 à 11:07, Bastien ***@***.***> a écrit :
Hi,
Thanks for releasing the source code, it helps a lot.
I'm currently working on a similar project, but I'm wondering if there is
any way to display the current values of each tensor with rlax during the
training? When I want to print them, it just displays me their shape :
ListenerLossOutputs(loss=Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)>
with primal =
Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)> tangent =
Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>
Do you have any advice?
Thanks in advance
—
Reply to this email directly, view it on GitHub
<#2>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AD3QGLOCH5ZUEZPLUYMTDE3W7PXE5ANCNFSM6AAAAAAWSPNQ3A>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
Hi,
Thanks for releasing the source code, it helps a lot.
I'm currently working on a similar project, but I'm wondering if there is any way to display the current values of each tensor with rlax during the training? When I want to print them, it just displays me their shape :
ListenerLossOutputs(loss=Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)> tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>
Do you have any advice?
Thanks in advance
The text was updated successfully, but these errors were encountered: