In [2]:
from models import multi_conv as mc
from weakly_convex_ridge_regularizer.models import multi_conv as mct
from weakly_convex_ridge_regularizer.models import spline_module as smc
from models import spline_module as sm
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import torch 
import jax.numpy as jnp

In [4]:
conv_layers = mc.MultiConv2d()
conv_layers.check_tranpose()

conv_layers_t = mct.MultiConv2d()
conv_layers_t.check_tranpose()

ps_1: 169.39486694335938
ps_2: 169.3951416015625
ratio: 0.9999983906745911
ps_1: 238.53005981445312
ps_2: 238.530029296875
ratio: 1.0000001279401936


# Check spline_module.py (avec spline_autograd_func.py)

! j ai testé que les forward de spline_autograd_func.py


In [5]:
# Paramètres du test
spline_torch = smc.LinearSpline(num_activations = 1, num_knots = 10, x_min =-2, x_max = 2, init = 'identity', slope_max=-1, slope_min=1, antisymmetric=False, clamp=True,)
spline_jax = sm.LinearSpline(num_activations = 1, num_knots = 10, x_min =-2, x_max = 2, init = 'identity', slope_max=-1, slope_min=1, antisymmetric=False, clamp=True,)

In [6]:
np.random.seed(4)
x_np = np.random.uniform(-2, 2, size=(3,4)).astype(np.float32)
x_np

array([[ 1.8681194 ,  0.18892899,  1.8907374 ,  0.85926396],
       [ 0.7909153 , -1.135642  ,  1.9050978 , -1.9750789 ],
       [-0.98807055, -0.26083386,  1.1175317 , -1.2092597 ]],
      dtype=float32)

check forward

In [7]:
x_torch = torch.tensor(x_np)
spline_torch(x_torch)

tensor([[-4.6667, -3.3000, -4.6667, -3.9704],
        [-3.9020, -1.9755, -4.6667, -1.5556],
        [-2.1230, -2.8503, -4.2286, -1.9019]], grad_fn=<ViewBackward0>)

In [8]:
x_jax = jnp.array(x_np)
spline_jax(x_jax)

Array([[-4.666667 , -3.3000405, -4.666667 , -3.9703753],
       [-3.9020264, -1.9754691, -4.666667 , -1.5555556],
       [-2.1230407, -2.8502772, -4.228643 , -1.9018514]], dtype=float32)

check derivative

In [9]:
spline_torch.derivative(x_torch)

tensor([[ 0.0000, -1.0000,  0.0000, -1.0000],
        [-1.0000, -1.0000,  0.0000,  0.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000]], grad_fn=<ViewBackward0>)

In [10]:
spline_jax.derivative(x_jax)

Array([[ 0.        , -1.0000005 ,  0.        , -0.99999994],
       [-0.99999994, -0.99999994,  0.        ,  0.        ],
       [-0.99999994, -0.99999994, -1.0000005 , -0.99999994]],      dtype=float32)

check integrate

In [11]:
spline_torch.integrate(x_torch)

**** Updating integrated spline coefficients ****


tensor([[-6.0512,  0.8512, -6.1568, -1.5856],
        [-1.3166,  4.3451, -6.2238,  5.7390],
        [ 4.0426,  2.2343, -2.6444,  4.4878]], grad_fn=<ViewBackward0>)

In [12]:
spline_jax.integrate(x_jax)

**** Updating integrated spline coefficients ****


Array([[-6.0512247,  0.8511638, -6.156777 , -1.5856433],
       [-1.3166085,  4.345057 , -6.2237926,  5.7390122],
       [ 4.042646 ,  2.234256 , -2.6444144,  4.4877768]], dtype=float32)

check get_clip_equivalent

In [13]:
spline_torch.get_clip_equivalent()(x_torch)

tensor([[[[-1.5556, -2.9222, -1.5556, -2.2518],
          [-2.3202, -4.2468, -1.5556, -4.6667],
          [-4.0992, -3.3719, -1.9936, -4.3204]]]], grad_fn=<AddBackward0>)

In [14]:
spline_jax.get_clip_equivalent()(x_jax)

