In [None]:
def Helmholz_loss(u,a,boundary,omega=5*torch.pi/2,p=1,D=1):
    '''Calculates the PDE loss for the Poisson equation
    Input: u  Output of network, Shape = (Batch_size,Grid_size,Grid_size)
           a  Input  of network, Shape = (Batch_size,Grid_size,Grid_size)
           boundary Boundary conditon of u, float
           omega wave number
           p  Do we use L1 or L2 errors? Default: L1
           D  Period of Domain
    Warning: Input f and Output u should not be normalized!'''

    Laplace_u=Laplace(u,D=D)

    if p == 1:
      loss = torch.nn.L1Loss()
    elif p == 2:
      loss = torch.nn.MSELoss()

    loss_pde=loss(Laplace_u,-omega**2*a**2*u)

    #Add boundary loss: u=0 on boundary(Domain)
    boundary_lossx_0=loss(u[:,0,:], boundary*torch.ones_like(u[:,0,:]))
    boundary_lossx_D=loss(u[:,-1,:],boundary*torch.ones_like(u[:,-1,:]))
    boundary_lossy_D=loss(u[:,:,-1],boundary*torch.ones_like(u[:,:,1]))
    boundary_lossy_0=loss(u[:,:,0], boundary*torch.ones_like(u[:,:,0]))
    boundary_loss=0.25*(boundary_lossx_0+boundary_lossy_0+boundary_lossx_D+boundary_lossy_D)
    loss=0.5*(loss_pde+boundary_loss)

    return loss

In [None]:
import torch
import numpy as np
def Poisson_pde_loss(u,f,p,D=2):
    '''Calculates the pde loss for the Poisson equation'''
    s=u.size(-1)
    
    u_hat=torch.fft.fft2(u,dim=[-2,-1])
    assert (u.device==u_hat.device)
    k_max=s//2
    
    k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=u.device),
                     torch.arange(start=-k_max, end=0, step=1, device=u.device)), 0).reshape(s, 1).repeat(1, s).reshape(1,s,s)
    k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=u.device),
                     torch.arange(start=-k_max, end=0, step=1, device=u.device)), 0).reshape(1, s).repeat(s, 1).reshape(1,s,s)
    
    
    Laplace_u_hat =-4*(torch.pi/D)**2*(k_x**2+k_y**2)*u_hat


    Laplace_u=torch.fft.irfft2(Laplace_u_hat[:, :, :k_max + 1], dim=[-2, -1])

    if p == 1:
      loss = torch.nn.L1Loss()
    elif p == 2:
      loss = torch.nn.MSELoss()
    
    epsilon=1e-10

    loss_pde = loss(-Laplace_u, f)/(loss(torch.zeros_like(f).to(u.device), f) + epsilon)
    
    return loss_pde, Laplace_u


In [61]:
#Check if loss function works correctly

def calculate_f(x, y, K, a, r):
    result = 0
    for i in range(len(a)):
        for j in range(len(a[i])):
            result += a[i][j] *  ((i)**2 + (j)**2)**0.5 * np.sin(np.pi * (i) * x) * np.sin(np.pi * (j) * y)
    return (np.pi  / ( K**2))* result


def calculate_u(x, y, K, a, r):
    result = 0
    for i in range(len(a)):
        for j in range(len(a[i])):
            if (i**2 + j**2) == 0:
                result += 0  
            else:
                result += a[i][j] * ((i)**2 + (j)**2)**(-1/2) * np.sin(np.pi * (i) * x) * np.sin(np.pi * (j) * y)
    return (1 / (np.pi * K**2)) * result

s=64
f=torch.zeros((1,1,s,s))
u=torch.zeros((1,1,s,s))
error=torch.rand((1,1,s,s))*1e-5
#Really senitive to error

K=3
a = 2 * torch.rand((K,K)) - 1
r=0.5
D=2
for x in range(0,s):
    for y in range(0,s):
        f[0,0,x,y]=calculate_f(D*x/s,D* y/s, K, a, r)
        u[0,0,x,y]=calculate_u(D*x/s,D* y/s, K, a, r)

