## Download & Import necessary libs

In [4]:
! pip install pytorch_lightning --quiet

from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import pytorch_lightning as pl

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


## PyTorch-Lightning Data Module
This class download and split the MNIST dataset into training-validation-testing sets, albeit the validation set isn't necessary.

In [5]:
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 2):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

## PyTorch module
A very basic CNN to do the multi-classification task. The class must have the option to instantiate without passing any arguments to the constructor.

In [6]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)
        # self.conv3 = nn.Conv2d(64, 256, 3, 1, 1)
        # self.conv4 = nn.Conv2d(256, 512, 3, 1, 1)
        self.adapt = nn.AdaptiveAvgPool2d(3)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64*3*3, 512)
        self.linear2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        # x = torch.relu(self.conv3(x))
        # x = torch.relu(self.conv4(x))
        x = self.adapt(x)
        x = self.flatten(x)
        x = torch.relu(self.linear1(x))
        return self.linear2(x)

## A PyTorch-Lightning Module
It's used for the training of a PyTorch model. The constructor **must** have the first parameter reserved for PyTorch model.

In [7]:
 class LitModel(pl.LightningModule):

    def __init__(self, model=None):
        super().__init__()

        self.model = Model() if model is None else model
        self.cost = nn.CrossEntropyLoss()

    def forward(self, z):
        return self.model(z)

    def training_step(self, batch, batch_idx, ):
        imgs, y = batch
        y_hat = self(imgs)
        c = self.cost(y_hat, y)
        self.log('train_loss', c)
        return c

    def validation_step(self, batch, batch_idx, ):
        imgs, y = batch
        y_hat = self(imgs)
        c = self.cost(y_hat, y)
        self.log('validation_loss', c)
        return c

    def test_step(self, batch, batch_idx, ):
        imgs, y = batch
        y_hat = self(imgs)
        c = self.cost(y_hat, y)
        self.log('test_loss', c)
        return c     

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3, )
        return [opt], []

## GA trainer
It consists of 4 global variables (_Model, ..., _modules), and two classes (Individual, GA).   

In [36]:
import numpy as np

import torch
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning import Trainer

from tqdm.notebook import tqdm 

from typing import List, Optional, ClassVar
from functools import reduce


## Global variables 
_Model = None
_LitModel = None
_loss = ''
_modules = []

class Individual:

    def __init__(self, solution:Optional[torch.BoolTensor]=None):
        # Generate a random solution
        self.solution = torch.rand(len(_modules)) < .5 if solution is None else solution

    def mutate(self, proba=.1):
        # Switch the bit of certain modules
        mask = torch.rand(*self.solution.shape) < proba
        for i in range(mask.size(0)):
            if mask[i]:
                self.solution[i] = self.solution[i].logical_not()

    def crossover(self, other,):
        # 
        n = self.solution.size(0)
        pt = torch.randint(low=1, high=n, size=(1,))
        
        def offspring(obj1, obj2,):
            solution = torch.hstack([obj1.solution[:pt], obj2.solution[pt:]])
            return Individual(solution)

        return offspring(self, other,), offspring(other, self, )

    def _disable_layers(self, model):
        for sol, mod in zip(self.solution.tolist(), _modules):
            # Reach out that specific layer
            mod = [model] + mod.split('.')
            mod = reduce(lambda a, b : getattr(a, b), mod)

            # Disable/Enable the layer
            for param in mod.parameters():
                param.requires_grad = sol

    def save_model(self, model, id=None):
        self.path = f'individual {id}.pth'
        torch.save(model.state_dict(), self.path)
        del model
    
    def load_model(self):
        model = _Model()
        model.load_state_dict(torch.load(self.path))
        return model, _LitModel(model,)

    def fitness(self, trainer, dm):
        # Load model to RAM
        _, litmodel = self.load_model()

        # Set requires_grad=False wherever it needs to be
        self._disable_layers(litmodel.model)
        
        # Train & estimate performance
        trainer.fit(litmodel, dm)
        trainer.test(litmodel, verbose=False)

        # Save the model
        self.save_model(litmodel.model)

        # Return the fitness = 1 / (loss + eps)
        return 1/(trainer.logged_metrics[_loss] + 1e-6)

