In [1]:
import numpy as np
import torch
import os
os.chdir("/home/jshenouda/break_test_llms")  # Change to the actual subdirectory name
from train_mnist import MLP
import torch.nn as nn
from torchvision import datasets, transforms

In [37]:
def modify_mlp(model, D_mats= None):
    """Modify an MLP model by applying deep reparameterization ensuring network equivalence."""
    modified_model = MLP(model.model[0].in_features, 
                          [layer.out_features for layer in model.model if isinstance(layer, nn.Linear)][:-1], 
                          model.model[-1].out_features)
    if D_mats is None:
        # controls distortion to mess up cosine sim
        lam = 5

        # Generate the D_k matrices if non given
        D_mats = []
        # count number of ReLU layers
        num_hidden = sum(1 for layer in model.model if isinstance(layer, nn.ReLU))
        for layer in model.model:
            if len(D_mats) == num_hidden:
                break
            if isinstance(layer, nn.Linear):
                max_exp = 30  # Limit exponent growth
                i = torch.linspace(-max_exp, max_exp, layer.out_features)
                D1_values = torch.exp(i)
                D = torch.diag(D1_values)
                # D = torch.diag(torch.tensor([lam**j for j in range(layer.out_features)])).to(torch.float)
                D_mats.append(D)
    
    with torch.no_grad():
        i = 0
        for (orig_layer, new_layer) in zip(model.model, modified_model.model):
            if isinstance(orig_layer, nn.Linear):
                if i < model.affine_layers-1:
                    if i == 0:
                        new_layer.weight.data = torch.inverse(D_mats[0]) @ orig_layer.weight.data
                        new_layer.bias.data = torch.inverse(D_mats[0]) @ orig_layer.bias.data
                    else:
                        new_layer.weight.data = torch.inverse(D_mats[i]) @ orig_layer.weight.data @ D_mats[i-1]
                        new_layer.bias.data = torch.inverse(D_mats[i]) @ orig_layer.bias.data
                else:
                    # For the last linear layer
                    new_layer.weight.data = orig_layer.weight.data @ D_mats[-1]
                    new_layer.bias.data = orig_layer.bias.data

                i += 1

    return modified_model, D_mats


In [38]:
input_dim = 28*28
num_layers = 6
hidden_dims = [100]*num_layers  # 6 hidden layers
output_dim = 10

trained_model = MLP(input_dim, hidden_dims, output_dim)

# Load the trained model
trained_model.load_state_dict(torch.load("mnist_mlp_6_layers.pt"))

<All keys matched successfully>

In [39]:
# Modify the model
modified_model, D_matrices = modify_mlp(trained_model)

In [40]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

test_data= datasets.MNIST('../data', train=False,
                       transform=transform)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=10000, shuffle=False)

def ensure_models_equivalent(model1, model2, test_loader, device):
    """Ensure two models produce the same output for all 10,000 test MNIST samples."""
    model1.to(device)   
    model2.to(device)
    model1.eval()
    model2.eval()
    
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            output1 = model1(data)
            output2 = model2(data)
            
            dist = torch.linalg.norm(output1 - output2, ord='fro')**2
    
    return dist


In [41]:
# Ensure both models are equivalent
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(ensure_models_equivalent(trained_model, modified_model, test_loader, device))

tensor(4.4379e-06, device='cuda:0')


In [42]:
# For each weight matrix compare cossine similarity between rows of original weigth matrix and modified weight matrix
def compare_weights(model1, model2):
    with torch.no_grad():
        for (layer1, layer2) in zip(model1.model, model2.model):
            if isinstance(layer1, nn.Linear):
                cos_sim = nn.CosineSimilarity(dim=1)
                sim = cos_sim(layer1.weight, layer2.weight)
                print(sim)

In [43]:
compare_weights(trained_model, modified_model)

tensor([1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+