print(calculate_f(D*0/s,D* 0/s, K, a, r)-calculate_f(D*(s+1)/s,D* (s+1)/s, K, a, r))
u=u

loss,u_lap=Poisson_pde_loss(u.squeeze(1),f.squeeze(1),1,D=D)

print(loss)
y=4
import matplotlib.pyplot as plt
#plt.plot(np.arange(0,s),u[0,0,:,y]-error[0,0,:,y],label='u_error')
plt.plot(np.arange(0,s),u[0,0,:,y],label='u')
plt.plot(np.arange(0,s),f[0,0,:,y],label='f')
plt.plot(np.arange(0,s),u_lap[0,:,y],label='u_lab')
plt.legend()

print(u_lap[0,:,y]/f[0,0,:,y])

tensor(0.0140)


IndexError: too many indices for tensor of dimension 3

In [177]:
def Wave_pde_loss(u,u0,c=0.1,T=5,p=1):
    '''Calculates the PDE loss for the Wave equation
       Input:  u   Output of network, Shape = (Batch_size,Spatial_grid_size,Spatial_grid_size,Temporal_grid_size)
               u0  Input  of network, Shape = (Batch_size,Spatial_grid_size,Spatial_grid_size)
               p   Do we use L1 i.e (p==1) or L2 i.e (p==2) errors? Default: L1'''
   
    Temporal_grid_size=u.size(1)
    Spatial_grid_size=u.size(-1)
    #Calculate derivative using the fact FT(f')=2*pi*i*k*FT(f) (Where we use torch, Bogdan uses scypi)
    u_hat=torch.fft.fft2(u,dim=[-2,-1])
    print(u_hat.shape)

    assert (u.device==u_hat.device) #Need to be same device, can only be checked on GPU

    #Doesnt this make it a grid based loss? (No it shouldnt!)
    k_max=Spatial_grid_size//2
    dt=T/(Temporal_grid_size-1)
    
    k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=u.device),
                     torch.arange(start=-k_max, end=0, step=1, device=u.device)), 0).reshape(Spatial_grid_size, 1).repeat(1, Spatial_grid_size).reshape(1,1,Spatial_grid_size,Spatial_grid_size)
    k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=u.device),
                     torch.arange(start=-k_max, end=0, step=1, device=u.device)), 0).reshape(1, Spatial_grid_size).repeat(Spatial_grid_size, 1).reshape(1,1,Spatial_grid_size,Spatial_grid_size)
    
    ux_hat  = 2j *torch.pi*k_x*u_hat
    uxx_hat = 2j *torch.pi*k_x*ux_hat
    
    uy_hat  = 2j*torch.pi*k_y*u_hat
    uyy_hat = 2j*torch.pi*k_y*uy_hat

    
    uxx = torch.fft.irfft2(uxx_hat[:, :,:, :k_max + 1],   dim=[-2, -1])
    uyy = torch.fft.irfft2(uyy_hat[:, :,:, :k_max + 1],   dim=[-2, -1])

    #Calculate Laplace-Operator for u
    Du=(uyy+uxx)/4
   
    utt = (u[:, 2:,...] - 2.0*u[:, 1:-1,...] + u[:, :-2,...]) / (dt**2)

    if p == 1:
      loss = torch.nn.L1Loss()
    elif p == 2:
      loss = torch.nn.MSELoss()
    
    epsilon=1e-10

    #Here I am not sure if I should use the relative loss
    loss_pde = loss(Du[:, 1:-1,:],c**2* utt)
      
    print(loss_pde)
    #Add relative boundary loss: u=0 on boundary(Domain)
    inital_loss=loss(u0, u[:,0,...])/(loss(torch.zeros_like(u0).to(u0.device), u0)+epsilon)

    return 0.5*(loss_pde+inital_loss)

