# Playground script to test functions and inspect variables

## Logical Data Generator

In [27]:
from lja.data_generators.logical_data_generator import LogicalDataGenerator

In [28]:
data = LogicalDataGenerator(size=10000)

In [30]:
data.data.type()

'torch.FloatTensor'

In [10]:
data.label

tensor([0, 0, 1,  ..., 1, 1, 1])

## Trained Logical Model

In [11]:
from lja.managers.training_manager import LogicalNetworkTrainingManager
import torch

In [57]:
training_exp = LogicalNetworkTrainingManager()
model = training_exp.net

[Loading configurations]
device : auto
results_dir : results/
data_dir : data/
model_save_interval : 10
print_log_interval : 100
num_models_to_keep : 2
networks : {'general': {'optimizer': 'sgd', 'lr': 0.005, 'momentum': 0.9, 'dampening': 0.1, 'weight_decay': 0.001, 'training_results_dir': 'results/network_training/'}, 'logical': {'load_model_name': 'model_logical', 'num_epochs': 100, 'batch_size': 16, 'sizes': [4, 128, 64, 1]}, 'mnist': {'load_model_name': None, 'num_epochs': 1000, 'batch_size': 32, 'sizes': [784, 1024, 1024, 512, 10]}}
Loading model from results/network_training/logical/model_logical
Loaded model from results/network_training/logical/model_logical.


In [58]:
model.act(torch.tensor([-1,0,0.5,1]))

tensor([0.0000, 0.0000, 0.5000, 1.0000])

In [59]:
a = torch.tensor([[0,0,0,0],
                 [0,0,0,1],
                 [0,1,0,0],
                 [0,1,0,1],
                 [1,1,0,1],
                 [1,1,1,1]], dtype=torch.float)
model.forward(a)

tensor([0.0035, 0.9818, 0.9826, 0.0091, 0.9836, 0.0060],
       grad_fn=<SqueezeBackward0>)

In [61]:
model

NLayerPerceptron(
  (nets): ModuleList(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
  (act): ReLU(inplace=True)
  (last_act): Sigmoid()
)

In [56]:
for name, param in training_exp.net.named_parameters():
    print('')
    print(name)
    print(param)


nets.0.weight
Parameter containing:
tensor([[ 4.2409e-02,  6.8853e-02, -1.0490e-03,  8.0594e-02],
        [-4.8075e-02,  4.8030e-02, -4.8034e-02,  4.8017e-02],
        [ 4.0556e-01, -4.0587e-01,  4.0567e-01,  4.0549e-01],
        [ 2.2055e-01,  2.2056e-01,  2.2067e-01, -2.2066e-01],
        [ 5.3728e-02,  3.0757e-02,  3.2658e-03,  4.2683e-02],
        [-1.7807e-02, -4.5326e-03, -1.5698e-02,  2.5639e-02],
        [-3.0509e-02, -2.3602e-02, -2.3606e-02, -2.5987e-02],
        [ 4.4490e-01, -4.4485e-01,  4.4487e-01, -4.4502e-01],
        [ 3.6469e-02,  2.8512e-03, -1.6449e-02,  2.5637e-02],
        [ 1.7520e-02,  4.5237e-02,  1.6359e-02,  2.6000e-03],
        [ 6.8966e-02, -6.8870e-02,  6.9059e-02,  6.9022e-02],
        [ 4.6023e-02, -6.3105e-03,  5.2140e-02,  1.7301e-02],
        [ 2.9156e-01, -2.9145e-01, -2.9130e-01, -2.9139e-01],
        [-1.6460e-01,  1.6452e-01,  1.6450e-01,  1.6454e-01],
        [ 4.1319e-02, -4.1316e-02,  4.1302e-02, -4.1289e-02],
        [ 2.0497e-01, -2.0494e-01

## Extract Jacobian

In [148]:
def get_jacobian_stats(jacobian, model, x):
    
    #print(layer.weight.shape)
    #print(layer.bias.shape)
    #print(jacobian.shape)
    
    num_matches = torch.isclose(
        layer.weight, jacobian, atol=1e-2
    ).sum()
    frac_match = num_matches / sum(p.numel() for p in layer.parameters())
    print("Fraction matched in jacobian and weight mat: %f" % frac_match.item())

    y = layer.forward(x)
    y_prime = jacobian @ x + layer.bias
    diff = torch.abs(y - y_prime)
    print("Diff between y and y_prime: %f" % diff.sum().item())
    

In [155]:
# 1. Create input and array
x = training_exp.train_dataset.data[3]


# 2. Extract specific Layer 
layer, act = model.get_layer_and_act(0)

# 4. Define network function
net_func = lambda inp: act(layer.forward(inp))

# 3. Calculate jacobian
jacobian = torch.autograd.functional.jacobian(net_func, x, create_graph=True)

In [156]:
get_jacobian_stats(jacobian, model, x)

Fraction matched in jacobian and weight mat: 0.379687
Diff between y and y_prime: 8.138474


In [157]:
x

tensor([0., 1., 0., 1.])