In [73]:
import torch 
import torchvision 
import plotly
import logging
import json
import os
import timeit
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset
import random
import concurrent.futures
from functools import partial
import time

In [74]:
BATCH_SIZE=64

# Make it repeatable, set seeds
random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [75]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_dataset, validate_dataset = torch.utils.data.random_split(train_dataset, [50000, 10000])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validate_loader = DataLoader(dataset=validate_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# # Create a random subset of the MNIST dataset and pin the memory for faster transfer to GPU
subset_size = 500
subset_indices = torch.randperm(len(train_dataset))[:subset_size]
subset_dataset = Subset(train_dataset, subset_indices)
subset_data_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [76]:
def train(model, data_loader, optimizer, criterion, mask = None):
    results = []
    model.train()
    for data in data_loader:
        inputs, labels = data
        optimizer.zero_grad() # zero the parameter gradients
        predictions = model(inputs) # forward pass
        loss = criterion(predictions, labels) # calculate loss
        loss.backward() # backward pass
        optimizer.step() # update parameters

    return model

def evaluate(model, data_loader, mask = None):
      model.eval()
      with torch.inference_mode():
          correct = 0
          total = 0
          for inputs, labels in data_loader:
              inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
              outputs = model(inputs, mask)
              _, predicted = torch.max(outputs.data, 1)
              total += labels.size(0)
              correct += (predicted == labels).sum().item()

      return correct / total

In [77]:
DEBUG = bool(os.environ.get('DEBUG', 'True') == 'True')
log_level = os.environ.get('LOG_LEVEL', 'DEBUG' if DEBUG else 'INFO')
logging.basicConfig(level=logging.getLevelName(log_level), format="%(message)s", force=True)

In [78]:
# Crossover methods
def basic_crossover(couple):
    return couple[0] * couple[1]

def calculate_beta(q, eta):
    beta = torch.where(q <= 0.5, 
                       (2 * q) ** (1 / (eta + 1)), 
                       (1 / (2 * (1 - q))) ** (1 / (eta + 1)))
    return beta

def sbx_crossover(couple, probability=0.9, eta=1): # Takes like 1 minute per crossover - very slow.
    # Init empty offspring
    offspring1, offspring2 = {}, {}
    offspring1["hidden"], offspring2["hidden"] = [], []
    
    # Crossover for each layer, add to the offspring
    offspring1["input"], offspring2["input"] = sbx_crossover_per_layer(couple[0]["input"], couple[1]["input"], probability, eta)
    for i in range(min(len(couple[0]["hidden"]), len(couple[1]["hidden"]))):
        off1_hidden, off2_hidden = sbx_crossover_per_layer(couple[0]["hidden"][i], couple[1]["hidden"][i], probability, eta)
        offspring1["hidden"].append(off1_hidden)
        offspring2["hidden"].append(off2_hidden)
    offspring1["output"], offspring2["output"] = sbx_crossover_per_layer(couple[0]["output"], couple[1]["output"], probability, eta)
    return offspring1, offspring2

def sbx_crossover_per_layer(offspring1_layer_tensor, offspring2_layer_tensor, probability=0.9, eta=1):

    # Generate random numbers and calculate beta for the entire tensor
    rand = torch.rand(offspring1_layer_tensor.size(), device=offspring1_layer_tensor.device)
    beta = calculate_beta(rand, eta)

    # Apply SBX crossover using vectorized operations
    beta_1 = 0.5 * ((1 + beta) * offspring1_layer_tensor + (1 - beta) * offspring2_layer_tensor)
    beta_2 = 0.5 * ((1 - beta) * offspring1_layer_tensor + (1 + beta) * offspring2_layer_tensor)

    return beta_1, beta_2

def uninform_crossover(couple, probability=0.5):
    offspring1, offspring2 = {}, {}
    offspring1["hidden"], offspring2["hidden"] = [], []

    offspring1["input"], offspring2["input"] = uniform_crossover_per_layer(couple[0]["input"], couple[1]["input"], probability)
    for i in range(min(len(couple[0]["hidden"]), len(couple[1]["hidden"]))):
        off1_hidden, off2_hidden = uniform_crossover_per_layer(couple[0]["hidden"][i], couple[1]["hidden"][i], probability)
        offspring1["hidden"].append(off1_hidden)
        offspring2["hidden"].append(off2_hidden)
    offspring1["output"], offspring2["output"] = uniform_crossover_per_layer(couple[0]["output"], couple[1]["output"], probability)
    return offspring1, offspring2
    
def uniform_crossover_per_layer(offspring1_layer_tensor, offspring2_layer_tensor, probability=0.5):
    # Initialize the offspring tensors
    # Reshape the tensors for vectorized operations
    flat_tensor1 = offspring1_layer_tensor.view(-1)
    flat_tensor2 = offspring2_layer_tensor.view(-1)

    # Generate a random mask
    mask = torch.rand_like(flat_tensor1) < probability

    # Create offspring tensors using vectorized operations
    offspring1_tensor = torch.where(mask, flat_tensor2, flat_tensor1)
    offspring2_tensor = torch.where(mask, flat_tensor1, flat_tensor2)

    # Reshape back to original shape
    return offspring1_tensor.view_as(offspring1_layer_tensor), offspring2_tensor.view_as(offspring2_layer_tensor)

In [79]:
FLIP_BIT_PROBABILITY = 0.1
RANDOM_RESETTING_PROB = 0.15
CREEP_MUTATION_PROB = 0.1
GAUSS_MUTATION_PROB = 0.2

def random_mutation(mask):
    for key in mask:
        if isinstance(mask[key], list):
            mask[key] = [mutate_layer(layer) for layer in mask[key]]
        else:
            mask[key] = mutate_layer(mask[key])
    return mask

def mutate_layer(layer):
    mutation_choice = random.choices(
        ['flip_bit', 'random_reset', 'creep', 'none', 'guass'],
        weights=[FLIP_BIT_PROBABILITY, RANDOM_RESETTING_PROB, CREEP_MUTATION_PROB, 0.45, GAUSS_MUTATION_PROB])[0]

    if mutation_choice == 'flip_bit':
        return flip_bit_mutation_layer(layer)
    elif mutation_choice == 'random_reset':
        return random_resetting_mutation_layer(layer)
    elif mutation_choice == 'creep':
        return creep_mutation_layer(layer)
    elif mutation_choice == 'guass':
        return gaussian_mutation_layer(layer)
    else:  # 'none'
        return layer

def gaussian_mutation_layer(mask, mean=0.0, stddev=0.1):
    # Adds Gaussian noise to the elements of the mask
    noise = torch.normal(mean, stddev, size=mask.size(), device=mask.device)
    return mask + noise

def flip_bit_mutation_layer(mask, mutation_rate=0.01):
    # Randomly flip some of the elements in the mask
    flip_mask = torch.rand_like(mask) < mutation_rate
    return torch.where(flip_mask, 1 - mask, mask)

def random_resetting_mutation_layer(mask, mutation_rate=0.01):
    # Randomly reset a certain percentage of the elements in the mask to new values.
    random_values = torch.rand_like(mask)
    mutation_mask = torch.rand_like(mask) < mutation_rate
    return torch.where(mutation_mask, random_values, mask)

def creep_mutation_layer(mask, creep_rate=0.05, max_creep=0.1):
    # Increments or decrements the values of the mask by a small random value.
    creep = (2 * torch.rand_like(mask) - 1) * max_creep  # Values between -max_creep and +max_creep
    mutation_mask = torch.rand_like(mask) < creep_rate
    return torch.where(mutation_mask, mask + creep, mask)

In [80]:
## HYPERPARAMETERS
# number of masks
# number of top parents to keep
# the crossover operation
# the mutation operation
# the masking operation
# number of epochs
# length should be 2x the trained version, and width polynomial in width to get SLT

EPOCHS=2000
LEARNING_RATE=0.001
TRAIN=False
NUM_MASKS=100
assert NUM_MASKS % 2 == 0 # must be able to create groups of 2 parents
KEEP_BEST=10
NUM_CROSSOVER = 40
assert NUM_CROSSOVER % 2 == 0 # must be able to create couples for crossover. They'll be randomly grouped, it'll create 2 children per couple
NUM_HIDDEN_LAYERS = 3
HIDDEN_SIZE = 128
ELEMENTS_KEEP_INPUT_HIDDEN = 781*HIDDEN_SIZE*0.6 # 80% of the weights are kept
ELEMENTS_KEEP_HIDDEN_HIDDEN = HIDDEN_SIZE*HIDDEN_SIZE*0.6 # 80% of the weights are kept
ELEMENTS_KEEP_HIDDEN_OUTPUT = HIDDEN_SIZE*10*0.6 # 80% of the weights are kept
# Masking appears to be very optimized already
MAX_THREADS_MASK = 2
MAX_THREADS_CROSS = 4
# Mutation is adding noise to the mask, lots of threads
MAX_THREADS_MUTATE = 10
TRACK_PERF = True

class MLP(nn.Module):
    def __init__(self, input_size, num_hidden_layers, hidden_size, output_size):
        super(MLP, self).__init__()
        assert num_hidden_layers > 0
        self.input = nn.Linear(input_size, hidden_size)
        self.layers = nn.ModuleList()
        for i in range(num_hidden_layers):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x, mask = None):
        # This allows for a forward that applied the masks to a model's parameters but doesn't train the masks
        x = x.view(x.size(0), -1)


        if mask is not None:
            x = torch.nn.functional.linear(x, self.input.weight * mask['input'], self.input.bias)

            for layer, mask_layer in zip(self.layers, mask['hidden']):
                x = torch.nn.functional.linear(x, layer.weight * mask_layer, layer.bias)
                x = torch.relu(x)

            x = torch.nn.functional.linear(x, self.output.weight * mask['output'], self.output.bias)
        else:
            x = self.input(x)
            for layer in self.layers:
                x = layer(x)
                x = torch.relu(x)
            x = self.output(x)
            
        return torch.softmax(x, dim=1)