In [189]:
def calculate_u0(x, y, K, a):
    result = 0
    for i in range(len(a)):
        for j in range(len(a[i])):
            if (i**2 + j**2) == 0:
                result += 0  
            else:
               result += a[i][j] * (i**2 + j**2)**(-1) * np.sin(np.pi * i * x) * np.sin(np.pi * j * y)
    return (1 / (K**2))* result
def calculate_u_wave(x, y, t, K, a,c):
    result = 0
    for i in range(len(a)):
        for j in range(len(a[i])):
            if (i**2 + j**2) == 0:
                result += 0  
            else:
                result += a[i][j] * (i**2 + j**2)**(-1) * np.sin(np.pi * i * x) * np.sin(np.pi * j * y)*np.cos(np.pi *c* np.sqrt(i**2+j**2) * t)
    return (1 / (K**2))* result
s=32
T=10
u0=torch.zeros((1,1,s,s))
u=torch.zeros((1,T,s,s))
error=torch.rand((1,1,s,s))*1e-10
K=3
a = 2 * torch.rand((K,K)) - 1
c=0.1
r=0.5
for x in range(-s//2,s//2):
    for y in range(-s//2,s//2):  
        u0[0,0,x+s//2,y+s//2]=calculate_u0(x/s*2, y/s*2, K, a)
        for t in range(T):
          u[0,t,x+s//2,y+s//2]=calculate_u_wave(x/s*2, y/s*2,t/1000, K, a, c)
#u=u+error

loss=Wave_pde_loss(u,u0.squeeze(1),2)

print(loss)

torch.Size([1, 10, 32, 32])
tensor(0.4671)
tensor(0.2335)


In [204]:
x = torch.rand(1,4,4,2)
s=4//2
m=torch.fft.fft2(x,dim=[-3,-2])
x_h=torch.fft.irfft2(m[:,:,:s+1,:],dim=[-3,-2])
print(torch.norm(x-x_h))

c=torch.rand(10,10)
a=(2*torch.pi*c)**2
b=(c)**2*(2*torch.pi)**2
a-b

tensor(1.5699e-07)


tensor([[4.7684e-07, 0.0000e+00, 0.0000e+00, 1.9073e-06, 2.3842e-07, 2.3842e-07,
         1.9073e-06, 7.4506e-09, 5.9605e-08, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.8147e-06, 5.9605e-08, 2.3842e-07, 5.9605e-08,
         3.8147e-06, 0.0000e+00, 1.9073e-06, 2.3283e-10],
        [2.8610e-06, 9.5367e-07, 0.0000e+00, 2.3842e-07, 0.0000e+00, 1.9073e-06,
         3.8147e-06, 5.7220e-06, 4.7684e-07, 4.7684e-07],
        [1.9073e-06, 0.0000e+00, 3.8147e-06, 9.5367e-07, 4.6566e-10, 4.7684e-07,
         3.8147e-06, 1.9073e-06, 9.5367e-07, 9.5367e-07],
        [2.3842e-07, 9.5367e-07, 9.5367e-07, 1.9073e-06, 9.5367e-07, 7.6294e-06,
         9.5367e-07, 2.3842e-07, 0.0000e+00, 1.1642e-10],
        [7.6294e-06, 1.9073e-06, 3.8147e-06, 3.8147e-06, 5.7220e-06, 4.7684e-07,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9802e-08],
        [0.0000e+00, 1.9073e-06, 2.9802e-08, 9.5367e-07, 1.9073e-06, 2.9802e-08,
         1.9073e-06, 7.6294e-06, 5.9605e-08, 0.0000e+00],
        [4.7684e-07, 1.1921

In [11]:
import torch
import torch.nn as nn

# Create a sample model
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),
    nn.Linear(5, 2),
)

num_layers_to_freeze=0.25
# Print the initial gradient status


# Freeze specific layers or parameters
for i, param in enumerate(model.parameters()):
    if i < num_layers_to_freeze*len(list(model.parameters())):
        param.requires_grad = False
 


# Put the model in evaluation mode
model.eval()

# Print the updated gradient status
for name, param in model.named_parameters():
    print(f"Parameter: {name}, Requires Grad: {param.requires_grad}")

Parameter: 0.weight, Requires Grad: False
Parameter: 0.bias, Requires Grad: True
Parameter: 2.weight, Requires Grad: True
Parameter: 2.bias, Requires Grad: True


In [8]:
from collections import defaultdict

# Define your parameter names
parameter_names = [
"project.bias" ,
"project.convolution.weight",
"project.convolution.bias",
"lift.bias",
"lift.convolution.weight",
"lift.convolution.bias",
"cont_conv_layers.0.bias",
"cont_conv_layers.0.convolution.weight",
"cont_conv_layers.0.convolution.bias",
"cont_conv_layers.1.bias",
"cont_conv_layers.1.convolution.weight",
"cont_conv_layers.1.convolution.bias",
"cont_conv_layers.2.bias",
"cont_conv_layers.2.convolution.weight",
"cont_conv_layers.2.convolution.bias",
"cont_conv_layers.3.bias",
"cont_conv_layers.3.convolution.weight",
"cont_conv_layers.3.convolution.bias",
"cont_conv_layers_invariant.0.bias",
"cont_conv_layers_invariant.0.convolution.weight",
"cont_conv_layers_invariant.0.convolution.bias",
"cont_conv_layers_invariant.1.bias",
"cont_conv_layers_invariant.1.convolution.weight",
"cont_conv_layers_invariant.1.convolution.bias",
"cont_conv_layers_invariant.2.bias",
"cont_conv_layers_invariant.2.convolution.weight",
"cont_conv_layers_invariant.2.convolution.bias",
"cont_conv_layers_invariant.3.bias",
"cont_conv_layers_invariant.3.convolution.weight",
"cont_conv_layers_invariant.3.convolution.bias",
"res_batch_norm_in.0.weight",
"res_batch_norm_in.0.bias",
"res_batch_norm_in.1.weight",
"res_batch_norm_in.1.bias",
"res_batch_norm_in.2.weight",
"res_batch_norm_in.2.bias",
"batch_norm.0.weight",
"batch_norm.0.bias",
"batch_norm.1.weight",
"batch_norm.1.bias",
"batch_norm.2.weight",
"batch_norm.2.bias",
"batch_norm.3.weight",
"batch_norm.3.bias",
"batch_norm_inv.0.weight",
"batch_norm_inv.0.bias",
"batch_norm_inv.1.weight",
"batch_norm_inv.1.bias",
"batch_norm_inv.2.weight",
"batch_norm_inv.2.bias",
"batch_norm_inv.3.weight",
"batch_norm_inv.3.bias",
"resnet_blocks.0.cont_conv.0.bias",
"resnet_blocks.0.cont_conv.0.convolution.weight",
"resnet_blocks.0.cont_conv.0.convolution.bias",
"resnet_blocks.0.cont_conv.1.bias",
"resnet_blocks.0.cont_conv.1.convolution.weight",
"resnet_blocks.0.cont_conv.1.convolution.bias",
"resnet_blocks.1.cont_conv.0.bias",
"resnet_blocks.1.cont_conv.0.convolution.weight",
"resnet_blocks.1.cont_conv.0.convolution.bias",
"resnet_blocks.1.cont_conv.1.bias",
"resnet_blocks.1.cont_conv.1.convolution.weight",
"resnet_blocks.1.cont_conv.1.convolution.bias",
"resnet_blocks.2.cont_conv.0.bias",
"resnet_blocks.2.cont_conv.0.convolution.weight",
"resnet_blocks.2.cont_conv.0.convolution.bias",
"resnet_blocks.2.cont_conv.1.bias",
"resnet_blocks.2.cont_conv.1.convolution.weight",
"resnet_blocks.2.cont_conv.1.convolution.bias",
"resnet_blocks.3.cont_conv.0.bias",
"resnet_blocks.3.cont_conv.0.convolution.weight",
"resnet_blocks.3.cont_conv.0.convolution.bias",
"resnet_blocks.3.cont_conv.1.bias",
"resnet_blocks.3.cont_conv.1.convolution.weight",
"resnet_blocks.3.cont_conv.1.convolution.bias",
"resnet_blocks.4.cont_conv.0.bias",
"resnet_blocks.4.cont_conv.0.convolution.weight",
"resnet_blocks.4.cont_conv.0.convolution.bias",
"resnet_blocks.4.cont_conv.1.bias",
"resnet_blocks.4.cont_conv.1.convolution.weight",
"resnet_blocks.4.cont_conv.1.convolution.bias",
"resnet_blocks.5.cont_conv.0.bias",
"resnet_blocks.5.cont_conv.0.convolution.weight",
"resnet_blocks.5.cont_conv.0.convolution.bias",
"resnet_blocks.5.cont_conv.1.bias",
"resnet_blocks.5.cont_conv.1.convolution.weight",
"resnet_blocks.5.cont_conv.1.convolution.bias",
"resnet_blocks.6.cont_conv.0.bias",
"resnet_blocks.6.cont_conv.0.convolution.weight",
"resnet_blocks.6.cont_conv.0.convolution.bias",
"resnet_blocks.6.cont_conv.1.bias",
"resnet_blocks.6.cont_conv.1.convolution.weight",
"resnet_blocks.6.cont_conv.1.convolution.bias",
"resnet_blocks.7.cont_conv.0.bias",
"resnet_blocks.7.cont_conv.0.convolution.weight",
"resnet_blocks.7.cont_conv.0.convolution.bias",
"resnet_blocks.7.cont_conv.1.bias",
"resnet_blocks.7.cont_conv.1.convolution.weight",
"resnet_blocks.7.cont_conv.1.convolution.bias",
"resnet_blocks.8.cont_conv.0.bias",
"resnet_blocks.8.cont_conv.0.convolution.weight",
"resnet_blocks.8.cont_conv.0.convolution.bias",
"resnet_blocks.8.cont_conv.1.bias",
"resnet_blocks.8.cont_conv.1.convolution.weight",
"resnet_blocks.8.cont_conv.1.convolution.bias",
"resnet_blocks.9.cont_conv.0.bias",
"resnet_blocks.9.cont_conv.0.convolution.weight",
"resnet_blocks.9.cont_conv.0.convolution.bias",
"resnet_blocks.9.cont_conv.1.bias",
"resnet_blocks.9.cont_conv.1.convolution.weight",
"resnet_blocks.9.cont_conv.1.convolution.bias",
"resnet_blocks.10.cont_conv.0.bias",
"resnet_blocks.10.cont_conv.0.convolution.weight",
"resnet_blocks.10.cont_conv.0.convolution.bias",
"resnet_blocks.10.cont_conv.1.bias",
"resnet_blocks.10.cont_conv.1.convolution.weight",
"resnet_blocks.10.cont_conv.1.convolution.bias",
"resnet_blocks.11.cont_conv.0.bias",
"resnet_blocks.11.cont_conv.0.convolution.weight",
"resnet_blocks.11.cont_conv.0.convolution.bias",
"resnet_blocks.11.cont_conv.1.bias",
"resnet_blocks.11.cont_conv.1.convolution.weight",
"resnet_blocks.11.cont_conv.1.convolution.bias",
"resnet_blocks.12.cont_conv.0.bias",
"resnet_blocks.12.cont_conv.0.convolution.weight",
"resnet_blocks.12.cont_conv.0.convolution.bias",
"resnet_blocks.12.cont_conv.1.bias",
"resnet_blocks.12.cont_conv.1.convolution.weight",
"resnet_blocks.12.cont_conv.1.convolution.bias",
"resnet_blocks.13.cont_conv.0.bias",
"resnet_blocks.13.cont_conv.0.convolution.weight",
"resnet_blocks.13.cont_conv.0.convolution.bias",
"resnet_blocks.13.cont_conv.1.bias",
"resnet_blocks.13.cont_conv.1.convolution.weight",
"resnet_blocks.13.cont_conv.1.convolution.bias"   ,
"resnet_blocks.14.cont_conv.0.bias",
"resnet_blocks.14.cont_conv.0.convolution.weight",
"resnet_blocks.14.cont_conv.0.convolution.bias",
"resnet_blocks.14.cont_conv.1.bias",
"resnet_blocks.14.cont_conv.1.convolution.weight",
"resnet_blocks.14.cont_conv.1.convolution.bias",
"resnet_blocks.15.cont_conv.0.bias",
"resnet_blocks.15.cont_conv.0.convolution.weight",
"resnet_blocks.15.cont_conv.0.convolution.bias",
"resnet_blocks.15.cont_conv.1.bias",
"resnet_blocks.15.cont_conv.1.convolution.weight",
"resnet_blocks.15.cont_conv.1.convolution.bias",
"resnet_blocks.16.cont_conv.0.bias",
"resnet_blocks.16.cont_conv.0.convolution.weight",
"resnet_blocks.16.cont_conv.0.convolution.bias",
"resnet_blocks.16.cont_conv.1.bias",
"resnet_blocks.16.cont_conv.1.convolution.weight",
"resnet_blocks.16.cont_conv.1.convolution.bias",
"resnet_blocks.17.cont_conv.0.bias",
"resnet_blocks.17.cont_conv.0.convolution.weight",
"resnet_blocks.17.cont_conv.0.convolution.bias",
"resnet_blocks.17.cont_conv.1.bias",
"resnet_blocks.17.cont_conv.1.convolution.weight",
"resnet_blocks.17.cont_conv.1.convolution.bias",
"resnet_blocks.18.cont_conv.0.bias",
"resnet_blocks.18.cont_conv.0.convolution.weight",
"resnet_blocks.18.cont_conv.0.convolution.bias",
"resnet_blocks.18.cont_conv.1.bias",
"resnet_blocks.18.cont_conv.1.convolution.weight",
"resnet_blocks.18.cont_conv.1.convolution.bias",
"resnet_blocks.19.cont_conv.0.bias",
"resnet_blocks.19.cont_conv.0.convolution.weight",
"resnet_blocks.19.cont_conv.0.convolution.bias",
"resnet_blocks.19.cont_conv.1.bias",
"resnet_blocks.19.cont_conv.1.convolution.weight",
"resnet_blocks.19.cont_conv.1.convolution.bias",
"resnet_blocks.20.cont_conv.0.bias",
"resnet_blocks.20.cont_conv.0.convolution.weight",
"resnet_blocks.20.cont_conv.0.convolution.bias",
"resnet_blocks.20.cont_conv.1.bias",
"resnet_blocks.20.cont_conv.1.convolution.weight",
"resnet_blocks.20.cont_conv.1.convolution.bias",
"resnet_blocks.21.cont_conv.0.bias",
"resnet_blocks.21.cont_conv.0.convolution.weight",
"resnet_blocks.21.cont_conv.0.convolution.bias",
"resnet_blocks.21.cont_conv.1.bias",
"resnet_blocks.21.cont_conv.1.convolution.weight",
"resnet_blocks.21.cont_conv.1.convolution.bias",
"resnet_blocks.22.cont_conv.0.bias",
"resnet_blocks.22.cont_conv.0.convolution.weight",
"resnet_blocks.22.cont_conv.0.convolution.bias",
"resnet_blocks.22.cont_conv.1.bias",
"resnet_blocks.22.cont_conv.1.convolution.weight",
"resnet_blocks.22.cont_conv.1.convolution.bias",
"resnet_blocks.23.cont_conv.0.bias",
"resnet_blocks.23.cont_conv.0.convolution.weight",
"resnet_blocks.23.cont_conv.0.convolution.bias",
"resnet_blocks.23.cont_conv.1.bias",
"resnet_blocks.23.cont_conv.1.convolution.weight",
"resnet_blocks.23.cont_conv.1.convolution.bias",
"resnet_blocks.24.cont_conv.0.bias",
"resnet_blocks.24.cont_conv.0.convolution.weight",
"resnet_blocks.24.cont_conv.0.convolution.bias",
"resnet_blocks.24.cont_conv.1.bias",
"resnet_blocks.24.cont_conv.1.convolution.weight",
"resnet_blocks.24.cont_conv.1.convolution.bias",
"resnet_blocks.25.cont_conv.0.bias",
"resnet_blocks.25.cont_conv.0.convolution.weight",
"resnet_blocks.25.cont_conv.0.convolution.bias",
"resnet_blocks.25.cont_conv.1.bias",
"resnet_blocks.25.cont_conv.1.convolution.weight",
"resnet_blocks.25.cont_conv.1.convolution.bias",
"resnet_blocks.26.cont_conv.0.bias",
"resnet_blocks.26.cont_conv.0.convolution.weight",
"resnet_blocks.26.cont_conv.0.convolution.bias",
"resnet_blocks.26.cont_conv.1.bias",
"resnet_blocks.26.cont_conv.1.convolution.weight",
"resnet_blocks.26.cont_conv.1.convolution.bias",
"resnet_blocks.27.cont_conv.0.bias",
"resnet_blocks.27.cont_conv.0.convolution.weight",
"resnet_blocks.27.cont_conv.0.convolution.bias",
"resnet_blocks.27.cont_conv.1.bias",
"resnet_blocks.27.cont_conv.1.convolution.weight",
"resnet_blocks.27.cont_conv.1.convolution.bias",
"resnet_blocks.28.cont_conv.0.bias",
"resnet_blocks.28.cont_conv.0.convolution.weight",
"resnet_blocks.28.cont_conv.0.convolution.bias",
"resnet_blocks.28.cont_conv.1.bias",
"resnet_blocks.28.cont_conv.1.convolution.weight",
"resnet_blocks.28.cont_conv.1.convolution.bias",
"resnet_blocks.29.cont_conv.0.bias",
"resnet_blocks.29.cont_conv.0.convolution.weight",
"resnet_blocks.29.cont_conv.0.convolution.bias",
"resnet_blocks.29.cont_conv.1.bias",
"resnet_blocks.29.cont_conv.1.convolution.weight",
"resnet_blocks.29.cont_conv.1.convolution.bias "
]
#parameter_names=["resnet_blocks","lift,res_batch_norm_in","cont_conv_layers_invariant","cont_conv_layers","lift","project"]
# Group parameter names by their common prefixes

layer_count_by_prefix = {}

for name in parameter_names:
    parts = name.split(".")[0]  # Get the part before the dot
    count = name.split(".")[1]
    if parts in layer_count_by_prefix:
        layer_count_by_prefix[parts] += 1
    else:
        layer_count_by_prefix[parts] = 1

print(layer_count_by_prefix)
# Print the number of layers in each group
for prefix, count in layer_count_by_prefix.items():
    print(f"Prefix: {prefix}, Number of Layers: {count}")



 



{'project': 3, 'lift': 3, 'cont_conv_layers': 12, 'cont_conv_layers_invariant': 12, 'res_batch_norm_in': 6, 'batch_norm': 8, 'batch_norm_inv': 8, 'resnet_blocks': 180}
Prefix: project, Number of Layers: 3
Prefix: lift, Number of Layers: 3
Prefix: cont_conv_layers, Number of Layers: 12
Prefix: cont_conv_layers_invariant, Number of Layers: 12
Prefix: res_batch_norm_in, Number of Layers: 6
Prefix: batch_norm, Number of Layers: 8
Prefix: batch_norm_inv, Number of Layers: 8
Prefix: resnet_blocks, Number of Layers: 180