Array([[[[-1.5555553, -2.922182 , -1.5555553, -2.2518473],
         [-2.320196 , -4.246753 , -1.5555553, -4.666667 ],
         [-4.099182 , -3.3719451, -1.9935796, -4.320371 ]]]],      dtype=float32)

In [15]:
spline_torch.get_clip_equivalent().integrate(x_torch)

tensor([[[[10.1709,  1.9624, 10.2940,  4.9206],
          [ 4.5984, -2.5617, 10.3722, -4.6085],
          [-2.1445,  0.2295,  6.1803, -2.7616]]]], grad_fn=<MulBackward0>)

In [16]:
spline_jax.get_clip_equivalent().integrate(x_jax)

Array([[[[10.170874  ,  1.9624476 , 10.294016  ,  4.920626  ],
         [ 4.598432  , -2.5616682 , 10.3722    , -4.608518  ],
         [-2.1444788 ,  0.22953984,  6.180272  , -2.761646  ]]]],      dtype=float32)

In [17]:
spline_torch.get_clip_equivalent().slope_max

tensor([[[-1.0000]]], grad_fn=<MaxBackward0>)

In [18]:
spline_jax.get_clip_equivalent().slope_max

Array([[[-1.0000001]]], dtype=float32)

nimp

In [19]:
x_torch.exp()

tensor([[6.4761, 1.2080, 6.6243, 2.3614],
        [2.2054, 0.3212, 6.7201, 0.1388],
        [0.3723, 0.7704, 3.0573, 0.2984]])

In [20]:
jnp.exp(x_jax)

Array([[6.4761057 , 1.2079551 , 6.6242514 , 2.361422  ],
       [2.205414  , 0.3212158 , 6.720065  , 0.13875036],
       [0.3722943 , 0.77040887, 3.0572984 , 0.2984181 ]], dtype=float32)

In [21]:
x_jax

Array([[ 1.8681194 ,  0.18892899,  1.8907374 ,  0.85926396],
       [ 0.7909153 , -1.135642  ,  1.9050978 , -1.9750789 ],
       [-0.98807055, -0.26083386,  1.1175317 , -1.2092597 ]],      dtype=float32)

In [22]:
def activation(x, sigma=None, skip_scaling=False):
        # get scaling, which depends on sigma and on the channel
       x = x *2 
       y = x + 1
       return y



In [23]:
x_torch

tensor([[ 1.8681,  0.1889,  1.8907,  0.8593],
        [ 0.7909, -1.1356,  1.9051, -1.9751],
        [-0.9881, -0.2608,  1.1175, -1.2093]])

In [24]:
x_torch

tensor([[ 1.8681,  0.1889,  1.8907,  0.8593],
        [ 0.7909, -1.1356,  1.9051, -1.9751],
        [-0.9881, -0.2608,  1.1175, -1.2093]])

In [25]:
y_torch = activation(x_torch)
y_torch

tensor([[ 4.7362,  1.3779,  4.7815,  2.7185],
        [ 2.5818, -1.2713,  4.8102, -2.9502],
        [-0.9761,  0.4783,  3.2351, -1.4185]])

In [26]:
x_jax

Array([[ 1.8681194 ,  0.18892899,  1.8907374 ,  0.85926396],
       [ 0.7909153 , -1.135642  ,  1.9050978 , -1.9750789 ],
       [-0.98807055, -0.26083386,  1.1175317 , -1.2092597 ]],      dtype=float32)

In [27]:
y_jax = activation(x_jax) 
y_jax

Array([[ 4.7362385 ,  1.3778579 ,  4.781475  ,  2.7185278 ],
       [ 2.5818305 , -1.2712841 ,  4.810196  , -2.9501579 ],
       [-0.9761411 ,  0.47833228,  3.2350633 , -1.4185195 ]],      dtype=float32)

In [28]:
torch.sum(x_torch, dim = (1,0))

tensor(3.0517)

In [29]:
jnp.sum(x_jax, axis = (1,0))

Array(3.0517094, dtype=float32)

In [30]:
x_jax

Array([[ 1.8681194 ,  0.18892899,  1.8907374 ,  0.85926396],
       [ 0.7909153 , -1.135642  ,  1.9050978 , -1.9750789 ],
       [-0.98807055, -0.26083386,  1.1175317 , -1.2092597 ]],      dtype=float32)