def random_masks(num, depth, width):
    if num <= 0:
        return None
    masks = []

    for _ in range(num):
        mask = {}
        mask['input'] = torch.rand(width, 784, device=DEVICE)
        mask['hidden'] = [torch.rand(width, width, device=DEVICE) for _ in range(depth)]
        mask['output'] = torch.rand(10, width, device=DEVICE)
        masks.append(mask)
    return masks

def convert_to_binary_mask(float_masks: dict[str, torch.tensor], elements_keep_input_hidden: int = ELEMENTS_KEEP_INPUT_HIDDEN, 
                           elements_keep_hidden_hidden: int = ELEMENTS_KEEP_HIDDEN_HIDDEN, elements_keep_hidden_output: int = ELEMENTS_KEEP_HIDDEN_OUTPUT):
    
    # Convert the masks for each layer to binary masks
    float_masks['input'] = convert_layer(float_masks['input'], elements_keep_input_hidden)
    for i in range(len(float_masks['hidden'])):
        float_masks['hidden'][i] = convert_layer(float_masks['hidden'][i], elements_keep_hidden_hidden)
    float_masks['output'] = convert_layer(float_masks['output'], elements_keep_hidden_output)
    return float_masks

def convert_layer(float_mask: torch.tensor, elements_keep: int):

    flattened_tensor = float_mask.view(-1)
    # Find the indices of the top k elements
    _, top_indices = torch.topk(flattened_tensor, int(elements_keep))
    zero_tensor = torch.zeros_like(flattened_tensor)
    zero_tensor[top_indices] = 1
    # Reshape back to the original shape
    return zero_tensor.reshape(float_mask.shape)
    

