<a href="https://colab.research.google.com/github/jonbaer/googlecolab/blob/master/EvoNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Backpropagation vs. Evolutionary Strategy on GPU

* Backpopragation baseline:
 * Number of epochs: 10
 * Final accuracy: 97%
 * Seconds per epoch: 9

* Evolutionary Strategy:
 * Number of epochs: 10
 * Final accuracy: 90%
 * Seconds per epoch: 9

**Evolutionary Strategy on CPU is much slower. This is because the loss for every individual in the population is calculated in parallel, so make sure you run this notebook on a GPU to reproduce the results.**

## Common stuff

In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import time

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32
val_batch_size = 32
random_seed = 1337
lr = 1E-3
epochs = 10

random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)

transform=transforms.Compose([
    transforms.ToTensor(),
])

train = datasets.MNIST('../data', train=True, download=True, transform=transform)
val = datasets.MNIST('../data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train, batch_size, shuffle=True, pin_memory=torch.cuda.is_available())
val_loader = torch.utils.data.DataLoader(val, val_batch_size, shuffle=False, pin_memory=torch.cuda.is_available())

@torch.inference_mode()
def evaluate(model: nn.Module):
    model.eval()
    total = 0
    loss = 0
    correct = 0
    for input, target in val_loader:
        input, target = input.to(device), target.to(device)
        output = model.forward(input)
        loss += F.cross_entropy(output, target, reduction='sum').item()
        pred = output.argmax(dim=-1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += input.size(0)

    return loss / total, correct / total

## Train with Backpropagation

In [None]:
class BackpropModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 64, bias=False)
        self.fc2 = nn.Linear(64, 10, bias=False)

    def forward(self, x: torch.Tensor):
        x = self.fc1.forward(x.flatten(1))
        return self.fc2.forward(F.silu(x))

model = BackpropModel()
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(), lr)

def train_for_epoch():
    model.train()
    for input, target in train_loader:
        optim.zero_grad()
        input, target = input.to(device), target.to(device)
        output = model.forward(input)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optim.step()

epoch = 0
loss, accuracy = evaluate(model)
print(f'epoch {epoch} | loss: {loss:.4f} | accuracy: {accuracy:.2%}')

while epoch < epochs:
    epoch += 1
    t0 = time.time()
    train_for_epoch()
    if device == "cuda":
        torch.cuda.synchronize()
    dt = time.time() - t0
    loss, accuracy = evaluate(model)
    print(f'epoch {epoch} | loss: {loss:.4f} | accuracy: {accuracy:.2%} | seconds per epoch: {dt:.3f}')

epoch 0 | loss: 2.3049 | accuracy: 6.89%
epoch 1 | loss: 0.1997 | accuracy: 94.11% | seconds per epoch: 8.490
epoch 2 | loss: 0.1356 | accuracy: 95.90% | seconds per epoch: 8.893
epoch 3 | loss: 0.1042 | accuracy: 96.85% | seconds per epoch: 8.903
epoch 4 | loss: 0.0964 | accuracy: 97.16% | seconds per epoch: 8.888
epoch 5 | loss: 0.0911 | accuracy: 97.19% | seconds per epoch: 8.697
epoch 6 | loss: 0.0894 | accuracy: 97.12% | seconds per epoch: 8.677
epoch 7 | loss: 0.0845 | accuracy: 97.48% | seconds per epoch: 9.611
epoch 8 | loss: 0.0728 | accuracy: 97.52% | seconds per epoch: 9.800
epoch 9 | loss: 0.0768 | accuracy: 97.68% | seconds per epoch: 8.868
epoch 10 | loss: 0.0787 | accuracy: 97.62% | seconds per epoch: 8.842


## Train with Evolutionary Strategy

In [None]:
lr = 2.7E-3
population_size = 64
generations_per_batch = 2
num_parents_for_mating = 4

class EvoLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.weight: torch.Tensor
        self.register_buffer('weight', torch.zeros(out_features, in_features))
        self.offspring: torch.Tensor | None = None

    def next_generation(self, population_size: int, lr: float):
        out_features, in_features = self.weight.size()
        mean = self.weight.expand(population_size, 1, out_features, in_features)
        self.offspring = torch.normal(mean, std=lr)

    def mate(self, parents: list[int]):
        self.weight = self.offspring[parents, 0, :, :].mean(0, keepdim=False)

    def reset(self):
        self.offspring = None

    def forward(self, x: torch.Tensor):
        if self.offspring is not None:
            if x.dim() == 2:
                x = x.unsqueeze(0)
            return torch.einsum('ebi,eboi->ebo', x, self.offspring)
        return F.linear(x, self.weight)

class EvoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = EvoLinear(28 * 28, 64)
        self.fc2 = EvoLinear(64, 10)

    def forward(self, x: torch.Tensor):
        x = self.fc1.forward(x.flatten(1))
        return self.fc2.forward(F.silu(x))

    def next_generation(self, population_size: int, lr: float):
        for m in self.modules():
            if isinstance(m, EvoLinear):
                m.next_generation(population_size, lr)

    def mate(self, parents: list[int]):
        for m in self.modules():
            if isinstance(m, EvoLinear):
                m.mate(parents)

    def reset(self):
        for m in self.modules():
            if isinstance(m, EvoLinear):
                m.reset()

model = EvoModel()
model = model.to(device)

@torch.inference_mode()
def evolve_for_epoch(lr: float):
    model.eval()
    for input, target in train_loader:
        input, target = input.to(device), target.to(device)
        for _ in range(generations_per_batch):
            model.next_generation(population_size, lr)
            output = model.forward(input)
            loss = F.cross_entropy(output.flatten(0, 1), target.expand(population_size, -1).flatten(), reduction='none')
            loss = loss.unflatten(0, (population_size, target.size(0))).mean(dim=-1)
            parents = torch.topk(loss, k=num_parents_for_mating, largest=False).indices.tolist()
            model.mate(parents)

epoch = 0
model.reset()
loss, accuracy = evaluate(model)
print(f'epoch {epoch} | loss: {loss:.4f} | accuracy: {accuracy:.2%}')

while epoch < epochs:
    epoch += 1
    t0 = time.time()
    evolve_for_epoch(lr / epoch)
    if device == "cuda":
        torch.cuda.synchronize()
    dt = time.time() - t0
    model.reset()
    loss, accuracy = evaluate(model)
    print(f'epoch {epoch} | loss: {loss:.4f} | accuracy: {accuracy:.2%} | seconds per epoch: {dt:.3f}')

epoch 0 | loss: 2.3026 | accuracy: 9.80%
epoch 1 | loss: 0.4850 | accuracy: 86.40% | seconds per epoch: 8.426
epoch 2 | loss: 0.4059 | accuracy: 88.01% | seconds per epoch: 9.119
epoch 3 | loss: 0.3815 | accuracy: 88.97% | seconds per epoch: 8.995
epoch 4 | loss: 0.3671 | accuracy: 89.26% | seconds per epoch: 9.017
epoch 5 | loss: 0.3601 | accuracy: 89.43% | seconds per epoch: 9.042
epoch 6 | loss: 0.3499 | accuracy: 89.68% | seconds per epoch: 8.669
epoch 7 | loss: 0.3453 | accuracy: 89.84% | seconds per epoch: 8.629
epoch 8 | loss: 0.3410 | accuracy: 89.93% | seconds per epoch: 10.034
epoch 9 | loss: 0.3361 | accuracy: 90.14% | seconds per epoch: 9.454
epoch 10 | loss: 0.3329 | accuracy: 90.34% | seconds per epoch: 8.972
