In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim.optimizer import Optimizer, required

import tqdm as tqdm
from copy import deepcopy
import matplotlib.pyplot as plt 

## ++ Добавить LOGGER !!

In [None]:
# INIT NETS

num_nets = 3

nets = []

for i in range(num_nets):
  net = torchvision.models.mobilenet_v2(pretrained=True)
  num_ftrs = net.last_channel

  classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

  net.classifier = nn.Linear(num_ftrs, len(classes))
  net.cuda()
  nets.append(net)

In [None]:
# net_one = deepcopy(net)
# net_one.cuda()
# net_two = deepcopy(net)
# net_two.cuda()
# net_three = deepcopy(net)
# net_three.cuda()
# print('')




In [None]:
import itertools 
from copy import deepcopy, copy
import numpy as np
from tqdm import tqdm
import torch



class CrossN:
    def __init__(self, n_combinations = 2, pmix = 0.5, mode = 'mask'):
        self.p = pmix
        self.n_combinations = n_combinations
        self.layers = []
        self.mode = mode
        
    def make_ancestry(self, nets):
        n = []
        abc = ['A','B','C','D','E','F']
        for counter, net in enumerate(nets):
            if not 'ancestry' in net.__dir__():
                if len(abc) > counter:
                    net.ancestry = '[{}]'.format(abc[counter])
                else: 
                    net.ancestry = str(counter)
            n.append(net)
        return n 
 
    def make_combinations(self, nets):
        nets = self.make_ancestry(nets)
        parents = [n.ancestry for n in nets]
        pairs = itertools.product(parents,repeat = 
                                           min(len(parents), 
                                              self.n_combinations))

        return pairs
    
    def find_layers(self, model, keyword = 'weight'):
        model_state_dict = model.state_dict()
        self.layers = [layer_name for layer_name in model_state_dict 
                       if keyword in layer_name]
        
    
    def switch_weights(self, parentA, parentB):
        
        # abundant copying may lead to unnessary memmory usage - need to optimize
        child = deepcopy(parentA) 
        parent_B = deepcopy(parentB)
        parentB_state_dict = parent_B.state_dict()
        # p is responsible for percantage or False / Zero values, values not to be replaced
        p = self.p
        
        if parentA.ancestry != parent_B.ancestry:
            
            child_params = child.state_dict()
            
            if self.mode == 'mask':
              for layer in self.layers:
                  w = parentB_state_dict[layer]
                  shape = w.shape
                  
                  # creating the mask with tensor does not cause shape mismatch problem 
                  #mask =  np.random.choice([False, True], size=(shape), p=[1-p, p])
                  mask = torch.cuda.FloatTensor(shape).uniform_() > p
                  # switching weights
                  try:
                      child_params[layer][mask]=w[mask]
                      #child_params[layer][mask]=child_params[layer][mask]=w[mask].view(shape)
                  except Exception as e:
                      print('skipping layer: ',layer, e)
              #HOWTO initiate a new model in any other way?
            
            elif self.mode == 'addition': 
              for layer in self.layers:
                  w = parentB_state_dict[layer]
                  try:
                      child_params[layer]+=p*w
                  except Exception as e:
                      print('skipping layer: ',layer, e)

            
            child.load_state_dict(child_params)
          
        child.ancestry +=',{}.{}'.format(round(1-p, 2), parent_B.ancestry.replace(',',' '))
        return child
    
    
    def breed(self, nets: list()):
        nets = deepcopy(nets)
        self.find_layers(nets[0])
        families = self.make_combinations(nets)
        children = []
        nets_named = {n.ancestry:n for n in nets}
        for family in families:
            parents = [nets_named[f] for f in family]
            child = self.switch_weights(*parents)
            yield child
    
    def history(self, net):
        for i, p in enumerate(net.ancestry.split(',')):
            print('{}+{}'.format('  '*i,p))


cross = CrossN(pmix=0.0001, mode='addition')