def eval_accuracy_of_mask(index_mask_tuple, train_loader, mask_model, ms):
    index, mask = index_mask_tuple
    binary_mask = ms(mask) # mask the model
    accuracy = evaluate(mask_model, train_loader, binary_mask)
    return index, accuracy

def parallel_crossover(couple, cx):
    off1, off2 = cx(couple)  # generate 2 children
    return off1, off2

def parallel_mutation(child, mt):
    return mt(child)

def train_masks(mask_model, train_loader, keep_best, depth, width, masks, cx, mt, ms):

    if (keep_best + 2 * NUM_CROSSOVER > NUM_MASKS):
        raise ValueError("Too many masks specified to keep")

    for epoch in range(EPOCHS):
        overall_start = time.time()
        accuracies = [None] * len(masks)

        # Evaluate the current generation of masks in seperate threads
        indexed_masks = list(enumerate(masks))
        start_accur_eval = time.time()
        # Use functools.partial to create a new function with frozen extra arguments, in this case the train_loader, mask_model, and ms
        partial_process_mask = partial(eval_accuracy_of_mask, train_loader=train_loader, mask_model=mask_model, ms=ms)
        with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_THREADS_MASK) as executor:
            # Use executor.map to process the (index, mask) tuples concurrently
            results = list(executor.map(partial_process_mask, indexed_masks))
        
        for index, accuracy in results:
            accuracies[index] = accuracy 
        end__accur_eval = time.time()

        # Find the indicies of the best masks that will survive the generation
        best_mask_indexes = np.argpartition(accuracies, -keep_best)[-keep_best:]
        best_masks = [masks[i] for i in best_mask_indexes]

        start_cross_evaluate = time.time()
        # Parallelize the crossover
        partial_process_crossover = partial(parallel_crossover, cx=cx)

        # Add the best masks to the selected masks for crossover. 
        # Then randomly select other masks to perform crossover on. Make sure to shuffle the list first
        selected_masks = best_masks
        selected_masks += random.sample(masks, NUM_CROSSOVER*2 - keep_best)
        random.shuffle(selected_masks)
        couples = [(selected_masks[i], selected_masks[i + 1]) for i in range(0, len(selected_masks) - 1, 2)]

        with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_THREADS_CROSS) as executor:
            results = executor.map(partial_process_crossover, couples)
        # Flatten the list of tuples to a single list of children
        children = [child for couple in results for child in couple]
        end_cross_evaluate = time.time()

        start_mutate = time.time()
        # Parallelize the Mutation
        partial_process_mutation = partial(parallel_mutation, mt=mt)
        with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_THREADS_MUTATE) as executor:
            new_masks = list(executor.map(partial_process_mutation, children))
        
        end_mutate = time.time()
        overall_end = time.time()
        print(f'Epoch {epoch + 1}/{EPOCHS}, Best Accuracy: {max(accuracies)}. All Accuracies: {accuracies}. Accuracy eval time: {end__accur_eval - start_accur_eval}, Crossover time: {end_cross_evaluate - start_cross_evaluate}, Mutation Time: {end_mutate - start_mutate}, Overall time: {overall_end - overall_start}')

        # Keep the 'keep-best' number of original masks, add the new generation of masks, and fill the rest with random masks if needed
        filler_masks = random_masks((NUM_MASKS - len(new_masks) - keep_best), depth, width)
        masks = best_masks + new_masks + filler_masks
        

