# O(2)-Invariant Neural Network Example: Learning the Dot Product

We show how we can use an $O(2)$-invariant neural network to learn the dot product between any two vectors in $\mathbb{R}^{2}$.

We can model this dot product by a function $(\mathbb{R}^2)^{\otimes 2} \rightarrow \mathbb{R}$.

We generate some synthetic training data that has some scaled Gaussian noise added to it.

The neural network architecture that we use comes from the paper "Brauer's Group Equivariant Neural Networks", which can be found at https://arxiv.org/abs/2212.08630.

In [1]:
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using a {device} device for training.")

Using a mps device for training.


### Part I: Define the $O(2)$-Equivariant Layers

We first define a learnable, linear, $O(2)$-equivariant layer of the form $(\mathbb{R}^2)^{\otimes 2} \rightarrow (\mathbb{R}^2)^{\otimes 2}$.

In [3]:
class O222Layer(torch.nn.Module):
    """
     - A learnable, linear, O(2)-equivariant layer of the form (R^2)^{\otimes 2} -> (R^2)^{\otimes 2}
    """
    def __init__(self):
        super().__init__()
        self.lambda1 = torch.nn.Parameter(torch.randn(()))
        self.lambda2 = torch.nn.Parameter(torch.randn(()))
        self.lambda3 = torch.nn.Parameter(torch.randn(()))
        self.dim = 2
        
    def forward(self, X):
        weight = torch.zeros((self.dim*self.dim,self.dim*self.dim)).to(device)

        #lambda_1
        for i in range(self.dim):
            for j in range(self.dim):
                for k in range(self.dim):
                    for l in range(self.dim):
                        if i == j and k == l:
                            weight[self.dim*i + j][self.dim*k + l] += self.lambda1

        #lambda_2
        for i in range(self.dim):
            for j in range(self.dim):
                for k in range(self.dim):
                    for l in range(self.dim):
                        if i == k and j == l:
                            weight[self.dim*i + j][self.dim*k + l] += self.lambda2

        #lambda_3
        for i in range(self.dim):
            for j in range(self.dim):
                for k in range(self.dim):
                    for l in range(self.dim):
                        if i == l and j == k:
                            weight[self.dim*i + j][self.dim*k + l] += self.lambda3 
        
        linear = torch.einsum('ij,kj->ki', weight, X)    # allows for batch processing
        return linear

We check that the weight matrix will be of the correct form, by setting $\lambda_1 = 1$, $\lambda_2 = 2$ and $\lambda_3 = 3$.

In [4]:
# A check that the weight matrix will be of the correct form

dim = 2
weight = torch.zeros((dim*dim,dim*dim)).to(device)

#lambda_1 = 1
for i in range(dim):
    for j in range(dim):
        for k in range(dim):
            for l in range(dim):
                if i == j and k == l:
                    weight[dim*i + j][dim*k + l] += 1
print(weight)

#lambda_2 = 2
for i in range(dim):
    for j in range(dim):
        for k in range(dim):
            for l in range(dim):
                if i == k and j == l:
                    weight[dim*i + j][dim*k + l] += 2
print(weight)

#lambda_3 = 3
for i in range(dim):
    for j in range(dim):
        for k in range(dim):
            for l in range(dim):
                if i == l and j == k:
                    weight[dim*i + j][dim*k + l] += 3            
print(weight)

tensor([[1., 0., 0., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 0., 1.]], device='mps:0')
tensor([[3., 0., 0., 1.],
        [0., 2., 0., 0.],
        [0., 0., 2., 0.],
        [1., 0., 0., 3.]], device='mps:0')
