In [43]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [44]:
from torch.cuda import is_available

devi = "cuda" if is_available() else "cpu"
# device = "cpu"
print("==> Device:", devi)

==> Device: cuda


In [45]:
from torch import manual_seed
manual_seed(16)

<torch._C.Generator at 0x7fea3c5a66f0>

In [46]:
from torch import load, long
from torch.utils.data import Dataset, DataLoader
from torch import nn

class DS(Dataset):
    def __init__(self, maps, labels) -> None:
        self.maps = maps
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        X = self.maps[idx]
        #X = X.reshape(1, -1)
        X = X.unsqueeze(0)
        y = self.labels[idx]
        return X.to(devi, dtype=pt_float), y.to(devi, dtype=long)

# Data
X_train, y_train = load('/kaggle/input/pytorch-mnist/training.pt')
X_test, y_test = load('/kaggle/input/pytorch-mnist/test.pt')
X_train.shape, y_train.shape, X_test.shape, y_test.shape

(torch.Size([60000, 28, 28]),
 torch.Size([60000]),
 torch.Size([10000, 28, 28]),
 torch.Size([10000]))

In [47]:
from torch import float as pt_float, ones

class NET(nn.Module):
    def __init__(self, 
                 l1, a1, l2, a2, l3, a3):
        super().__init__()
        
        self.lin1 = nn.Sequential(
            nn.Flatten(),
            nn.LazyLinear(l1),
            nn.Dropout(0.5),
            nn.__getattribute__(a1)())
        
        self.lin2 = nn.Sequential(
            nn.LazyLinear(l2),
            nn.Dropout(0.5),
            nn.__getattribute__(a2)())

        self.lin3 = nn.Sequential(
            nn.LazyLinear(l3),
            nn.Dropout(0.5),
            nn.__getattribute__(a3)())

        self.out = nn.Sequential(
            nn.LazyLinear(10),
            nn.LogSoftmax(dim=-1))
        
        self.model = nn.Sequential(
            self.lin1,
            self.lin2,
            self.lin3,
            self.out
        )
                
    def forward(self, x):
        """Forward"""
        return self.model(x)
    
    def count_weights_biases(self):
        return int(sum(p.numel() for p in self.parameters() if p.requires_grad))
    
net = NET(10, 'SELU', 10, 'SELU', 10, 'SELU').to(devi)
f'Dry run: {net(ones(1, 1, 28, 28).to(devi, dtype=pt_float)).shape}'

'Dry run: torch.Size([1, 10])'

