In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.optim as optim
from scipy.ndimage import gaussian_filter

In [2]:
# For M
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# For M' and M'' and M*
class Net_sub(nn.Module):
    def __init__(self):
        super(Net_sub, self).__init__()
        self.fc1 = nn.Linear(784, 64) # subbed from gaussian
        self.fc2 = nn.Linear(64, 64)  # changed to match dimension
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  

In [3]:
M = Net()
M.load_state_dict(torch.load('model.pth'))
M.eval()

Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

### Creating $S_{64\times128}$ Num of iterations=1 gives best result

In [5]:
def create_S(W, output_size, num_iterations=1, dtype=torch.float32):
    # Convert to tensor
    S = torch.tensor(W, dtype=dtype)
    print(S.shape)
    # Add dimensions to S to make it a 4D tensor if it's not already
    while len(S.shape) < 4:
        S = S.unsqueeze(0)

    # Ensure output_size is a list so we can modify it if necessary
    output_size = list(output_size)

    # If S is a 1D tensor, adjust output_size to match
    if len(W.shape) == 1 and len(output_size)==1:
        output_size = [output_size[0], 1]
    # Calculate the step size for each dimension
    step_size = [(S.shape[i] - output_size[i]) // num_iterations for i in range(len(output_size))]
    if len(W.shape) == 1:
        s_shape = W.shape[0]
    else:
        s_shape = S.shape[-2]
    print(s_shape)
    for i in range(num_iterations):
        # Apply Gaussian filter
        sigma = 1.0*num_iterations/(i+1)
        S_np = S.numpy()  # Convert to NumPy array
        S_filtered = gaussian_filter(S_np, sigma=sigma)
        S = torch.from_numpy(S_filtered)  # Convert back to tensor
        # Calculate the target size for this iteration
        if i < num_iterations - 1:
            target_zero = s_shape + step_size[0]
            target_size = [target_zero, output_size[1] ]
            s_shape = target_zero
        else:
            target_size = output_size
        print(target_size)
        # Create adaptive max pooling layer
        pool = nn.AdaptiveMaxPool2d(target_size)

        # Apply adaptive max pooling
        S = pool(S)
    # Remove the added dimensions
    S = S.squeeze()

    return S
    # sigma = 1.0
    # S = gaussian_filter(S, sigma=sigma)
    # S = torch.tensor(S, dtype=dtype)
    # S = S.unsqueeze(0).unsqueeze(0)
    # pool = nn.AdaptiveMaxPool2d(output_size)
    # S = pool(S)
    # S = S.squeeze()

A_dash_weights_L1 = create_S(W = M.fc1.weight.data.clone().detach(), output_size= (64, 784))
A_dash_bias_L1 = create_S(M.fc1.bias.data.clone().detach(), (64, 1))
A_dash_weights_L2 = (create_S((M.fc2.weight.data.clone().detach()).T, (64, 64))).T

print(A_dash_weights_L1.shape, A_dash_bias_L1.shape, A_dash_weights_L2.shape)

torch.Size([128, 784])
128
[64, 784]
torch.Size([128])
128
[64, 1]
torch.Size([128, 64])
128
[64, 64]
torch.Size([64, 784]) torch.Size([64]) torch.Size([64, 64])


  S = torch.tensor(W, dtype=dtype)


In [33]:
# A_dash_weights_L1 = S @ M.fc1.weight.data.clone()
# A_dash_bias_L1 = S @ M.fc1.bias.data.clone()

# A_dash_weights_L2 = S @ (M.fc2.weight.data.clone()).T
# A_dash_bias_L2 = S.T @ (M.fc2.bias.data.clone())

In [6]:
M_dash = Net_sub()

M_dash.fc1.weight.data = A_dash_weights_L1
M_dash.fc1.bias.data = A_dash_bias_L1

M_dash.fc2.weight.data = A_dash_weights_L2
M_dash.fc2.bias.data = M.fc2.bias.data.clone()

M_dash.fc3.weight.data = M.fc3.weight.data.clone()
M_dash.fc3.bias.data = M.fc3.bias.data.clone()


In [7]:
print(A_dash_weights_L1.shape, A_dash_bias_L1.shape)
print(A_dash_weights_L2.shape)

torch.Size([64, 784]) torch.Size([64])
torch.Size([64, 64])


In [8]:
print(M_dash.fc1.weight.data.shape, M_dash.fc1.bias.data.shape)
print(M_dash.fc2.weight.data.shape, M_dash.fc2.bias.data.shape)

torch.Size([64, 784]) torch.Size([64])
torch.Size([64, 64]) torch.Size([64])


In [9]:
M_dash.eval()

Net_sub(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

In [10]:
# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [11]:
M_dash.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        output = M_dash(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = 100 * correct / total

print(f'Test Accuracy model: {accuracy}%')

Test Accuracy model: 17.83%


In [21]:
import copy
M_double_dash = copy.deepcopy(M_dash)

In [22]:
M_double_dash

Net_sub(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

In [23]:
# Train the model M'' on the MNIST dataset
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(M_double_dash.parameters(), lr=0.01)
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = M_double_dash(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    # Print the training loss for each epoch
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

Epoch 1/10, Loss: 0.22690068185329437
Epoch 2/10, Loss: 0.655265748500824
Epoch 3/10, Loss: 0.13585880398750305
Epoch 4/10, Loss: 0.41211843490600586
Epoch 5/10, Loss: 0.18461140990257263
Epoch 6/10, Loss: 0.07055684924125671
Epoch 7/10, Loss: 0.4024240970611572
Epoch 8/10, Loss: 0.2851245105266571
Epoch 9/10, Loss: 0.12699361145496368
Epoch 10/10, Loss: 0.2197505682706833


In [15]:
M_double_dash.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        output = M_double_dash(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = 100 * correct / total

print(f'Test Accuracy model: {accuracy}%')

Test Accuracy model: 96.82%


In [16]:
M_star = Net_sub()
# Train the model M* on the MNIST dataset
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(M_star.parameters(), lr=0.01)
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = M_star(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    # Print the training loss for each epoch
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')


Epoch 1/10, Loss: 0.44507694244384766
Epoch 2/10, Loss: 0.15758569538593292
Epoch 3/10, Loss: 0.18744738399982452
Epoch 4/10, Loss: 0.09961222857236862
Epoch 5/10, Loss: 0.2557902932167053
Epoch 6/10, Loss: 0.12377199530601501
Epoch 7/10, Loss: 0.15048006176948547
Epoch 8/10, Loss: 0.13790775835514069
Epoch 9/10, Loss: 0.04519626498222351
Epoch 10/10, Loss: 0.06047726795077324


In [17]:
# Evaluate the model M* on the test set
M_star.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        output = M_star(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy}%')


Test Accuracy: 95.95%