tensor([[6., 0., 0., 1.],
        [0., 2., 3., 0.],
        [0., 3., 2., 0.],
        [1., 0., 0., 6.]], device='mps:0')


  nonzero_finite_vals = torch.masked_select(


We now define a learnable, linear, $O(2)$-equivariant layer of the form $(\mathbb{R}^2)^{\otimes 2} \rightarrow \mathbb{R}$.

In [5]:
class O220Layer(torch.nn.Module):
    """
     - A learnable, linear, O(2)-invariant layer of the form (R^2)^{\otimes 2} -> R
    """
    def __init__(self):
        super().__init__()
        self.lambda1 = torch.nn.Parameter(torch.randn(()))
        self.dim = 2
        
    def forward(self, X):
        weight = torch.zeros((1,self.dim*self.dim)).to(device)
        
        for k in range(self.dim):
            for l in range(self.dim):
                if k == l:
                    weight[0][dim*k + l] += self.lambda1
        
        linear = torch.einsum('ij,kj->ki', weight, X)    # allows for batch processing
        return linear

Again, we check that the weight matrix will be of the correct form, by setting $\lambda_1 = 5$.

In [6]:
# A check that the weight matrix will be of the correct form

dim = 2
weight2 = torch.zeros((1,dim*dim))

#lambda_1 = 5
for k in range(dim):
    for l in range(dim):
        if k == l:
            weight2[0][dim*k + l] += 5
print(weight2)

tensor([[5., 0., 0., 5.]])


We can now build an $O(2)$-equivariant neural network composed of these linear layers and pointwise, equivariant, non-linearities, for example:

In [7]:
class Orthogonal2NeuralNetwork(nn.Module):
    def __init__(self):
        super(Orthogonal2NeuralNetwork, self).__init__()
        self.layer1 = O222Layer()
        self.layer2 = O222Layer()
        self.layer3 = O220Layer()
        self.relu = torch.nn.ReLU()
        
    def forward(self, X):
        X = self.layer1(X)
        X = self.relu(X)
        X = self.layer2(X)
        X = self.relu(X)
        X = self.layer3(X)
        X = self.relu(X)
        return X

### Part II: Training Data Generation

In this section, we generate some synthetic data. We first create pairs of vectors that live in $\mathbb{R}^{2}$ (expressed in the standard basis of $\mathbb{R}^{2}$), and calculate their tensor product. These will represent the input data to the model(s).

In [8]:
#Data Generation
np.random.seed(42)
N = 6400
dim = 2

X_1 = np.random.randn(N, dim)
print(X_1)

X_2 = np.random.randn(N, dim)
print(X_2)

X = np.zeros((N, dim*dim))
print(X)

[[ 0.49671415 -0.1382643 ]
 [ 0.64768854  1.52302986]
 [-0.23415337 -0.23413696]
 ...
 [ 2.50323781  0.02447967]
 [-0.96266853 -0.63363987]
 [-2.25672511  1.60701901]]
[[ 1.59476199 -1.01550744]
 [-0.22788429  0.09560754]
 [-0.96035305 -1.75510203]
 ...
 [ 1.36867699  0.36583056]
 [-2.40627814  0.88349343]
 [-1.27230501  0.85917276]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 ...
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]


In [9]:
for i in range(0,N):
    X[i] = np.kron(X_1[i], X_2[i])
print(X)

[[ 0.79214085 -0.50441692 -0.22049865  0.14040843]
 [-0.14759804  0.06192391 -0.34707458  0.14561314]
 [ 0.22486991  0.41096306  0.22485414  0.41093425]
 ...
 [ 3.42612399  0.91576088  0.03350476  0.00895541]
 [ 2.31644825 -0.85051132  1.52471376 -0.55981666]
 [ 2.87124267 -1.93891674 -2.04461834  1.38070696]]


We now calculate the dot products of the vectors, and add some scaled Gaussian noise. These will form the labels for the input data.

In [10]:
dotprod = np.sum(X_1*X_2, axis=1)
dotprod = np.reshape(dotprod, (N,1))
print(dotprod)

[[ 9.32549277e-01]
 [-1.98489742e-03]
 [ 6.35804157e-01]
 ...
 [ 3.43507940e+00]
 [ 1.75663159e+00]
 [ 4.25194962e+00]]


In [11]:
epsilon = 0.1 * np.random.randn(N, 1)
print(epsilon)

[[0.0901518 ]
 [0.03298926]
 [0.05436809]
 ...
 [0.07900326]
 [0.04654331]
 [0.06654187]]


In [12]:
## learn dot product (with noise)
Y = dotprod + epsilon
print(Y)

[[1.02270108]
 [0.03100436]
 [0.69017225]
 ...
 [3.51408266]
 [1.8031749 ]
 [4.31849149]]


Finally, we create tensors that live on the device, ready to be used for training our models.

In [13]:
X_tensor = torch.tensor(X, dtype = torch.float, requires_grad = True, device=device)
print(X_tensor)

tensor([[ 0.7921, -0.5044, -0.2205,  0.1404],
        [-0.1476,  0.0619, -0.3471,  0.1456],
        [ 0.2249,  0.4110,  0.2249,  0.4109],
        ...,
        [ 3.4261,  0.9158,  0.0335,  0.0090],
        [ 2.3164, -0.8505,  1.5247, -0.5598],
        [ 2.8712, -1.9389, -2.0446,  1.3807]], device='mps:0',
       requires_grad=True)


In [14]:
Y_tensor = torch.tensor(Y, dtype = torch.float, requires_grad = True, device=device)
print(Y_tensor)

tensor([[1.0227],
        [0.0310],
        [0.6902],
        ...,
        [3.5141],
        [1.8032],
        [4.3185]], device='mps:0', requires_grad=True)


### Part III: Model Training

We create three models that we train with the training data. We will use these models (on test data) to show how well they generalise to unseen data.

In [15]:
model_linear = nn.Linear(4,1).to(device)
print(list(model_linear.parameters()))

[Parameter containing:
tensor([[-0.4635, -0.4597,  0.2921, -0.2624]], device='mps:0',
       requires_grad=True), Parameter containing:
tensor([-0.2768], device='mps:0', requires_grad=True)]


In [16]:
model_Orth_layer = O220Layer()
print(list(model_Orth_layer.parameters()))

[Parameter containing:
tensor(-0.8172, requires_grad=True)]


In [17]:
model_Orth_Network = Orthogonal2NeuralNetwork()
print(list(model_Orth_Network .parameters()))

[Parameter containing:
tensor(0.7488, requires_grad=True), Parameter containing:
tensor(-0.0502, requires_grad=True), Parameter containing:
tensor(0.1457, requires_grad=True), Parameter containing:
tensor(1.5516, requires_grad=True), Parameter containing:
tensor(0.1126, requires_grad=True), Parameter containing:
tensor(-0.3866, requires_grad=True), Parameter containing:
tensor(1.8287, requires_grad=True)]


We set the parameters, and create an iterable which will be used for training the models.

In [18]:
lr = 0.01
batch_size = 64
num_of_iterations = N // batch_size

torch.manual_seed(42)
loss_fn = nn.MSELoss(reduction='mean')
n_epochs = 20

In [19]:
optimizer_linear = optim.Adam(model_linear.parameters(), lr=lr)
optimizer_Orth_layer = optim.Adam(model_Orth_layer.parameters(), lr=lr)
optimizer_Orth_Network = optim.Adam(model_Orth_Network.parameters(), lr=lr)

In [20]:
train_dataset = TensorDataset(X_tensor,Y_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # create your dataloader

In [21]:
def train_model(model, dataloader, n_epochs, loss_fn, optimizer, device):
    for epoch in range(n_epochs):
        model.train()
        for batch, (X_tensor_slice, Y_tensor_slice) in enumerate(dataloader):
            #Y_res = torch.zeros((batch_size,1)).to(device)

            Y_res = model(X_tensor_slice)

            loss = loss_fn(Y_res, Y_tensor_slice)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Print the Loss and Accuracy at the end of every 20th iteration.
            if (batch + 1) % 20 == 0:
                    print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Params: {}"
                          .format(epoch + 1, n_epochs, batch + 1, len(dataloader), loss.item(), model.state_dict()))

    return None

We now train the three models.

In [22]:
train_model(model_linear, train_dataloader, n_epochs, loss_fn, optimizer_linear, device)

Epoch [1/20], Step [20/100], Loss: 3.9633, Params: OrderedDict([('weight', tensor([[-0.2746, -0.3218,  0.1696, -0.0757]], device='mps:0')), ('bias', tensor([-0.1086], device='mps:0'))])
Epoch [1/20], Step [40/100], Loss: 1.8973, Params: OrderedDict([('weight', tensor([[-0.0887, -0.2045,  0.0705,  0.0987]], device='mps:0')), ('bias', tensor([-0.0459], device='mps:0'))])
Epoch [1/20], Step [60/100], Loss: 1.4411, Params: OrderedDict([('weight', tensor([[ 0.0845, -0.0925,  0.0435,  0.2601]], device='mps:0')), ('bias', tensor([0.0054], device='mps:0'))])
Epoch [1/20], Step [80/100], Loss: 1.2745, Params: OrderedDict([('weight', tensor([[ 0.2300, -0.0342,  0.0107,  0.4031]], device='mps:0')), ('bias', tensor([0.0316], device='mps:0'))])
Epoch [1/20], Step [100/100], Loss: 0.3693, Params: OrderedDict([('weight', tensor([[ 0.3591,  0.0038, -0.0267,  0.5328]], device='mps:0')), ('bias', tensor([0.0242], device='mps:0'))])
Epoch [2/20], Step [20/100], Loss: 0.3803, Params: OrderedDict([('weight

Epoch [9/20], Step [80/100], Loss: 0.0086, Params: OrderedDict([('weight', tensor([[ 1.0005e+00,  4.5521e-04, -2.9567e-03,  1.0011e+00]], device='mps:0')), ('bias', tensor([8.4954e-05], device='mps:0'))])
Epoch [9/20], Step [100/100], Loss: 0.0100, Params: OrderedDict([('weight', tensor([[ 0.9998,  0.0049, -0.0021,  1.0005]], device='mps:0')), ('bias', tensor([-0.0018], device='mps:0'))])
Epoch [10/20], Step [20/100], Loss: 0.0089, Params: OrderedDict([('weight', tensor([[ 9.9990e-01,  5.1948e-04, -1.5743e-03,  1.0004e+00]], device='mps:0')), ('bias', tensor([0.0108], device='mps:0'))])
Epoch [10/20], Step [40/100], Loss: 0.0131, Params: OrderedDict([('weight', tensor([[0.9987, 0.0050, 0.0010, 1.0010]], device='mps:0')), ('bias', tensor([-0.0070], device='mps:0'))])
Epoch [10/20], Step [60/100], Loss: 0.0094, Params: OrderedDict([('weight', tensor([[ 0.9988, -0.0035, -0.0047,  1.0014]], device='mps:0')), ('bias', tensor([-0.0057], device='mps:0'))])
Epoch [10/20], Step [80/100], Loss: 

Epoch [18/20], Step [40/100], Loss: 0.0136, Params: OrderedDict([('weight', tensor([[0.9983, 0.0013, 0.0053, 1.0018]], device='mps:0')), ('bias', tensor([0.0008], device='mps:0'))])
Epoch [18/20], Step [60/100], Loss: 0.0114, Params: OrderedDict([('weight', tensor([[ 9.9896e-01, -1.3201e-03,  4.6454e-04,  9.9966e-01]], device='mps:0')), ('bias', tensor([0.0042], device='mps:0'))])
Epoch [18/20], Step [80/100], Loss: 0.0099, Params: OrderedDict([('weight', tensor([[ 1.0017e+00,  9.2280e-04, -6.9563e-03,  1.0005e+00]], device='mps:0')), ('bias', tensor([-0.0142], device='mps:0'))])
Epoch [18/20], Step [100/100], Loss: 0.0107, Params: OrderedDict([('weight', tensor([[ 1.0014, -0.0010,  0.0041,  0.9995]], device='mps:0')), ('bias', tensor([0.0051], device='mps:0'))])
Epoch [19/20], Step [20/100], Loss: 0.0119, Params: OrderedDict([('weight', tensor([[ 1.0002, -0.0017, -0.0033,  0.9998]], device='mps:0')), ('bias', tensor([-0.0003], device='mps:0'))])
Epoch [19/20], Step [40/100], Loss: 0.0

In [23]:
train_model(model_Orth_layer, train_dataloader, n_epochs, loss_fn, optimizer_Orth_layer, device)

Epoch [1/20], Step [20/100], Loss: 4.1122, Params: OrderedDict([('lambda1', tensor(-0.6268))])
Epoch [1/20], Step [40/100], Loss: 6.3064, Params: OrderedDict([('lambda1', tensor(-0.4482))])
Epoch [1/20], Step [60/100], Loss: 4.6415, Params: OrderedDict([('lambda1', tensor(-0.2722))])
Epoch [1/20], Step [80/100], Loss: 1.3939, Params: OrderedDict([('lambda1', tensor(-0.1125))])
Epoch [1/20], Step [100/100], Loss: 1.8529, Params: OrderedDict([('lambda1', tensor(0.0395))])
Epoch [2/20], Step [20/100], Loss: 1.9431, Params: OrderedDict([('lambda1', tensor(0.1812))])
Epoch [2/20], Step [40/100], Loss: 0.6626, Params: OrderedDict([('lambda1', tensor(0.3115))])
Epoch [2/20], Step [60/100], Loss: 0.7677, Params: OrderedDict([('lambda1', tensor(0.4217))])
Epoch [2/20], Step [80/100], Loss: 0.5793, Params: OrderedDict([('lambda1', tensor(0.5136))])
Epoch [2/20], Step [100/100], Loss: 0.4147, Params: OrderedDict([('lambda1', tensor(0.5974))])
Epoch [3/20], Step [20/100], Loss: 0.2439, Params: Ord

Epoch [18/20], Step [60/100], Loss: 0.0086, Params: OrderedDict([('lambda1', tensor(1.0001))])
Epoch [18/20], Step [80/100], Loss: 0.0145, Params: OrderedDict([('lambda1', tensor(0.9993))])
Epoch [18/20], Step [100/100], Loss: 0.0122, Params: OrderedDict([('lambda1', tensor(1.0009))])
Epoch [19/20], Step [20/100], Loss: 0.0092, Params: OrderedDict([('lambda1', tensor(1.0019))])
Epoch [19/20], Step [40/100], Loss: 0.0093, Params: OrderedDict([('lambda1', tensor(1.0016))])
Epoch [19/20], Step [60/100], Loss: 0.0087, Params: OrderedDict([('lambda1', tensor(0.9999))])
Epoch [19/20], Step [80/100], Loss: 0.0109, Params: OrderedDict([('lambda1', tensor(0.9988))])
Epoch [19/20], Step [100/100], Loss: 0.0108, Params: OrderedDict([('lambda1', tensor(1.0002))])
Epoch [20/20], Step [20/100], Loss: 0.0091, Params: OrderedDict([('lambda1', tensor(0.9995))])
Epoch [20/20], Step [40/100], Loss: 0.0149, Params: OrderedDict([('lambda1', tensor(1.0006))])
Epoch [20/20], Step [60/100], Loss: 0.0119, Para

In [24]:
train_model(model_Orth_Network, train_dataloader, n_epochs, loss_fn, optimizer_Orth_Network, device)

Epoch [1/20], Step [20/100], Loss: 1.9416, Params: OrderedDict([('layer1.lambda1', tensor(0.6559)), ('layer1.lambda2', tensor(-0.1849)), ('layer1.lambda3', tensor(0.0296)), ('layer2.lambda1', tensor(1.4587)), ('layer2.lambda2', tensor(-0.0221)), ('layer2.lambda3', tensor(-0.5028)), ('layer3.lambda1', tensor(1.6413))])
Epoch [1/20], Step [40/100], Loss: 2.1068, Params: OrderedDict([('layer1.lambda1', tensor(0.6000)), ('layer1.lambda2', tensor(-0.2213)), ('layer1.lambda3', tensor(-0.0170)), ('layer2.lambda1', tensor(1.4027)), ('layer2.lambda2', tensor(-0.0585)), ('layer2.lambda3', tensor(-0.5493)), ('layer3.lambda1', tensor(1.4760))])
Epoch [1/20], Step [60/100], Loss: 0.7807, Params: OrderedDict([('layer1.lambda1', tensor(0.6190)), ('layer1.lambda2', tensor(-0.2393)), ('layer1.lambda3', tensor(-0.0571)), ('layer2.lambda1', tensor(1.4218)), ('layer2.lambda2', tensor(-0.0765)), ('layer2.lambda3', tensor(-0.5895)), ('layer3.lambda1', tensor(1.3331))])
Epoch [1/20], Step [80/100], Loss: 1.3

Epoch [6/20], Step [40/100], Loss: 1.1319, Params: OrderedDict([('layer1.lambda1', tensor(-0.0987)), ('layer1.lambda2', tensor(-1.5667)), ('layer1.lambda3', tensor(-0.7981)), ('layer2.lambda1', tensor(0.7040)), ('layer2.lambda2', tensor(-1.4039)), ('layer2.lambda3', tensor(-1.3304)), ('layer3.lambda1', tensor(1.0017))])
Epoch [6/20], Step [60/100], Loss: 0.8633, Params: OrderedDict([('layer1.lambda1', tensor(-0.1171)), ('layer1.lambda2', tensor(-1.6312)), ('layer1.lambda3', tensor(-0.8045)), ('layer2.lambda1', tensor(0.6856)), ('layer2.lambda2', tensor(-1.4684)), ('layer2.lambda3', tensor(-1.3368)), ('layer3.lambda1', tensor(1.0013))])
Epoch [6/20], Step [80/100], Loss: 1.0114, Params: OrderedDict([('layer1.lambda1', tensor(-0.1727)), ('layer1.lambda2', tensor(-1.7347)), ('layer1.lambda3', tensor(-0.8718)), ('layer2.lambda1', tensor(0.6301)), ('layer2.lambda2', tensor(-1.5719)), ('layer2.lambda3', tensor(-1.4042)), ('layer3.lambda1', tensor(1.0045))])
Epoch [6/20], Step [100/100], Loss

Epoch [11/20], Step [60/100], Loss: 0.9551, Params: OrderedDict([('layer1.lambda1', tensor(-0.3031)), ('layer1.lambda2', tensor(-2.8114)), ('layer1.lambda3', tensor(-1.3747)), ('layer2.lambda1', tensor(0.4997)), ('layer2.lambda2', tensor(-2.6486)), ('layer2.lambda3', tensor(-1.9070)), ('layer3.lambda1', tensor(1.0031))])
Epoch [11/20], Step [80/100], Loss: 0.8236, Params: OrderedDict([('layer1.lambda1', tensor(-0.3084)), ('layer1.lambda2', tensor(-2.8512)), ('layer1.lambda3', tensor(-1.3949)), ('layer2.lambda1', tensor(0.4944)), ('layer2.lambda2', tensor(-2.6884)), ('layer2.lambda3', tensor(-1.9272)), ('layer3.lambda1', tensor(1.0040))])
Epoch [11/20], Step [100/100], Loss: 1.5850, Params: OrderedDict([('layer1.lambda1', tensor(-0.3225)), ('layer1.lambda2', tensor(-2.8532)), ('layer1.lambda3', tensor(-1.3816)), ('layer2.lambda1', tensor(0.4802)), ('layer2.lambda2', tensor(-2.6904)), ('layer2.lambda3', tensor(-1.9139)), ('layer3.lambda1', tensor(1.0025))])
Epoch [12/20], Step [20/100], 

Epoch [16/20], Step [80/100], Loss: 1.3111, Params: OrderedDict([('layer1.lambda1', tensor(-0.8753)), ('layer1.lambda2', tensor(-3.9493)), ('layer1.lambda3', tensor(-2.0769)), ('layer2.lambda1', tensor(-0.0725)), ('layer2.lambda2', tensor(-3.7865)), ('layer2.lambda3', tensor(-2.6092)), ('layer3.lambda1', tensor(0.9995))])
Epoch [16/20], Step [100/100], Loss: 0.4830, Params: OrderedDict([('layer1.lambda1', tensor(-0.9143)), ('layer1.lambda2', tensor(-4.0837)), ('layer1.lambda3', tensor(-2.1144)), ('layer2.lambda1', tensor(-0.1116)), ('layer2.lambda2', tensor(-3.9209)), ('layer2.lambda3', tensor(-2.6468)), ('layer3.lambda1', tensor(1.0007))])
Epoch [17/20], Step [20/100], Loss: 0.8505, Params: OrderedDict([('layer1.lambda1', tensor(-0.9331)), ('layer1.lambda2', tensor(-4.1283)), ('layer1.lambda3', tensor(-2.1559)), ('layer2.lambda1', tensor(-0.1304)), ('layer2.lambda2', tensor(-3.9654)), ('layer2.lambda3', tensor(-2.6882)), ('layer3.lambda1', tensor(1.0011))])
Epoch [17/20], Step [40/100

Finally, we save the trained models for future use.

In [25]:
torch.save(model_linear.state_dict(), "linearNet.pth")
print("Saved PyTorch Model State to linearNet.pth")

torch.save(model_Orth_layer.state_dict(), "orthLayerNet.pth")
print("Saved PyTorch Model State to orthLayerNet.pth")

torch.save(model_Orth_Network.state_dict(), "orthNetworkNet.pth")
print("Saved PyTorch Model State to orthNetworkNet.pth")

Saved PyTorch Model State to linearNet.pth
Saved PyTorch Model State to orthLayerNet.pth
Saved PyTorch Model State to orthNetworkNet.pth


In [26]:
### THEN NEED TO TRAIN AND SAVE A NUMBER OF MODELS
### THEN NEED TO DO MULTIPLE LOSS RUNS AND GRAPH THEM, AND FIX THE TESTING BIT!

### Part III: Test Data + Test Loss

We now generate some test data.

In [35]:
#Data Test Generation
np.random.seed(42)
M = 3600

X_1_test = np.random.randn(M, dim)
print(X_1_test)

X_2_test = np.random.randn(M, dim)
print(X_2_test)

X_test = np.zeros((M, dim*dim))
print(X_test)

[[ 0.49671415 -0.1382643 ]
 [ 0.64768854  1.52302986]
 [-0.23415337 -0.23413696]
 ...
 [ 0.74744497  0.29787619]
 [ 0.45015563 -0.10435845]
 [ 1.77549085 -1.47919602]]
[[ 0.13965167  0.06502463]
 [ 0.81328041 -0.84385388]
 [-1.7436726   0.15682302]
 ...
 [ 0.29613359 -0.58815313]
 [ 1.45079947 -1.11714176]
 [ 0.26162526  0.24611771]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 ...
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]


In [36]:
for i in range(0,M):
    X_test[i] = np.kron(X_1_test[i], X_2_test[i])
print(X_test)

[[ 0.06936696  0.03229865 -0.01930884 -0.00899059]
 [ 0.5267524  -0.54655448  1.23865034 -1.28521465]
 [ 0.40828682 -0.03672064  0.4082582  -0.03671806]
 ...
 [ 0.22134356 -0.4396121   0.08821115 -0.17519681]
 [ 0.65308555 -0.50288766 -0.15140318  0.11658318]
 [ 0.46451326  0.43697974 -0.38699505 -0.36405633]]


#### Do I need to add some scaled Gaussian noise into the test data?

In [37]:
dotprod_test = np.sum(X_1_test*X_2_test, axis=1)
dotprod_test = np.reshape(dotprod_test, (M,1))
Y_tensor_test = torch.tensor(dotprod_test, dtype = torch.float, requires_grad = True, device=device)

We create a tensor from the test data, and place it on the device.

In [38]:
X_tensor_test = torch.tensor(X_test, dtype = torch.float, requires_grad = True, device=device)

We now calculate the model outputs for the three models, and calculate their losses.

In [39]:
Y_res_linear = model_linear(X_tensor_test)
Y_res_Orth_layer = model_Orth_layer(X_tensor_test)
Y_res_Orth_Network = model_Orth_Network(X_tensor_test)

In [40]:
loss_linear = loss_fn(Y_res_linear, Y_tensor_test)
loss_Orth_layer = loss_fn(Y_res_Orth_layer, Y_tensor_test)
loss_Orth_Network = loss_fn(Y_res_Orth_Network, Y_tensor_test)
print("Standard Linear Layer loss: ", loss_linear)
print("Orthogonal Linear Layer loss: ",loss_Orth_layer)
print("Orthogonal NN Layer loss: ", loss_Orth_Network)

Standard Linear Layer loss:  tensor(1.1249e-05, device='mps:0', grad_fn=<MseLossBackward0>)
Orthogonal Linear Layer loss:  tensor(1.3396e-06, device='mps:0', grad_fn=<MseLossBackward0>)
Orthogonal NN Layer loss:  tensor(0.9235, device='mps:0', grad_fn=<MseLossBackward0>)
