This notebook is a modified copy of [this work](https://github.com/uber-research/ga-world-models) by Sebastian Risi and Kenneth O. Stanley.

In [None]:
import copy
from multiprocessing import Pool

import gym
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

Some hyperparameters:

In [None]:
DISCRETE_VAE = True
LATENT_SIZE = 128 if DISCRETE_VAE else 32
ACTION_SIZE = 3
HIDDEN_SIZE = 256
OBS_SIZE = 64
SIZE = 64
MUT_POW = 0.01

### VAE (V model)

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_size):
        super().__init__()
        self.latent_size = latent_size

        self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
        self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
        self.flatten = nn.Flatten()

        self.fc_mu = nn.Linear(2 * 2 * 256, latent_size)
        self.fc_logsigma = nn.Linear(2 * 2 * 256, latent_size)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logsigma = self.fc_logsigma(x)
        return mu, logsigma


class Decoder(nn.Module):
    def __init__(self, latent_size):
        super().__init__()
        self.latent_size = latent_size

        self.fc1 = nn.Linear(latent_size, 1024)
        self.deconv1 = nn.ConvTranspose2d(1024, 128, 5, stride=2)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
        self.deconv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
        self.deconv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = x.view(x.size(0), x.size(1), 1, 1)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        x = torch.sigmoid(self.deconv4(x))
        return x


class VAE(nn.Module):
    def __init__(self, latent_size):
        super().__init__()
        self.encoder = Encoder(latent_size)
        self.decoder = Decoder(latent_size)

    def encode(self, x, reparameterize=False):
        mu, logsigma = self.encoder(x)
        if reparameterize:
            sigma = logsigma.exp()
            eps = torch.randn_like(sigma)
            z = eps.mul(sigma).add_(mu)
            return z
        return mu

    def forward(self, x):
        mu, logsigma = self.encoder(x)
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(mu)
        recon_x = self.decoder(z)
        return recon_x, mu, logsigma

### MDN-RNN (M model)

In [None]:
class MDNRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, mixture_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.mixture_size = mixture_size

        self.rnn = nn.LSTMCell(latent_size + actions, hiddens)

        self.pi = nn.Linear(hidden_size, mixture_size)
        self.mu = nn.Linear(hidden_size, mixture_size * output_size)
        self.logsigma = nn.Linear(hidden_size, mixture_size * output_size)

    def forward(self, x, hidden):
        h, c = self.rnn(x, hidden)

        logpi = F.log_softmax(self.pi(h).view(-1, self.mixture_size), dim=1)
        mu = self.mu(h).view(-1, self.mixture_size, self.output_size)
        sigma = self.logsigma(h).exp().view(-1, self.mixture_size, self.output_size)

        return logpi, mu, sigmas, (h, c)

### Controller

In [None]:
class Controller(nn.Module):
    def __init__(self, latent_size, hidden_size, action_size):
        super().__init__()
        self.fc = nn.Linear(latent_size + hidden_size, action_size)

    def forward(self, z, h):
        x = torch.cat([z, h], dim=1)
        return self.fc(x)

### Genetic Algorithm

In [None]:
class RolloutGenerator:
    def __init__(self, time_limit):
        self.time_limit = time_limit

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((OBS_SIZE, OBS_SIZE)),
            transforms.ToTensor(),
        ])
        
        # environment
        self.env = gym.make("CarRacing-v0")

        # agent
        self.vae = VAE(LATENT_SIZE)
        self.mdnrnn = MDNRNNCell(LATENT_SIZE + ACTION_SIZE, HIDDEN_SIZE, LATENT_SIZE, 5)
        self.controller = Controller(LATENT_SIZE, HIDDEN_SIZE, ACTION_SIZE)

    def get_action(self, obs, hidden):
        with torch.no_grad():
            latent_mu = self.vae.encode(obs, reparameterize=False)

            if DISCRETE_VAE:
                latent_mu = torch.tanh(latent_mu)
                bins = np.array([-1.0, 0.0, 1.0])
                newdata = bins[np.digitize(latent_mu, bins[1:])] + 1
                latent_mu = torch.from_numpy(newdata).float()
                
            # steering: real valued in [-1, 1]
            # gas: real valued in [0, 1]
            # break: real valued in [0, 1]
            action = self.controller(latent_mu, hidden[0])

            rnn_input = torch.cat([latent_mu, action], dim=1)
            logpi, mu, sigma, hidden = self.mdnrnn(rnn_input, hidden)
            # NOTE[jinyeom]: the MDN head doesn't do anything...

        return action, hidden
    
    def rollout(self, render=False, early_termination=True):
        self.env = gym.make("CarRacing-v0")

        obs = self.env.reset()
        hidden = (torch.zeros(1, HIDDEN_SIZE),  # h
                  torch.zeros(1, HIDDEN_SIZE))  # c

        fitness = 0  # reward sum
        neg_count = 0  # for early termination

        t = 0
        done = False

        # NOTE[jinyeom]: I don't think this is necessary?
        # self.env.render("rgb_array")

        while not done:
            if render:
                self.env.render("human")

            obs = self.transform(obs).unsqueeze(0)
            action, hidden = self.get_action(obs, hidden)
            action = action.squeeze().cpu().numpy()
            
            obs, reward, done, _ = self.env.step(action)

            fitness += reward
            neg_count += int(reward < 0)

            # early termination for speeding up evaluation
            if early_termination and (neg_count > 20 or t > self.time_limit):
                done = True

            t += 1

        self.env.close()
        return fitness