In [48]:
class GA_Pytorch():
    def __init__(self, 
                 params, 
                 eval_func,
                 eval_weights,
                 X_train,
                 X_test,
                 y_train,
                 y_test,
                 batch_size=64,
                 lr=0.0001,
                 sel_tournsize=2, 
                 cx_uniform_prob=0.5, 
                 mut_shuffle_idx_prob=0.1, 
                 n_pop=50, 
                 n_gen=20, 
                 n_hof=5, 
                 cx_prob=0.5, 
                 mut_prob=0.1, 
                 n_jobs=1
                ):
        self.params = params
        self.eval_func = eval_func
        self.eval_weights = eval_weights
        
        self.X_train = X_train
        self.X_test = X_test
        self.y_train = y_train
        self.y_test = y_test
        self.batch_size = batch_size
        self.lr = lr
        
        self.sel_tournsize = sel_tournsize
        self.cx_uniform_prob = cx_uniform_prob
        self.mut_shuffle_idx_prob = mut_shuffle_idx_prob
        self.n_pop = n_pop
        self.n_gen = n_gen
        self.n_hof = n_hof
        self.cx_prob = cx_prob
        self.mut_prob = mut_prob
        
        self.n_jobs = n_jobs

        self._pad_params()
        self._create_fitness_and_indiv()
        self._register_indiv_and_pop_generators()
        self._register_eval_func()
        self._register_selection_crossover_mutation_methods()

    def _pad_params(self):
        """Pad params for crossover shuffle idx method"""
        assert isinstance(self.params, dict), 'Params must be a dict, i.e. estimator.get_params()'
        params_count = {k: len(v) for k,v in self.params.items()}
        max_length, max_key = -99, ''
        for k, v in params_count.items():
            if v <= max_length:
                continue
            else:
                max_key = k
                max_length = v
        assert isinstance(max_length, int), 'The max length between all params must be an int'
        # cycle through params for max length param, otherwise infinite cycle
        values_padded = (cycle(v) if k!=max_key else v for k,v in self.params.items())
        values_padded = zip(*values_padded)  # ('a', 1, 14), ('b', 2, 16), ('c', 3, 16) ...
        values_padded = zip(*values_padded)  # ('a', 'b', 'c'), (1, 2, 3), (14, 15, 16)...
        padded_params = {}
        for k, v in zip(self.params, values_padded):
            padded_params[k] = v
        self.padded_params = padded_params
        print('Params padded')

    def _create_fitness_and_indiv(self):
        """Create GA individual and fitness entities (classes)"""
        ga_cr.create('Fitness', ga_b.Fitness, weights=self.eval_weights)
        ga_cr.create('Individual', list, fitness=ga_cr.Fitness)
        print('GA entities created')

    def _gen_params_to_ga(self):
        """Generate index for each param for individual"""
        max_dict = len(self.padded_params)
        max_length = len(list(self.padded_params.values())[0])
        idxs = [randint(0, max_length-1) for _ in range(max_dict)]
        return idxs
    
    def _register_indiv_and_pop_generators(self):
        """Register GA individual and population generators"""
        self.tb = ga_b.Toolbox()

        if self.n_jobs > 1:
            from multiprocessing import Pool
            pool = Pool()
            self.tb.register("map", pool.map)

        self.tb.register("individual", ga_t.initIterate, ga_cr.Individual, self._gen_params_to_ga)
        self.tb.register("population", ga_t.initRepeat, list, self.tb.individual)
        print('GA entities\' methods registered')
        
    def _register_eval_func(self):
        """Set GA evaluate individual function"""
        self.tb.register("evaluate",
                        self.eval_func,
                        padded_params=self.padded_params,
                        X_train=self.X_train,
                        X_test=self.X_test, 
                        y_train=self.y_train, 
                        y_test=self.y_test,
                        batch_size=self.batch_size,
                        lr=self.lr)
        #print(list(self.tb.evaluate(indiv) for indiv in self.tb.population(3)))
        print('GA eval function registered')
    
    def _register_selection_crossover_mutation_methods(self):
        self.tb.register("select", ga_t.selTournament, tournsize=self.sel_tournsize)
        self.tb.register("mate", ga_t.cxUniform, indpb=self.cx_uniform_prob)
        self.tb.register("mutate", ga_t.mutShuffleIndexes, indpb=self.mut_shuffle_idx_prob)
        print('GA sel-cx-mut methods registered')
        
    def run_ga_search(self):
        """GA Search"""
        pop = self.tb.population(n=self.n_pop)
        hof = ga_t.HallOfFame(self.n_hof)

        # Stats stdout
        #stats = ga_t.Statistics(lambda ind: ind.fitness.values )
        stats1 = ga_t.Statistics(lambda ind: ind.fitness.values[0] )
        stats2 = ga_t.Statistics(lambda ind: ind.fitness.values[1] )
        stats3 = ga_t.Statistics(lambda ind: ind.fitness.values[2] )
        stats = ga_t.MultiStatistics(accuracy=stats1, risk=stats2, complexity=stats3)
        stats.register("avg", mean)
        #stats.register("std", np.std)
        #stats.register("min", np.min)
        #stats.register("max", np.max)

        # History
        #hist = tools.History()
        #toolbox.decorate("select", hist.decorator)
        #tb.decorate("mate", hist.decorator)
        #tb.decorate("mutate", hist.decorator)
        #hist.update(pop)

        # GA Run
        pop, log = ga_algo.eaSimple(pop, self.tb, cxpb=self.cx_prob, 
                                    mutpb=self.mut_prob, ngen=self.n_gen, 
                                    stats=stats, halloffame=hof, verbose=True)
        
        # Convert back params
        hof_ = {}
        for i in range(self.n_hof):
            hof_['hof_' + str(i)] = self._ga_to_params(hof[i])

        return pop, log, hof_
    
    def _ga_to_params(self, idx_params):
        """Convert back idx to params"""
        res = {}
        for (k,v), idx in zip(self.padded_params.items(), idx_params):
            res[k] = v[idx]
        return res

In [49]:
from numpy import mean, linspace

net_params = {
    'l1': linspace(1,100,100).astype(int),
    'a1': ['ReLU', 'CELU', 'SELU', 'ELU', 'Softsign'],
    'l2': linspace(1,100,100).astype(int),
    'a2': ['ReLU', 'CELU', 'SELU', 'ELU', 'Softsign'],
    'l3': linspace(1,100,100).astype(int),
    'a3': ['ReLU', 'CELU', 'SELU', 'ELU', 'Softsign'],
}

