In [147]:
import torch 
import torchvision 
import plotly
import logging
import json
import os
import timeit
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset
import random
import heapq

In [148]:
BATCH_SIZE=64

# Make it repeatable, set seeds
random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [149]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_dataset, validate_dataset = torch.utils.data.random_split(train_dataset, [50000, 10000])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validate_loader = DataLoader(dataset=validate_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [150]:
def train(model, data_loader, optimizer, criterion, mask = None):
    results = []
    model.train()
    for data in data_loader:
        inputs, labels = data
        optimizer.zero_grad() # zero the parameter gradients
        predictions = model(inputs) # forward pass
        loss = criterion(predictions, labels) # calculate loss
        loss.backward() # backward pass
        optimizer.step() # update parameters

    return model

def evaluate(model, data_loader, mask = None):
      model.eval()
      with torch.inference_mode():
          correct = 0
          total = 0
          for inputs, labels in data_loader:
              inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
              outputs = model(inputs, mask)
              _, predicted = torch.max(outputs.data, 1)
              total += labels.size(0)
              correct += (predicted == labels).sum().item()

      return correct / total

In [151]:
DEBUG = bool(os.environ.get('DEBUG', 'True') == 'True')
log_level = os.environ.get('LOG_LEVEL', 'DEBUG' if DEBUG else 'INFO')
logging.basicConfig(level=logging.getLevelName(log_level), format="%(message)s", force=True)

# a random subset of MNIST for quicker accuracy determination
subset_size = 10000
# Create a random subset of the MNIST dataset and move it to the GPU
subset_indices = torch.randperm(len(train_dataset))[:subset_size]
subset_dataset = Subset(train_dataset, subset_indices)
# Create DataLoader for the subset
subset_data_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [154]:
## HYPERPARAMETERS
# number of masks
# number of top parents to keep
# the crossover operation
# the mutation operation
# the masking operation
# number of epochs
# length should be 2x the trained version, and width polynomial in width to get SLT

EPOCHS=3
LEARNING_RATE=0.001
TRAIN=False
NUM_MASKS=26
assert NUM_MASKS % 2 == 0 # must be able to create groups of 2 parents
NUM_ELEMENTS_TO_KEEP = 20

class MLP(nn.Module):
    def __init__(self, input_size, num_hidden_layers, hidden_size, output_size):
        super(MLP, self).__init__()
        assert num_hidden_layers > 0
        self.input = nn.Linear(input_size, hidden_size)
        self.layers = nn.ModuleList()
        for i in range(num_hidden_layers):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x, mask = None):
        """
        This allows for both a trainable MLP model and a non-training mask based model
        """
        x = x.view(x.size(0), -1)
        x = self.input(x)

        if (mask != None):
            for i, layer in enumerate(self.layers):
                x = layer(x) * mask[i]
                x = torch.relu(x)
        else:
            for layer in self.layers:
                x = layer(x)
                x = torch.relu(x)

        x = self.output(x)
        return torch.softmax(x, dim=1)


def random_masks(num, depth, width):
    if num <= 0:
        return torch.tensor([])
    return torch.rand(num, depth, width, device=DEVICE)

def convert_to_binary_mask(float_mask_2d: torch.tensor, num_elements_to_keep: int = NUM_ELEMENTS_TO_KEEP):
    _, top_indices_2d = float_mask_2d.topk(num_elements_to_keep, dim=1)
    # Create new mask for of 0's in the same shape as the float mask
    binary_mask_2d = torch.zeros_like(float_mask_2d)
    # Place 1's at the position of the top k elements in each layer
    binary_mask_2d.scatter_(1, top_indices_2d, 1)
    return binary_mask_2d

def basic_crossover(couple):
    return couple[0] * couple[1]

def basic_mutatation(mask):
    # Create a randomly initialized tensor with values between 0.1 and 0
    random_tensor = 0.1 * torch.rand_like(mask)
    # If the random value is less than 0.5, subtract the random tensor from the mask, otherwise add it
    if random.random() < 0.5:
        return mask - random_tensor
    return mask + random_tensor

def eval_accuracy_of_mask(mask_model, mask, ms):
    binary_mask = ms(mask) # mask the model
    accuracy = evaluate(mask_model, train_loader, binary_mask)
    return accuracy

