In [1]:
import numpy as np
import pandas as pd
from model import CNN
from pysr import PySRRegressor
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch



Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


### Load Model and get Kernels

In [2]:
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pt'))
print(cnn)

for name, param in cnn.named_parameters():
    if name == 'conv1.weight':
        print(f"amount of kernels of Conv1: {param.shape}")
        kernels1 = param
    if name == 'conv2.weight':
        print(f"amount of kernels of Conv2: {param.shape}")
print(f"kernels of first layer:\n{kernels1}")

  cnn.load_state_dict(torch.load('cnn.pt'))


CNN(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (out): Linear(in_features=1568, out_features=10, bias=True)
)
amount of kernels of Conv1: torch.Size([16, 1, 5, 5])
amount of kernels of Conv2: torch.Size([32, 16, 5, 5])
kernels of first layer:
Parameter containing:
tensor([[[[-0.2335, -0.2440, -0.9097, -1.0703, -0.8168],
          [-0.0917,  0.0114, -0.2991, -0.3230,  0.6103],
          [ 0.0865,  0.0287,  0.0781,  0.3111,  0.9103],
          [ 0.0553, -0.7562, -0.8900, -0.9722, -0.9644],
          [ 0.0107, -0.6425, -0.8105, -0.5706, -1.1133]]],


        [[[-0.0280, -0.6446, -1.0783, -0.0510, -0.1909],
          [-1.2921, -1.5182, -0.0516,  0.5699, -0.5167],
          [-

### Load Dataset and get results

In [3]:
test_data = datasets.MNIST(root='data', train=False, transform=ToTensor(),)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=10, shuffle=True, num_workers=1)
samples, labels = next(iter(test_loader))

cnn.eval()
with torch.no_grad():
    results = cnn(samples)

### Prepare Data for PySR
Extract the 5x5 submatrices (incl. padding) from images to use them as input

In [4]:
# Take every image and split it into 5x5 submatrices => np.array.shape = (7840, 25)
# 25 <- flattened 5x5 patch
# 7840 <- 28 * 28 patches per image * 10 images (batch_size)
kernel_size = 5
X = None
for x in samples:
    x = torch.nn.functional.pad(input=x[0], pad=(2, 2, 2, 2), mode="constant", value=0)
    for i, j in np.ndindex((x.size()[0] - kernel_size + 1, x.size()[1] - kernel_size + 1)):
        slice = x[i:i + kernel_size, j:j + kernel_size]
        if X is None:
            X = np.array([slice.numpy().flatten()])
        else:
            X = np.concatenate((X, [slice.numpy().flatten()]))

print(X.shape)

# Get the result for every 5x5 submatrix for each kernel => np.array.shape = (16, 7840)
# 16 <- amount of kernels in the first layer
y = results['relu1'].numpy().transpose(1, 0, 2, 3).reshape(16, 7840)
print(y.shape)

(7840, 25)
(16, 7840)


### Symbolic Regression
#### Over all 16 Kernels

In [None]:
regr_functions = pd.DataFrame()
regr_functions.index.names = ['complexity']
for i in range(16):
    regr = PySRRegressor(
        niterations=40,
        binary_operators=["+", "*", "-", "/"],
        unary_operators=[
            "cos",
            "exp",
            "sin",
            "square",
            "cube",
            "inv(x) = 1/x",  # Julia syntax
        ],
        extra_sympy_mappings={"inv": lambda x: 1 / x},  # Sympy syntax
        elementwise_loss="loss(prediction, target) = (prediction - target)^2",  # Julia syntax
        warm_start=False,
        verbosity=0,
        temp_equation_file=True,
    )

    regr.fit(X, y[i])
    # print(regr.equations_)
    regr_functions.insert(loc=i, column=f'Kernel {i}', value=regr.equations_['equation'])
    print(regr_functions)

In [None]:
regr_functions.to_csv('regression_conv1_relu1.csv')

#### Over the first kernel multiple times

In [5]:
regr_stability = pd.DataFrame()
regr_stability.index.names = ['complexity']
for i in range(10):
    regr = PySRRegressor(
        niterations=40,
        binary_operators=["+", "*", "-", "/"],
        unary_operators=[
            "cos",
            "exp",
            "sin",
            "square",
            "cube",
            "inv(x) = 1/x",  # Julia syntax
        ],
        extra_sympy_mappings={"inv": lambda x: 1 / x},  # Sympy syntax
        elementwise_loss="loss(prediction, target) = (prediction - target)^2",  # Julia syntax
        warm_start=False,
        verbosity=0,
        temp_equation_file=True,
    )

    regr.fit(X, y[0])
    # print(regr.equations_)
    regr_stability.insert(loc=i, column=f'Iteration {i}', value=regr.equations_['equation'])
    print(regr_stability)



                                                  Iteration 0
complexity                                                   
0                                                          x2
1                                           cube(0.102201656)
2                                            0.00848108 * x14
3                                   (x14 - x19) * 0.011500353
4                                   cube(x13 / exp(exp(x17)))
5                                cube(sin(x9) / exp(exp(x4)))
6                               cube(x9 / exp(x4 + exp(x23)))
7                       square(x9 / exp(cube(x4 + exp(x20))))
8                       cube(x9 / cube((x17 + exp(x4)) + x1))
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...
13          sin(square(x9) / cube(cube(((x4 + cube(exp(x2)...
14      



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9



                                                  Iteration 0  \
complexity                                                      
0                                                          x2   
1                                           cube(0.102201656)   
2                                            0.00848108 * x14   
3                                   (x14 - x19) * 0.011500353   
4                                   cube(x13 / exp(exp(x17)))   
5                                cube(sin(x9) / exp(exp(x4)))   
6                               cube(x9 / exp(x4 + exp(x23)))   
7                       square(x9 / exp(cube(x4 + exp(x20))))   
8                       cube(x9 / cube((x17 + exp(x4)) + x1))   
9                  cube(sin(x9 / cube(x17 + (x4 + exp(x1)))))   
10                cube(x9 / cube(x17 + (exp(x4 + x2) + x24)))   
11          square(x9 / square(cube(((exp(x2) + x4) + x24)...   
12          square(sin(x9 / cube((x17 + square(exp(x1) + x...   
13          sin(square(x9

In [6]:
regr_stability.to_csv('stability_conv1_relu1_kernel1.csv')