def net_eval_indiv(individual, padded_params, X_train, X_test, y_train, y_test, batch_size, lr):
    """Evaluate individual's genes (estimator's params)"""

    # Params
    indiv_params = {k : list(v)[idx] for (k,v), idx in zip(padded_params.items(), individual)}
    
    # Net
    net = NET(**indiv_params).to(devi)
    net(ones(1,1,28,28).to(devi))
    
    # Optimizer
    optimizer = Adam(net.parameters(), lr=lr)
    criterion = nn.NLLLoss()
    
    # Train
    train_ds = DS(X_train, y_train)  # TODO refactor out
    train_dl = DataLoader(train_ds,
                        batch_size=batch_size,
                        shuffle=True,
                        drop_last=True,
                         )
    
    for epoch in range(1):
        #running_loss = []
        train_correct = 0
        train_total = 0
        for i, (inputs, labels) in enumerate(train_dl):
            if i <= 100:
                #inputs = inputs.cuda()
                #labels = labels.cuda()
                outputs = net(inputs)

                optimizer.zero_grad()
                loss = criterion(outputs, labels).mean()
                loss.backward()
                optimizer.step()

                # print statistics
                #running_loss.append(loss.item())
                _, predicted = pt_max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
            else:
                break
        #print(f'TRAIN {train_correct / train_total * 100:^5.2f} %', end=' ')
        
    # Eval
    with no_grad():
        net = net.eval()
        test_ds = DS(X_test, y_test)  # TODO refactor out
        test_dl = DataLoader(test_ds,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=True)
        #running_loss = []
        test_correct = 0
        test_total = 0
        for i, (inputs, labels) in enumerate(test_dl):
            if i <= 50:
                #inputs = inputs.cuda()
                #labels = labels.cuda()
                outputs = net(inputs)

                # print statistics
                #running_loss.append(loss.item())
                _, predicted = pt_max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()
                test_accuracy = test_correct / test_total * 100
            else:
                break
        #print(f'TEST {test_accuracy:^5.2f} %')
        
    # Risk
    risk = median(prod(net(inputs).exp()*10, dim=1))
    if isnan(risk):
        risk = 10
    else:
        risk = float(risk)
        
    # Complexity
    compl = net.count_weights_biases()

    return (test_accuracy, risk, compl,)

net_weights = (1, -1, -1)

In [50]:
from itertools import cycle
from deap import creator as ga_cr, base as ga_b, algorithms as ga_algo, tools as ga_t
from random import randint
from numpy import mean
from torch.optim import Adam
from torch import max as pt_max, no_grad, median, prod, isnan

net_ga_params = GA_Pytorch(net_params, 
                           net_eval_indiv, 
                           net_weights,
                           X_train, 
                           X_test, 
                           y_train, 
                           y_test)
pop, log, hof = net_ga_params.run_ga_search()

Params padded
GA entities created
GA entities' methods registered
GA eval function registered
GA sel-cx-mut methods registered
   	      	       accuracy       	      complexity      	             risk             
   	      	----------------------	----------------------	------------------------------
gen	nevals	avg    	gen	nevals	avg    	gen	nevals	avg     	gen	nevals
0  	50    	18.2518	0  	50    	43541.3	0  	50    	0.667339	0  	50    
1  	19    	21.0901	1  	19    	50628.3	1  	19    	0.616537	1  	19    
2  	28    	22.5141	2  	28    	56270.8	2  	28    	0.567989	2  	28    
3  	28    	23.9767	3  	28    	60439.1	3  	28    	0.572068	3  	28    
4  	38    	25.2065	4  	38    	56892  	4  	38    	0.611387	4  	38    
5  	30    	28.4786	5  	30    	59972.6	5  	30    	0.651424	5  	30    
6  	26    	32.1471	6  	26    	61109.6	6  	26    	0.688419	6  	26    
7  	31    	34.0705	7  	31    	61497.8	7  	31    	0.72142 	7  	31    
8  	24    	36.7953	8  	24    	65190.9	8  	24    	0.737874	8  	24    
9  	29 

In [52]:
list(hof.values())

[{'l1': 86, 'a1': 'Softsign', 'l2': 92, 'a2': 'CELU', 'l3': 99, 'a3': 'SELU'},
 {'l1': 86, 'a1': 'Softsign', 'l2': 92, 'a2': 'ELU', 'l3': 86, 'a3': 'SELU'},
 {'l1': 86, 'a1': 'Softsign', 'l2': 92, 'a2': 'CELU', 'l3': 86, 'a3': 'SELU'},
 {'l1': 86, 'a1': 'Softsign', 'l2': 92, 'a2': 'CELU', 'l3': 86, 'a3': 'SELU'},
 {'l1': 86, 'a1': 'Softsign', 'l2': 70, 'a2': 'CELU', 'l3': 99, 'a3': 'SELU'}]

In [53]:
from joblib import dump, load

hof
dump(hof, 'best_params.json')

{'hof_0': {'l1': 86,
  'a1': 'Softsign',
  'l2': 92,
  'a2': 'CELU',
  'l3': 99,
  'a3': 'SELU'},
 'hof_1': {'l1': 86,
  'a1': 'Softsign',
  'l2': 92,
  'a2': 'ELU',
  'l3': 86,
  'a3': 'SELU'},
 'hof_2': {'l1': 86,
  'a1': 'Softsign',
  'l2': 92,
  'a2': 'CELU',
  'l3': 86,
  'a3': 'SELU'},
 'hof_3': {'l1': 86,
  'a1': 'Softsign',
  'l2': 92,
  'a2': 'CELU',
  'l3': 86,
  'a3': 'SELU'},
 'hof_4': {'l1': 86,
  'a1': 'Softsign',
  'l2': 70,
  'a2': 'CELU',
  'l3': 99,
  'a3': 'SELU'}}

['best_params.json']