def train_masks(mask_model, train_loader, keep_best, depth, width, masks, cx, mt, ms):
    for epoch in range(EPOCHS):
        accuracies = []
        # Evaluate the current generation of masks
        for mask in masks:
            binary_mask = ms(mask) # mask the model
            accuracies.append(evaluate(mask_model, train_loader, binary_mask))
        logging.info(f"Accuracies: {accuracies}")

        # Find the indicies of the best masks that will survive the generation
        best_mask_indexes = np.argpartition(accuracies, -keep_best)[-keep_best:]

        # Crossover and mutation
        couples = masks.chunk(int(NUM_MASKS / 2)) # chunk into groups of 2 parents
        children = [cx(couple) for couple in couples] # generate 1 child
        #logging.debug(f"Children: {children}")
        new_masks = [mt(child) for child in children] # mutate the children
        #logging.debug(f"New masks: {new_masks}")
        
        if (len(best_mask_indexes) + len(new_masks) > NUM_MASKS):
            raise ValueError("Too many masks specified to keep")
        
        # Keep the 'keep-best' number of original masks, add the new generation of masks, and fill the rest with random masks
        filler_masks = random_masks((NUM_MASKS - len(new_masks) - keep_best), depth, width)
        masks = torch.cat((masks[best_mask_indexes], torch.stack(new_masks, dim=0), filler_masks), dim=0)

def run(SHOW_PLOTS):
    num_hidden_layers = 3
    hidden_size = 100

    # TODO: the total number of neurons in the target network (train_model), should
    # be equal to the number of unmasked neurons in the masked model
    # e.g. 1*2 = 2 hidden neurons in target network, 3*8*sparsity == 2 in mask network
    # with 2l and polynomial width

    criterion = nn.CrossEntropyLoss()

    if TRAIN:
        logging.info('Training...')
        train_accuracies = []
        start_time = timeit.default_timer()
        train_model = MLP(input_size=784, num_hidden_layers=1, hidden_size=2, output_size=10)
        train_model.to(DEVICE)

        optimizer = torch.optim.Adam(train_model.parameters(), lr=LEARNING_RATE)

        for epoch in range(EPOCHS):
            model = train(train_model, train_loader, optimizer, criterion)
            accuracy = evaluate(model, validate_loader)
            train_accuracies.append(accuracy)
            print(f'Epoch {epoch + 1}/{EPOCHS}, Accuracy: {accuracy}')

        end_time = timeit.default_timer()
        print(f'Total training time: {end_time - start_time}')

        if SHOW_PLOTS:
            fig = px.line(x=range(1, EPOCHS + 1), y=train_accuracies)
            fig.show()


        logging.info('Testing...')
        test_accuracy = evaluate(train_model, test_loader)
        print(f'Test Accuracy: {test_accuracy}')

        # save model in case we want to use it again, and accuracy for a stop condition
        torch.save(train_model.state_dict(), './train_model.pt')
        with open('accuracy.json', 'w', encoding='utf-8') as f:
            json.dump(test_accuracy)

    mask_model = MLP(input_size=784, num_hidden_layers=num_hidden_layers, hidden_size=hidden_size, output_size=10)
    mask_model.to(DEVICE)

    masks = random_masks(NUM_MASKS, num_hidden_layers, hidden_size)
    logging.debug(f"Masks Shape: {masks.shape}")

    # TODO: replace these with whatever functions you want
    # takes a couple (tensor of size 2) and generates a child (singular tensor)
    cx = lambda a : basic_crossover(a)

    # Takes a mask and mutates it
    mt = lambda a : basic_mutatation(a)

    # Convert mask of floats to a binary mask
    ms = lambda a: convert_to_binary_mask(a, NUM_ELEMENTS_TO_KEEP)

    train_masks(mask_model, subset_data_loader, 10, num_hidden_layers, hidden_size, masks, cx, mt, ms)


if __name__ == "__main__":
    run(SHOW_PLOTS=True)

Masks Shape: torch.Size([26, 3, 100])
Accuracies: [0.1049, 0.0991, 0.105, 0.0926, 0.1083, 0.1049, 0.0991, 0.0999, 0.1049, 0.1049, 0.0921, 0.1393, 0.0991, 0.1055, 0.1049, 0.0991, 0.1049, 0.0991, 0.1049, 0.1049, 0.1049, 0.1049, 0.1247, 0.1049, 0.1049, 0.1049]
Accuracies: [0.1083, 0.1247, 0.1393, 0.1011, 0.0991, 0.1049, 0.0991, 0.1049, 0.1049, 0.0991, 0.1049, 0.0991, 0.1049, 0.1049, 0.1049, 0.0968, 0.1107, 0.1049, 0.1049, 0.1049, 0.1049, 0.0991, 0.1049, 0.1049, 0.0991, 0.1096]
Accuracies: [0.1107, 0.1393, 0.1247, 0.1049, 0.1182, 0.1049, 0.0991, 0.1049, 0.0991, 0.1049, 0.0911, 0.1048, 0.1049, 0.1049, 0.1049, 0.1019, 0.0991, 0.1049, 0.1049, 0.0737, 0.1049, 0.0991, 0.1049, 0.1, 0.0991, 0.1048]