In [None]:
class Individual:
    def __init__(self, time_limit, mut_mode):
        self.time_limit = time_limit
        self.mut_mode = mut_mode

        self.r_gen = RolloutGenerator(time_limit)
        self.async_results = []
        self.calculated_results = {}

        self.fitness = None
        self.is_elite = None

    def run_solution(self, pool, num_evals=5, early_termination=True, force_eval=False):
        if force_eval:
            # remove existing results, so that it can be evaluated again
            self.calculated_results.pop(num_evals, None)

        if num_evals not in self.calculated_results:
            self.async_results = []
            for i in range(num_evals):
                results = pool.apply_async(self.r_gen.rollout, args=(False, early_termination))
                self.async_results.append(results)
            
    def evaluate_solution(self, num_evals):
        if num_evals in self.calculated_results:
            mean_fitness, std_fitness = self.calculated_results[num_evals]
        else:
            results = [t.get() for t in self.async_results]
            mean_fitness = np.mean(results)
            std_fitness = np.std(results)
            self.calculated_results[num_evals] = (mean_fitness, std_fitness)

        self.fitness = -mean_fitness

        return mean_fitness, std_fitness
    
    def load_solution(self, filename):
        state_dicts = torch.load(filename)
        self.r_gen.vae.load_state_dict(state_dicts["vae"])
        self.r_gen.mdnrnn.load_state_dict(state_dicts["mdnrnn"])        
        self.r_gen.controller.load_state_dict(state_dicts["controller"])
        
    def clone_individual(self):
        child_solution = Individual(self.time_limit, self.mut_mode)

        child_solution.fitness = self.fitness
        child_solution.r_gen.controller = copy.deepcopy(self.r_gen.controller)
        child_solution.r_gen.vae = copy.deepcopy(self.r_gen.vae)
        child_solution.r_gen.mdrnn = copy.deepcopy(self.r_gen.mdnrnn)

        return child_solution
    
    def mutate_params(self, params):
        for key in params:
            noise = np.random.normal(0, 1, params[key].size()) * MUT_POW
            params[key] += torch.from_numpy(noise).float()
            
    def mutate(self):
        if self.mut_mode == 0:  # MUT-ALL
            self.mutate_params(self.r_gen.controller.state_dict())
            self.mutate_params(self.r_gen.vae.state_dict())
            self.mutate_params(self.r_gen.mdnrnn.state_dict())

        if self.mut_mode == 1:  # MUT-MOD
            c = np.random.randint(3)
            if c == 0:
                self.mutate_params(self.r_gen.vae.state_dict())
            elif c == 1:
                self.mutate_params(self.r_gen.mdnrnn.state_dict())
            else:
                self.mutate_params(self.r_gen.controller.state_dict())

In [None]:
class GA:
    def __init__(self, pop_size, num_workers, topk, num_evals, mut_mode):
        self.pop_size = pop_size
        self.truncation_threshold = pop_size // 2
        self.num_workers = num_workers
        self.topk = topk
        self.num_evals = num_evals
        self.mut_mode = mut_mode
        
        self.population = [Individual(timelimit, mut_mode) for _ in range(pop_size)]
        
    def run(self, num_gens):
        P = self.population
        Q = []  # offsprings
        
        while True:
            pool = Pool(self.num_workers)
            
            for ind in P:
                ind.run_solution(pool, num_evals=1, force_eval=True)
                
            fitness = []

            for ind in P:
                ind.is_elite = False
                f, _ = ind.evaluate_solution(1)
                fitness.append(f)
                
            P = sorted(P, key=lambda ind: ind.fitness)