In [88]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.optim as optim
from scipy.ndimage import gaussian_filter

In [89]:
# For M
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# For M' and M'' and M*
class Net_sub(nn.Module):
    def __init__(self):
        super(Net_sub, self).__init__()
        self.fc1 = nn.Linear(784, 64) # subbed from gaussian
        self.fc2 = nn.Linear(64, 64)  # changed to match dimension
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  

In [90]:
M = Net()
M.load_state_dict(torch.load('model.pth'))
M.eval()

Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

In [91]:
def create_S(W, b=None, output_size = (None, None), num_iterations=1, dtype=torch.float32):
    # Convert to tensor
    if b is not None:
        S = torch.cat(( W.T, b.unsqueeze(0)), dim=0)
        print(S.shape)

        Omega = torch.randn(S.shape[1], output_size[0], dtype=dtype)
        Y = S@Omega
        Q, _ = torch.qr(Y)

        print(Q.shape)
        return Q[:-1], Q[-1]  # Return the weight matrix and bias vector separately
    else:
        Omega = torch.randn(W.shape[1], output_size[0], dtype=dtype)
        print(W.shape, Omega.shape)
        Y = Omega@W.T
        print(f"Y shape: {Y.shape} ")
        Q, _ = torch.qr(Y)

        print(Q.shape)
        return Q

W = M.fc1.weight.data.clone().detach()
b = M.fc1.bias.data.clone().detach()
# print(W.shape, b.shape)
A_dash_weights_L1, A_dash_bias_L1 = create_S(W, b, (64, 784))
print("A' bias L1 shape",A_dash_bias_L1.shape)
# print(A_dash_weights_L1.shape, A_dash_bias_L1.shape)

# concatenated = torch.cat((W, b.view(1, -1)), dim=0)

# A_dash_weights_L1 = create_S(W = concatenated, output_size= (64, 784))
# A_dash_bias_L1 = A_dash_weights_L1[-1]
# # A_dash_bias_L1 = create_S(M.fc1.bias.data.clone().detach(), (64, 1))
A_dash_weights_L2 = (create_S((M.fc2.weight.data.clone().detach()).T, b = None, output_size= (64, 64))).T

# print(A_dash_weights_L1.shape, A_dash_bias_L1.shape, A_dash_weights_L2.shape)

torch.Size([785, 128])
torch.Size([785, 64])
A' bias L1 shape torch.Size([64])
torch.Size([128, 64]) torch.Size([64, 64])
Y shape: torch.Size([64, 128]) 
torch.Size([64, 64])


In [92]:
M_dash = Net_sub()

In [95]:
print(M_dash.fc1.weight.data.shape)
print(A_dash_weights_L1.shape)

torch.Size([64, 784])
torch.Size([784, 64])


In [85]:
print((A_dash_weights_L2).T.shape)
print(M_dash.fc2.weight.data.shape)

torch.Size([64, 64])
torch.Size([64, 64])


In [96]:


M_dash.fc1.weight.data = (A_dash_weights_L1).T
M_dash.fc1.bias.data = A_dash_bias_L1

M_dash.fc2.weight.data = A_dash_weights_L2
M_dash.fc2.bias.data = M.fc2.bias.data.clone()

M_dash.fc3.weight.data = M.fc3.weight.data.clone()
M_dash.fc3.bias.data = M.fc3.bias.data.clone()

In [98]:
print(A_dash_weights_L1.shape, A_dash_bias_L1.shape)
print(A_dash_weights_L2.shape)
print(M_dash.fc1.weight.data.shape, M_dash.fc1.bias.data.shape)
print(M_dash.fc2.weight.data.shape, M_dash.fc2.bias.data.shape)

assert A_dash_bias_L1.shape == M_dash.fc1.bias.data.shape
assert (A_dash_weights_L1).T.shape == M_dash.fc1.weight.data.shape
assert A_dash_weights_L2.shape == M_dash.fc2.weight.data.shape

torch.Size([784, 64]) torch.Size([64])
torch.Size([64, 64])
torch.Size([64, 784]) torch.Size([64])
torch.Size([64, 64]) torch.Size([64])


In [99]:
# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [101]:
M_dash.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        output = M_dash(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = 100 * correct / total

print(f'Test Accuracy model: {accuracy}%')

Test Accuracy model: 5.86%