#####################################################################################################################################################

class GA:    

    def __init__(self, Model, LitModel, model:nn.Module, modules:List[str], dm=None, n_individuals=10, loss:str='test_loss', **kwargs):
        
        assert n_individuals%4==0, 'The number of individuals must be divisible by 4.'
        assert n_individuals<=2**len(modules), f'The number of {n_individuals} can\'t exceed the number of all possibilities {2**len(modules)}'

        global _modules, _loss, _Model, _LitModel

        _modules = modules
        _loss = loss
        _Model = Model
        _LitModel = LitModel

        self.model = model
        self.trainer = Trainer(**kwargs)
        self.dm = dm

        self.individuals = []
        for i in range(n_individuals):
            self.individuals.append(Individual())
    
    def warmup(self, epochs):
        _copy = int(self.trainer.max_epochs)*1
        
        # Make all model layers trainable
        for param in self.model.parameters():
            param.requires_grad = True

        self.trainer.fit(_LitModel(self.model), self.dm)
        self.trainer.max_epochs = _copy

    def mutate(self, pb=.1):
        for c in self.individuals:
            c.mutate(pb)

    def save_models(self):
        for i, c in enumerate(self.individuals):
            c.save_model(self.model, i)

    def evaluate(self):
        fitnesses = []
        with tqdm(total=len(self.individuals), leave=False) as pbar:
            for c in self.individuals:
                fitness = c.fitness(self.trainer, self.dm)
                fitnesses.append(fitness)
                pbar.set_postfix({'test loss': 1/fitness.item()-1e-6})
                pbar.update(1)

        return fitnesses

    def _roulette_wheel(self, fitnesses):
        fits = fitnesses/fitnesses.sum()
        n = fits.size(0)

        return np.random.choice(list(range(n)), size=n//2, replace=False, p=fits.numpy())
    
    def selection_crossover(self, fitnesses):
        chosen = self._roulette_wheel(torch.tensor(fitnesses))
        population = []
        for i in range(0, len(chosen), 2):
            p1, p2 = chosen[i], chosen[i+1]
            off1, off2 = self.individuals[p1].crossover(self.individuals[p2])
            population.extend([off1, off2, self.individuals[p1], self.individuals[p2]])

        self.individuals = population

    def run(self, n_generations):
        with tqdm(total=n_generations) as pbar:
            for _ in range(n_generations):
                # Save model
                self.save_models()

                # Train/fitness
                fitnesses  = self.evaluate()

                # Retrieve and dispatch best model to others
                id = torch.tensor(fitnesses).argmax()
                best_fitness = fitnesses[id]
                self.model, _ = self.individuals[id].load_model()

                # Selection & Crossover
                self.selection_crossover(fitnesses)

                # Mutate
                self.mutate()

                # Replace duplicates
                n = 1
                while n > 0:
                    n = len(self.individuals)
                    set_ = set(map(lambda x: tuple(x.solution.tolist()), self.individuals))
                    n = n - len(set_)
                    self.individuals = [Individual(solution=torch.BoolTensor(x)) for x in set_]
                    for _ in range(n):
                        self.individuals.append(Individual())

                pbar.update(1)
                pbar.set_postfix({'Best loss': 1/best_fitness.item() - 1e-6})


## Train a model using the GA strategy

In [37]:
dm = MNISTDataModule()
dm.prepare_data()

model = LitModel()

args = dict(gpus=-1, max_epochs=1, weights_summary=None, progress_bar_refresh_rate=0, auto_scale_batch_size=True)

ga = GA(Model, LitModel, model.model, ['conv1', 'conv2', 'linear1'], dm=dm, n_individuals=4, **args)

ga.warmup(2)

ga.run(2)

## Train a model using the straight-forward approach

In [12]:
model2 = LitModel()
trainer = Trainer(gpus=-1, max_epochs=1)
trainer.fit(model2, dm)
trainer.test(model2, )