def run(SHOW_PLOTS):

    # TODO: the total number of neurons in the target network (train_model), should
    # be equal to the number of unmasked neurons in the masked model
    # e.g. 1*2 = 2 hidden neurons in target network, 3*8*sparsity == 2 in mask network
    # with 2l and polynomial width
    criterion = nn.CrossEntropyLoss()

    if TRAIN:
        logging.info('Training...')
        train_accuracies = []
        start_time = timeit.default_timer()
        train_model = MLP(input_size=784, num_hidden_layers=1, hidden_size=2, output_size=10)
        train_model.to(DEVICE)

        optimizer = torch.optim.Adam(train_model.parameters(), lr=LEARNING_RATE)

        for epoch in range(EPOCHS):
            model = train(train_model, train_loader, optimizer, criterion)
            accuracy = evaluate(model, validate_loader)
            train_accuracies.append(accuracy)
            print(f'Epoch {epoch + 1}/{EPOCHS}, Accuracy: {accuracy}')

        end_time = timeit.default_timer()
        print(f'Total training time: {end_time - start_time}')

        if SHOW_PLOTS:
            fig = px.line(x=range(1, EPOCHS + 1), y=train_accuracies)
            fig.show()

        logging.info('Testing...')
        test_accuracy = evaluate(train_model, test_loader)
        print(f'Test Accuracy: {test_accuracy}')

        # save model in case we want to use it again, and accuracy for a stop condition
        torch.save(train_model.state_dict(), './train_model.pt')
        with open('accuracy.json', 'w', encoding='utf-8') as f:
            json.dump(test_accuracy)

    mask_model = MLP(input_size=784, num_hidden_layers=NUM_HIDDEN_LAYERS, hidden_size=HIDDEN_SIZE, output_size=10)
    mask_model.to(DEVICE)

    masks = random_masks(NUM_MASKS, NUM_HIDDEN_LAYERS, HIDDEN_SIZE)
    logging.info(f"Initial mask shape, Input: {masks[0]['input'].shape}, Hidden: {masks[0]['hidden'][0].shape}, Output: {masks[0]['output'].shape}")

    # TODO: replace these with whatever functions you want
    # Takes a couple of masks and returns 2 offspring
    cx = lambda a : uninform_crossover(a)

    # Takes a mask and mutates it
    mt = lambda a : random_mutation(a)

    # Convert mask of floats to a binary mask
    ms = lambda a: convert_to_binary_mask(a, ELEMENTS_KEEP_INPUT_HIDDEN, ELEMENTS_KEEP_HIDDEN_HIDDEN, ELEMENTS_KEEP_HIDDEN_OUTPUT)

    train_masks(mask_model, subset_data_loader, KEEP_BEST, NUM_HIDDEN_LAYERS, HIDDEN_SIZE, masks, cx, mt, ms)


run(SHOW_PLOTS=True)

Initial mask shape, Input: torch.Size([128, 784]), Hidden: torch.Size([128, 128]), Output: torch.Size([10, 128])


Epoch 1/2000, Best Accuracy: 0.118. All Accuracies: [0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.108, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.118, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.106, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098]. Accuracy eval time: 10.639198541641235, Crossover time: 0.011903762817382812, Mutation Time: 0.012896060943603516, Overall time: 10.663998365402222
Epoch 2/2000, Best Accuracy: 0.13. All Accuracies: [0.098, 0.098, 0.098, 0.098, 0.098, 0.118, 0.09

KeyboardInterrupt: 