In [None]:
import numpy as np
import torch as t
import matplotlib.pyplot as plt

In [None]:
device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
print(device)

In [None]:
PROBLEM_DIMENSION = 30
NUM_SAMPLE = 50

In [None]:
def check_sets_feasibility(samples):
    return t.sum(samples, dim=0).all()

In [None]:
THRESHOLD = 0.3
samples = t.rand(NUM_SAMPLE, PROBLEM_DIMENSION, device=device) < THRESHOLD
# while not check_sets_feasibility(samples):
#     samples = t.rand(NUM_SAMPLE, PROBLEM_DIMENSION, device=device) < THRESHOLD

In [None]:
assert check_sets_feasibility(samples), "Problem not solvable"

In [None]:
samples

In [None]:
class Population:
    def __init__(self, population_len, genome_len, samples, is_highest_best=True, genomes=None, crossover=None, mutation_rate = 0.03):
        self.population_len = population_len
        self.genome_len = genome_len
        self.generation = -1
        self.is_highest_best = is_highest_best
        self.genomes = genomes if genomes is not None else t.rand(population_len, genome_len, device=device) >= 0.5
        self.fitness = None
        self.probability = None
        self.mutation_rate = mutation_rate
        self.crossover_function = None
        if crossover == None or crossover == "uniform":
            self.crossover_function = self.uniform_crossover
        elif crossover == "one_point":
            self.crossover_function = self.one_point_crossover
        # 
        self.samples = samples.expand(self.population_len, -1, -1)
        
        self.updatePopulation()
    
    def __str__(self):
        strs = list()
        strs.append(f'Generation: {self.generation}')
        strs.append(f'Genomes: {self.genomes}')
        strs.append(f'Best fitness: {self.get_best_fitness()}, of id: {self.get_best_id()}')
        return '\n'.join(strs)
    
    def updatePopulation(self):
        self.generation += 1
        self.set_fitness()
        self.set_probability()
    
    def get_phenotype(self):
        return t.mul(self.samples, self.genomes.unsqueeze(-1)).sum(dim=1)
        
    
    def set_fitness(self):
        # print("SAMPLES:", self.samples)
        res = self.get_phenotype()
        # print("RES:", res)
        # print("Genomes:", self.genomes)
        used_samples = self.genomes.sum(dim=1)
        # print("Used samples:", used_samples)
        self.fitness = (res == 0).sum(dim=1) * (self.samples.size()[2] + 1) + used_samples
        # print("Fit:", self.fitness)
        
    def set_probability(self):
        if self.is_highest_best:
            self.probability = self.fitness / t.sum(self.fitness)
        else:
            self.probability = 1 / self.fitness
            self.probability.div_(t.sum(self.probability))
        # print("Prob:", self.probability)
        
    def get_best_id(self):
        return t.argmax(self.probability)
    
    def get_best_fitness(self):
        return self.fitness[self.get_best_id()]
    
    def get_best_genome(self):
        return self.genomes[self.get_best_id(), :]
    
    def evolve(self):
        self.crossover_function()
        self.mutation()
        self.updatePopulation()
        # print(self)
        # print("Best fitness:", self.fitness[self.get_best_id()])
        
    def evolve_for_generations(self, generations):
        for _ in range(generations):
            self.evolve()
        print(self)
        # print("Best fitness:", self.get_best_fitness())
        
    def get_parents(self):
        parents = self.probability.expand(self.population_len, self.population_len).multinomial(2)
        # print("Parents:", parents)
        p1 = self.genomes[parents[:,0],:]
        p2 = self.genomes[parents[:,1],:]
        return p1, p2
    
    def one_point_crossover(self):
        raise("To be implemented")
        # p1, p2 = self.get_parents()
        # u = t.rand(self.population_len, device=device) * self.genome_len
        
    def uniform_crossover(self):
        p1, p2 = self.get_parents()
        mask = t.rand(self.population_len, self.genome_len, device=device) >= 0.5
        self.genomes = p1 * mask + p2 * ~mask
        
    def mutation(self):
        mutation = t.rand(self.population_len, self.genome_len, device=device) < self.mutation_rate
        # print("Mutation:", mutation)
        self.genomes = t.where(mutation, ~self.genomes, self.genomes)
        # print("Genomes:", self.genomes)

In [None]:
population_len = 10
population = Population(population_len, NUM_SAMPLE, samples, is_highest_best=False)
print(samples)
print(population)

In [None]:
population.evolve_for_generations(500)

In [None]:
population.get_phenotype()[254]