v1.0.1
Bug Fixes
- Fix
batch_size=1JAX bug - Fix gradient computation in backward pass
- Fix
verbose=Trueat layer construction - Suppress PyTorch sparse CSR beta warning
Improvements
- Add
torch.compileandjax.jit+vmapsupport for Moreau interface (#202) - Use scipy sparse CSR for CPU matmuls instead of framework sparse tensors
- Add runtime solver args for Moreau torch interface