In [31]:
x_torch.expand((3,3))

RuntimeError: The expanded size of the tensor (3) must match the existing size (4) at non-singleton dimension 1.  Target sizes: [3, 3].  Tensor sizes: [3, 4]

In [None]:
torch.linspace(1,2,5).expand((2,5)).shape[0]

2

In [None]:
jnp.broadcast_to(jnp.linspace(1,5,5), (2,5)).shape

(2, 5)

In [None]:
torch.arange(1,6).shape

torch.Size([5])

In [None]:
jnp.arange(1,6).shape

(5,)

In [None]:
import equinox as eqx


In [None]:
class class2(eqx.Module):
    attr2: int

    def __init__(self, att):
        self.attr2 = att


class class1(eqx.Module):
    attr1: class2 
    def __init__(self, x):
            self.attr1 = class2(x)

    def change(self,x):
         object.__setattr__(self.attr1, 'attr2', self.attr1.attr2 + 2)


In [None]:
class_1 = class1(9)

In [None]:
class_2 = class2(32)
class_2.attr2= 21

FrozenInstanceError: cannot assign to field 'attr2'

In [None]:
# print(class_2.attr2)
print(class_1.attr1.attr2)

11


In [None]:
object.__setattr__(class_2, 'attr2', 199)

In [None]:
print(class_2.attr2)
print(class_1.attr1)

207
class2(attr2=207)


In [None]:
class_1.change(2)


In [None]:
print(class_1.attr1.attr2)

13


In [None]:
x_jax.to_device(x_jax.device)

Array([[ 1.8681194 ,  0.18892899,  1.8907374 ,  0.85926396],
       [ 0.7909153 , -1.135642  ,  1.9050978 , -1.9750789 ],
       [-0.98807055, -0.26083386,  1.1175317 , -1.2092597 ]],      dtype=float32)

In [None]:
jnp.arange(0 ,10)

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [None]:
a = x_jax.clone()

a+=1

In [None]:
a

Array([[ 2.8681192 ,  1.188929  ,  2.8907375 ,  1.8592639 ],
       [ 1.7909153 , -0.13564205,  2.905098  , -0.97507894],
       [ 0.01192945,  0.73916614,  2.1175318 , -0.20925975]],      dtype=float32)

In [None]:
jnp.arange(0, 2).dtype

dtype('int32')

In [None]:
torch.arange(0,2)

tensor([0, 1])

In [None]:
b =torch.clone(x_torch)
b+=1

In [None]:
x_torch

tensor([[ 1.8681,  0.1889,  1.8907,  0.8593],
        [ 0.7909, -1.1356,  1.9051, -1.9751],
        [-0.9881, -0.2608,  1.1175, -1.2093]])

In [None]:
b

tensor([[ 2.8681,  1.1889,  2.8907,  1.8593],
        [ 1.7909, -0.1356,  2.9051, -0.9751],
        [ 0.0119,  0.7392,  2.1175, -0.2093]])

In [58]:
d = torch.arange(3*18, dtype=torch.float).reshape(3, 2, 3,3)
torch.norm(d, dim=(1, 2,3), p= 2)

tensor([ 42.2493, 114.5644, 190.0763])

In [91]:
d = torch.arange(3*18, dtype=torch.float)
(d >1).nonzero().view(-1)

tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
        20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
        38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53])

In [99]:
d = jnp.arange(3*18)
jnp.where(d>1)[0]

Array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52,
       53], dtype=int32)

In [None]:
jnp.where(x_jax>1)

Array([0, 0, 1, 2], dtype=int32)

In [None]:
d = jnp.arange(3*18).reshape(3,2,3,3)
jnp.sqrt(jnp.sum(d**2, axis=(1, 2, 3)))


Array([ 42.24926, 114.56439, 190.0763 ], dtype=float32)

In [104]:
jnp.sum((x_jax > 1)).item()

4

In [107]:
jnp.linalg.norm(x_jax)

Array(4.586922, dtype=float32)

In [None]:
x_jax.to

In [None]:
torch.norm(x_torch)

tensor(4.5869)