In [32]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
from torch.utils.data import Dataset
from tqdm import tqdm
import json
from pathlib import Path


def load_mnist_data(batch_size: int = 64):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    trainset = datasets.MNIST(
        "../../data/", download=True, train=True, transform=transform
    )
    
    testset = datasets.MNIST(
        "../../data/", download=True, train=False, transform=transform
    )
    
    return trainset, testset


class SimpleMNISTModel(nn.Module):
    def __init__(self, input_size=784, hidden_size=128, output_size=10):
        super(SimpleMNISTModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        return self.fc2(x)

In [33]:
# load a binary file
def load_data_from_json(weights_path: Path) -> SimpleMNISTModel:
    # Load weights from JSON
    with open(weights_path, 'r') as f:
        weight_data = json.load(f)
    
    return weight_data

def load_model() -> SimpleMNISTModel:
    # Load weights from JSON
    fc1 = load_data_from_json(Path("../../weights/mnist/avgDecryptedFC1.json"))
    fc2 = load_data_from_json(Path("../../weights/mnist/avgDecryptedFC2.json"))
    
    # Create model with same architecture
    model = SimpleMNISTModel()
    
    # Convert weights back to tensors and load into model
    fc1_weights = torch.tensor(fc1)
    fc2_weights = torch.tensor(fc2)
    
    print(fc1_weights.shape)
    print(fc2_weights.shape)
    
    model.fc1.weight.data = fc1_weights.reshape(128, 784)
    model.fc2.weight.data = fc2_weights[:1280].reshape(10, 128)
    
    return model

In [34]:
load_model().fc2.weight.data.shape

torch.Size([100352])
torch.Size([1280])


torch.Size([10, 128])

In [35]:
def evaluate_model(test_set: Dataset, device='cpu'):
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=64, shuffle=False, num_workers=2
    )
    model = load_model()
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    
    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in tqdm(test_loader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Accuracy on test set: {accuracy:.2f}%')
    return accuracy

In [36]:
_, test_set = load_mnist_data()

def include_digits(dataset, included_digits):
    including_indices = [
        idx for idx in range(len(dataset)) if dataset[idx][1] in included_digits
    ]
    return torch.utils.data.Subset(dataset, including_indices)

testset_137 = include_digits(test_set, [1, 3, 7])
testset_258 = include_digits(test_set, [2, 5, 8])
testset_469 = include_digits(test_set, [4, 6, 9])

In [48]:
print(torch.unique(test_set.targets))

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


In [49]:
for i in range(10):
    data, target = testset_137[i]
    print(f"Example {i}: Target = {target}")

Example 0: Target = 7
Example 1: Target = 1
Example 2: Target = 1
Example 3: Target = 1
Example 4: Target = 7
Example 5: Target = 3
Example 6: Target = 7
Example 7: Target = 1
Example 8: Target = 3
Example 9: Target = 1


In [50]:
evaluate_model(test_set)

torch.Size([100352])
torch.Size([1280])


Evaluating: 100%|██████████| 157/157 [00:03<00:00, 51.43it/s]

Accuracy on test set: 82.75%





82.75

In [51]:
evaluate_model(testset_137)

torch.Size([100352])
torch.Size([1280])


Evaluating: 100%|██████████| 50/50 [00:01<00:00, 45.60it/s]

Accuracy on test set: 89.91%





89.91490702804917

In [52]:
evaluate_model(testset_258)

torch.Size([100352])
torch.Size([1280])


Evaluating: 100%|██████████| 46/46 [00:01<00:00, 37.64it/s]

Accuracy on test set: 74.36%





74.36162870945479

In [53]:
evaluate_model(testset_469)

torch.Size([100352])
torch.Size([1280])


Evaluating: 100%|██████████| 47/47 [00:00<00:00, 49.54it/s]

Accuracy on test set: 77.62%





77.61953204476093