In [70]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=512,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [None]:
def validation(net, testloader):
    correct = 0
    total = 0
    with torch.no_grad():
        # try filter by loss too
        for data in testloader:
                images, labels = data
                images = images.cuda()
                labels = labels.cuda()
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [None]:
def train_one_epoch(net, trainloader, testloader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    loss = 0.0
    #print('training: {}'.format(net.ancestry))
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    val_score = validation(net, testloader)
    print('[%d, %5d] loss: %.3f validation score: %.2f' %
          (epoch + 1, i + 1, loss, val_score))
    return net, loss

In [None]:
# FIX FILTERING !!!

def train_n_breed(nets, testloader):
    trained_nets = []

    for i, net in enumerate(nets):
        print('Training child : {}'.format(i+1))
        net, loss = train_one_epoch(net, trainloader, testloader)
        trained_nets.append(net)
    best_score = 0
    best_nets = dict()
    print('BREEDING & FILTERING')
    for n in cross.breed(trained_nets):
        
        score = validation(n, testloader)
        print('SCORE: {}, NAME: {}'.format(score, n.ancestry ))

        if len(best_nets) < 3:
          best_nets[score] = n

        else:

          if min(best_nets.keys()) < score:
              #print(' ')
              #print('switch')
              del best_nets[min(best_nets.keys())]
              best_nets[score] = n
              #best_nets.pop(0)
    best_score = max(best_nets.keys())
    print('Breeding scores:',list(best_nets.keys()))
    print('Len nets after breeding: ',len(best_nets))
    return best_nets.values(), best_score

In [51]:
list(cross.make_combinations(nets))

[('[A]', '[A]'),
 ('[A]', '[B]'),
 ('[A]', '[C]'),
 ('[B]', '[A]'),
 ('[B]', '[B]'),
 ('[B]', '[C]'),
 ('[C]', '[A]'),
 ('[C]', '[B]'),
 ('[C]', '[C]')]

In [52]:
{n.ancestry:n for n in nets}.keys()

dict_keys(['[A]', '[B]', '[C]'])

In [None]:
num_epochs = 10
scores = []

# nets = deepcopy([  net_one, 
#           net_two, 
#           net_three ])

for epoch in range(num_epochs):
    best_nets, best_score = train_n_breed(nets, testloader)
    nets = best_nets

    print('EPOCH:', epoch)
    print('Best_score after breeding {}'.format(best_score))
    for i, n in enumerate(nets):
        print('Child: {}'.format(i+1))
        cross.history(n)

Training child : 1
[1,    98] loss: 1.234 validation score: 58.67
Training child : 2
[1,    98] loss: 1.164 validation score: 58.01
Training child : 3
[1,    98] loss: 1.184 validation score: 59.08
BREEDING & FILTERING
SCORE: 58.67, NAME: [A],1.0.[A]
SCORE: 58.68, NAME: [A],1.0.[B]
SCORE: 58.68, NAME: [A],1.0.[C]
SCORE: 58.01, NAME: [B],1.0.[A]
SCORE: 58.01, NAME: [B],1.0.[B]
SCORE: 58.01, NAME: [B],1.0.[C]
SCORE: 59.09, NAME: [C],1.0.[A]
SCORE: 59.09, NAME: [C],1.0.[B]
SCORE: 59.08, NAME: [C],1.0.[C]
Breeding scores: [58.68, 59.09, 59.08]
Len nets after breeding:  3
EPOCH: 0
Best_score after breeding 59.09
Child: 1
+[A]
  +1.0.[C]
Child: 2
+[C]
  +1.0.[B]
Child: 3
+[C]
  +1.0.[C]
Training child : 1
[2,    98] loss: 0.938 validation score: 65.22
Training child : 2
[2,    98] loss: 0.979 validation score: 65.79
Training child : 3
[2,    98] loss: 0.934 validation score: 65.81
BREEDING & FILTERING
SCORE: 65.22, NAME: [A],1.0.[C],1.0.[A] 1.0.[C]
SCORE: 65.22, NAME: [A],1.0.[C],1.0.[C] 1.0