In [None]:
import torch
from biotorch.models.weight_mirroring.resnet import resnet18, resnet34, resnet50, wide_resnet101_2
from biotorch.training.metrics import compute_angles_module

In [None]:
# Test the whole model
model = resnet18(pretrained=False, num_classes=10)
# model = model.to('cuda')

x = torch.randn((10, 3, 224, 224)).to(model.device)

for i in range(10000):
    model.mirror_weights(x, 
                         mirror_learning_rate=0.01,
                         noise_amplitude=0.1,
                         growth_control=False,
                         damping_factor=0.5)
    
    layers_alignment = compute_angles_module(model)
    if i % 100 == 0:
        print(layers_alignment)

In [None]:
layers_alignment

In [None]:
# Test 1 convolution layer
model = resnet18(pretrained=False, num_classes=10)
conv_layer = model.conv1
noise_amplitude = 0.1
for i in range(10000):
    with torch.no_grad():
        input_noise = noise_amplitude * torch.randn((10, 3, 224, 224)
        output_noise = conv_layer(input_noise)
        conv_layer.update_B(input_noise,
                             output_noise,
                             mirror_learning_rate=0.01,
                             growth_control=False,
                             damping_factor=1.0)
        print(compute_matrix_angle(conv_layer.weight, conv_layer.weight_backward))

In [None]:
# Test 1 FC layer
model = resnet18(pretrained=False, num_classes=10)
linear_layer = model.fc
noise_amplitude = 0.1
for i in range(10000):
    with torch.no_grad():
        input_noise = noise_amplitude * torch.randn(12, 512)
        output_noise = linear_layer(input_noise)
        linear_layer.update_B(input_noise,
                             output_noise,
                             mirror_learning_rate=0.01,
                             growth_control=False,
                             damping_factor=1.0)
        print(compute_matrix_angle(linear_layer.weight, linear_layer.weight_backward))