When I have a mix of tf.Variable and KerasVariables I get the following error:
--> 632 if v.overwrite_with_gradient:
633 if self.gradient_accumulation_steps:
634 # Utilize a stateless manner for JAX compatibility
635 steps = self.gradient_accumulation_steps
AttributeError: 'ResourceVariable' object has no attribute 'overwrite_with_gradient'
I suspect this is because my list of variables is [KerasVariables] + [tf.Variables]
and the following line only checks the first in the list as to whether overwrite_with_gradient can be used?
|
if not hasattr(vars[0], "overwrite_with_gradient"): |