##Config

In [None]:
class HyperParams:
    vision = 'VAE'
    memory = 'RNN'
    controller = 'A3C'

    extra = False
    data_dir = 'datasets'
    extra_dir = 'additional'
    ckpt_dir = 'ckpt'

    img_height = 96
    img_width = 96
    img_channels = 3

    batch_size = 2 # actually batchsize * Seqlen
    seq_len = 32

    test_batch = 1
    n_sample = 64

    vsize = 128 # latent size of Vision
    msize = 128 # size of Memory
    asize = 3 # action size
    rnn_hunits = 256
    ctrl_hidden_dims = 512
    log_interval = 5000
    save_interval = 10000

    use_binary_feature = False
    score_cut = 300 # to save
    save_start_score = 100

    # Rollout
    max_ep = 1000
    n_rollout = 200
    seed = 0

    n_workers = 0

class RNNHyperParams:
    vision = 'VAE'
    memory = 'RNN'

    extra = False
    data_dir = 'datasets'
    extra_dir = 'additional'
    ckpt_dir = 'ckpt'

    img_height = 96
    img_width = 96
    img_channels = 3

    batch_size = 1 # actually batchsize * Seqlen
    test_batch = 1
    seq_len = 32
    n_sample = 64

    vsize = 128 # latent size of Vision
    msize = 128 # size of Memory
    asize = 3 # action size
    rnn_hunits = 256
    log_interval = 1000
    save_interval = 2000

    max_step = 100000

    n_workers = 0

    seed = 0

class VAEHyperParams:
    vision = 'VAE'

    extra = False
    data_dir = 'datasets'
    extra_dir = 'additional'
    ckpt_dir = 'ckpt'

    img_height = 96
    img_width = 96
    img_channels = 3

    batch_size = 64 #
    test_batch = 12
    n_sample = 64

    vsize = 128 # latent size of Vision
    msize = 128 # size of Memory
    asize = 3 # action size

    log_interval = 5000
    save_interval = 10000

    max_step = 2000000

    n_workers = 0

In [None]:
import os, sys, glob
from os.path import join, exists
from os import mkdir, unlink, listdir, getpid, makedirs

import numpy as np
import easydict
import cma
import gym
from datetime import datetime
from time import sleep
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.multiprocessing as multi
from torch.nn import functional as F
from torch.distributions.normal import Normal
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.multiprocessing import Process, Queue

In [None]:
hp = HyperParams

def rollout():
    env = gym.make("CarRacing-v2")

    seq_len = 1000
    max_ep = hp.n_rollout
    feat_dir = hp.data_dir

    os.makedirs(feat_dir, exist_ok=True)

    for ep in range(max_ep):
        obs_lst, action_lst, reward_lst, next_obs_lst, done_lst = [], [], [], [], []
        env.reset()
        action = env.action_space.sample()
        obs, reward, done, _ = env.step(action)
        done = False
        t = 0

        while not done or t < seq_len:
            t += 1

            action = env.action_space.sample()
            next_obs, reward, done, _ = env.step(action)

            np.savez(
                os.path.join(feat_dir, 'rollout_{:03d}_{:04d}'.format(ep,t)),
                obs=obs,
                action=action,
                reward=reward,
                next_obs=next_obs,
                done=done,
            )

            obs_lst.append(obs)
            action_lst.append(action)
            reward_lst.append(reward)
            next_obs_lst.append(next_obs)
            done_lst.append(done)
            obs = next_obs
        np.savez(
            os.path.join(feat_dir, 'rollout_ep_{:03d}'.format(ep)),
            obs=np.stack(obs_lst, axis=0), # (T, C, H, W)
            action=np.stack(action_lst, axis=0), # (T, a)
            reward=np.stack(reward_lst, axis=0), # (T, 1)
            next_obs=np.stack(next_obs_lst, axis=0), # (T, C, H, W)
            done=np.stack(done_lst, axis=0), # (T, 1)
        )



if __name__ == '__main__':
    np.random.seed(hp.seed)
    rollout()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class GameSceneDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, training=True, test_ratio=0.01):
        self.fpaths = sorted(glob.glob(os.path.join(data_path, 'rollout_[0-9][0-9][0-9]_*.npz')))
        np.random.seed(0)
        indices = np.arange(0, len(self.fpaths))
        n_trainset = int(len(indices)*(1.0-test_ratio))
        self.train_indices = indices[:n_trainset]
        self.test_indices = indices[n_trainset:]
        # self.train_indices = np.random.choice(indices, int(len(indices)*(1.0-test_ratio)), replace=False)
        # self.test_indices = np.delete(indices, self.train_indices)
        self.indices = self.train_indices if training else self.test_indices
        # import pdb; pdb.set_trace()

    def __getitem__(self, idx):
        npz = np.load(self.fpaths[self.indices[idx]])
        obs = npz['obs']
        obs = transform(obs)
        # obs = obs.permute(2, 0, 1) # (N, C, H, W)
        return obs

    def __len__(self):
        return len(self.indices)

class GameEpisodeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, seq_len=32, seq_mode=True, training=True, test_ratio=0.01):
        self.training = training
        self.fpaths = sorted(glob.glob(os.path.join(data_path, 'rollout_ep_*.npz')))
        np.random.seed(0)
        indices = np.arange(0, len(self.fpaths))
        n_trainset = int(len(indices)*(1.0-test_ratio))
        self.train_indices = indices[:n_trainset]
        self.test_indices = indices[n_trainset:]
        # self.train_indices = np.random.choice(indices, int(len(indices)*(1.0-test_ratio)), replace=False)
        # self.test_indices = np.delete(indices, self.train_indices)
        self.indices = self.train_indices if training else self.test_indices
        self.seq_len = seq_len
        self.seq_mode = seq_mode
        # import pdb; pdb.set_trace()

    def __getitem__(self, idx):
        npz = np.load(self.fpaths[self.indices[idx]])
        obs = npz['obs'] # (T, H, W, C) np array
        actions = npz['action'] # (T, n_actions) np array
        T, H, W, C = obs.shape
        n_seq = T // self.seq_len
        end_seq = n_seq * self.seq_len # T' = end of sequence

        obs = obs[:end_seq].reshape([-1, self.seq_len, H, W, C]) # (N_seq, seq_len, H, W, C)
        actions = actions[:end_seq].reshape([-1, self.seq_len, actions.shape[-1]]) #

        # if args.seq_mode:
        #     start_range = max_len-self.seq_len
        #     for t in range(0, max_len-self.seq_len, self.seq_len):
        #         obs[t:t+self.seq_len]
        # else:
        #     rand_start = np.random.randint(max_len-self.seq_len)
        #     obs = obs[rand_start:rand_start+self.seq_len] # (T, H, W, C)
        #     actions = actions[rand_start:rand_start+self.seq_len]
        return obs, actions

    def __len__(self):
        return len(self.indices)

def collate_fn(data):
    # obs (B, N_seq, seq_len, H, W, C), actions (B, N_seq, seq_len, n_actions)
    obs, actions = zip(*data)
    obs, actions = np.array(obs), np.array(actions)
    _, _, seq_len, H, W, C = obs.shape
    obs = obs.reshape([-1, H, W, C]) # (B*N_seq*seq_len, H, W, C)
    actions = actions.reshape([-1, seq_len, actions.shape[-1]]) # (B*n_seq, n_actions)
    obs_lst = []
    for i in range(len(obs)): # batch loop
        obs_lst.append(transform(obs[i]))
        # for j in range(len(obs[i])): # sequence loop
        #     obs_lst.append(transform(obs[i][j]))
    obs = torch.stack(obs_lst, dim=0) # (B*N_seq*seq_len, C, H, W)
    # obs = obs.view([-1, seq_len, H, W, C]) # (B*N_seq, seq_len, C, H, W)
    return obs, torch.tensor(actions, dtype=torch.float)

##VAE

In [None]:
hp = HyperParams

class VAE(nn.Module):
    def __init__(self, latent_dims, img_channels=3):
        super(VAE, self).__init__()
        self.encoder = Encoder(img_channels, latent_dims)
        self.decoder = Decoder(img_channels, latent_dims)

    def forward(self, x):
        mu, logvar = self.encoder(x)
        # sigma = logsigma.exp()
        z = self.reparam(mu, logvar)
        y = self.decoder(z)
        return y, mu, logvar

    def reparam(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.randn_like(std)
        z = eps*std + mu
        # z = eps*sigma + mu
        return z

class Encoder(nn.Module):
    def __init__(self, in_channels, latent_dims):
        super(Encoder, self).__init__()
        # flatten_dims = hp.img_height//2**4
        self.latent_dims = latent_dims
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 64, 24, 24)
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 128, 12, 12)
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 256, 6, 6)
        )
        self.fc = nn.Linear(6*6*256, latent_dims*2)
        self.softplus = nn.Softplus()

    def forward(self, x):
        h = self.encoder(x)
        h = h.view(h.size(0), -1) # (B, d)
        h = self.fc(h) # (B, )
        mu = h[:, :self.latent_dims]
        logvar = h[:, self.latent_dims:]
        # sigma = self.softplus(h[:, self.latent_dims:])
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, out_channels, latent_dims):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dims, 1024)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 6, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 128, 6, 6)
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 64, 12, 12)
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 32, 24, 24)
            nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1),
            nn.LeakyReLU(), # (B, 32, 48, 48)
            nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1),
            # nn.Tanh()
            nn.Sigmoid()
            # nn.LeakyReLU(), # (B, c, 96, 96)
        )

    def forward(self, z):
        h = self.fc(z)
        h = h.view(h.size(0), -1, 2, 2)
        y = self.decoder(h)
        return y


def vae_loss(recon_x, x, mu, logvar):
    """ VAE loss function """
    recon_loss = nn.MSELoss(size_average=False)
    BCE = recon_loss(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD, BCE, KLD

In [None]:
hp = VAEHyperParams

DEVICE = None

def train():
    global_step = 0
    model = VAE(hp.vsize).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Loaded pretrained VAE
    ckpts = sorted(glob.glob(os.path.join(hp.ckpt_dir, 'vae', '*k.pth.tar')))
    if ckpts:
        ckpt = ckpts[-1]
        vae_state = torch.load(ckpt)
        model.load_state_dict(vae_state['model'])
        global_step = int(os.path.basename(ckpt).split('.')[0][:-1]) * 1000
        print('Loaded vae ckpt {}'.format(ckpt))

    data_path = hp.data_dir if not hp.extra else hp.extra_dir
    dataset = GameSceneDataset(data_path)
    loader = DataLoader(
        dataset, batch_size=hp.batch_size, shuffle=True,
        num_workers=hp.n_workers,
    )
    testset = GameSceneDataset(data_path, training=False)
    test_loader = DataLoader(testset, batch_size=hp.test_batch, shuffle=False, drop_last=True)

    ckpt_dir = os.path.join(hp.ckpt_dir, 'vae')
    sample_dir = os.path.join(ckpt_dir, 'samples')
    os.makedirs(sample_dir, exist_ok=True)

    while global_step < hp.max_step:
        for idx, obs in enumerate(tqdm(loader, total=len(loader))):
            x = obs.to(DEVICE)
            x_hat, mu, logvar = model(x)

            loss, recon_loss, kld = vae_loss(x_hat, x, mu, logvar)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step % hp.log_interval == 0:
                recon_loss, kld = evaluate(test_loader, model, sample_dir, global_step)
                now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                with open(os.path.join(ckpt_dir, 'train.log'), 'a') as f:
                    log = '{} || Step: {}, loss: {:.4f}, kld: {:.4f}\n'.format(now, global_step, recon_loss, kld)
                    f.write(log)

            if global_step % hp.save_interval == 0:
                d = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(
                    d, os.path.join(ckpt_dir, '{:03d}k.pth.tar'.format(global_step//1000))
                )
            global_step += 1

def evaluate(test_loader, model, sample_dir=None, global_step=0):
    model.eval()
    total_recon_loss = []
    total_kld_loss = []
    n_sample = hp.n_sample
    c_x = torch.zeros([n_sample, 3, 96, 96])
    c_x_hat = torch.zeros([n_sample, 3, 96, 96])
    with torch.no_grad():
        for idx, obs in enumerate(test_loader):
            x = obs.to(DEVICE)
            # import pdb; pdb.set_trace()
            x_hat, mu, logvar = model(x)
            _, recon_loss, kld = vae_loss(x_hat, x, mu, logvar)

            if idx < n_sample:
                c_x[idx] = x[0]
                c_x_hat[idx] = x_hat[0]
            total_recon_loss.append(recon_loss.item())
            total_kld_loss.append(kld.item())
        z = torch.randn([n_sample, hp.vsize]).to(DEVICE)
        x_rand = model.decoder(z)
    save_image(x_rand, os.path.join(sample_dir, '{:04d}k-random.png'.format(global_step//1000)))
    save_image(c_x_hat, os.path.join(sample_dir, '{:04d}k-xhat.png'.format(global_step//1000)))
    save_image(c_x, os.path.join(sample_dir, '{:04d}k-x.png'.format(global_step//1000)))
    model.train()
    return np.mean(total_recon_loss), np.mean(total_kld_loss)


if __name__ == '__main__':
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    train()

##LSTM

In [None]:
hp = HyperParams

class RNN(nn.Module):
    def __init__(self, n_latents, n_actions, n_hiddens):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(n_latents+n_actions, n_hiddens, batch_first=True)
        # target --> next latent (vision)
        self.fc = nn.Linear(n_hiddens, n_latents)

    def forward(self, states):
        h, _ = self.rnn(states)
        y = self.fc(h)
        return y, None, None

    def infer(self, states, hidden):
        h, next_hidden = self.rnn(states, hidden) # return (out, hx, cx)
        y = self.fc(h)
        return y, None, None, next_hidden

In [None]:
hp = RNNHyperParams

DEVICE = None

def train():
    global_step = 0

    # Loaded pretrained VAE
    vae = VAE(hp.vsize).to(DEVICE)
    ckpt = sorted(glob.glob(os.path.join(hp.ckpt_dir, 'vae', '*k.pth.tar')))[-1]
    vae_state = torch.load(ckpt)
    vae.load_state_dict(vae_state['model'])
    vae.eval()
    print('Loaded vae ckpt {}'.format(ckpt))

    rnn = RNN(hp.vsize, hp.asize, hp.rnn_hunits).to(DEVICE)
    ckpts = sorted(glob.glob(os.path.join(hp.ckpt_dir, 'rnn', '*k.pth.tar')))
    if ckpts:
        ckpt = ckpts[-1]
        rnn_state = torch.load(ckpt)
        rnn.load_state_dict(rnn_state['model'])
        global_step = int(os.path.basename(ckpt).split('.')[0][:-1]) * 1000
        print('Loaded rnn ckpt {}'.format(ckpt))


    data_path = hp.data_dir if not hp.extra else hp.extra_dir
    # optimizer = torch.optim.RMSprop(rnn.parameters(), lr=1e-3)
    optimizer = torch.optim.Adam(rnn.parameters(), lr=1e-4)
    dataset = GameEpisodeDataset(data_path, seq_len=hp.seq_len)
    loader = DataLoader(
        dataset, batch_size=1, shuffle=True, drop_last=True,
        num_workers=hp.n_workers, collate_fn=collate_fn
    )
    testset = GameEpisodeDataset(data_path, seq_len=hp.seq_len, training=False)
    test_loader = DataLoader(
        testset, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn
    )

    ckpt_dir = os.path.join(hp.ckpt_dir, 'rnn')
    sample_dir = os.path.join(ckpt_dir, 'samples')
    os.makedirs(sample_dir, exist_ok=True)

    l1 = nn.L1Loss()

    while global_step < hp.max_step:
        # GO_states = torch.zeros([hp.batch_size, 1, hp.vsize+hp.asize]).to(DEVICE)
        with tqdm(enumerate(loader), total=len(loader), ncols=70, leave=False) as t:
            t.set_description('Step {}'.format(global_step))
            for idx, (obs, actions) in t:
                obs, actions = obs.to(DEVICE), actions.to(DEVICE)
                with torch.no_grad():
                    latent_mu, latent_var = vae.encoder(obs) # (B*T, vsize)
                    z = latent_mu
                    # z = vae.reparam(latent_mu, latent_var) # (B*T, vsize)
                    z = z.view(-1, hp.seq_len, hp.vsize) # (B*n_seq, T, vsize)
                # import pdb; pdb.set_trace()

                next_z = z[:, 1:, :]
                z, actions = z[:, :-1, :], actions[:, :-1, :]
                states = torch.cat([z, actions], dim=-1) # (B, T, vsize+asize)
                # states = torch.cat([GO_states, next_states[:,:-1,:]], dim=1)
                x, _, _ = rnn(states)

                loss = l1(x, next_z)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                global_step += 1

                if global_step % hp.log_interval == 0:
                    eval_loss = evaluate(test_loader, vae, rnn, global_step)
                    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    with open(os.path.join(ckpt_dir, 'train.log'), 'a') as f:
                        log = '{} || Step: {}, train_loss: {:.4f}, loss: {:.4f}\n'.format(now, global_step, loss.item(), eval_loss)
                        f.write(log)
                    S = 2
                    y = vae.decoder(x[S, :, :])
                    v = vae.decoder(next_z[S, :, :])
                    save_image(y, os.path.join(sample_dir, '{:04d}-rnn.png'.format(global_step)))
                    save_image(v, os.path.join(sample_dir, '{:04d}-vae.png'.format(global_step)))
                    save_image(obs[S:S+hp.seq_len-1], os.path.join(sample_dir, '{:04d}-obs.png'.format(global_step)))

                if global_step % hp.save_interval == 0:
                    d = {
                        'model': rnn.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }
                    torch.save(
                        d, os.path.join(ckpt_dir, '{:03d}k.pth.tar'.format(global_step//1000))
                    )

def evaluate(test_loader, vae, rnn, global_step=0):
    rnn.eval()
    total_loss = []
    l1 = nn.L1Loss()
    with torch.no_grad():
        for idx, (obs, actions) in enumerate(test_loader):
            obs, actions = obs.to(DEVICE), actions.to(DEVICE)
            latent_mu, latent_var = vae.encoder(obs) # (B*T, vsize)
            z = latent_mu
            # z = vae.reparam(latent_mu, latent_var) # (B*T, vsize)
            z = z.view(-1, hp.seq_len, hp.vsize) # (B*n_seq, T, vsize)

            next_z = z[:, 1:, :]
            z, actions = z[:, :-1, :], actions[:, :-1, :]
            states = torch.cat([z, actions], dim=-1) # (B, T, vsize+asize)
            # states = torch.cat([GO_states, next_states[:,:-1,:]], dim=1)
            x, _, _ = rnn(states)

            loss = l1(x, next_z)

            total_loss.append(loss.item())
    rnn.train()
    return np.mean(total_loss)


if __name__ == '__main__':
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    np.random.seed(hp.seed)
    train()

##Controller

In [None]:
""" Define controller """

class Controller(nn.Module):
    """ Controller """
    def __init__(self, latents, recurrents, actions):
        super().__init__()
        self.fc = nn.Linear(latents + recurrents, actions)

    def forward(self, *inputs):
        cat_in = torch.cat(inputs, dim=1)
        return self.fc(cat_in)

In [None]:
def flatten_parameters(params):
    """ Flattening parameters.

    :args params: generator of parameters (as returned by module.parameters())

    :returns: flattened parameters (i.e. one tensor of dimension 1 with all
        parameters concatenated)
    """
    return torch.cat([p.detach().view(-1) for p in params], dim=0).cpu().numpy()

def unflatten_parameters(params, example, device):
    """ Unflatten parameters.

    :args params: parameters as a single 1D np array
    :args example: generator of parameters (as returned by module.parameters()),
        used to reshape params
    :args device: where to store unflattened parameters

    :returns: unflattened parameters
    """
    params = torch.Tensor(params).to(device)
    idx = 0
    unflattened = []
    for e_p in example:
        unflattened += [params[idx:idx + e_p.numel()].view(e_p.size())]
        idx += e_p.numel()
    return unflattened

def load_parameters(params, controller):
    """ Load flattened parameters into controller.

    :args params: parameters as a single 1D np array
    :args controller: module in which params is loaded
    """
    proto = next(controller.parameters())
    params = unflatten_parameters(
        params, controller.parameters(), proto.device)

    for p, p_0 in zip(controller.parameters(), params):
        p.data.copy_(p_0)

In [None]:
"""
Training a linear controller on latent + recurrent state
with CMAES.

This is a bit complex. num_workers slave threads are launched
to process a queue filled with parameters to be evaluated.
"""

hp = HyperParams

ctx = multi.get_context("spawn")
queue = ctx.Queue

args = easydict.EasyDict({
    "logdir" : 'ckpt',
    "n_samples" : 4,
    "pop_size" : 4,
    "target_return" : 950,
    "display" : True,
    "max_workers" : 32
})

# multiprocessing variables
n_samples = args.n_samples
pop_size = args.pop_size
num_workers = min(args.max_workers, n_samples * pop_size)
time_limit = 1000


# create tmp dir if non existent and clean it if existent
tmp_dir = join(args.logdir, 'tmp')
if not exists(tmp_dir):
    makedirs(tmp_dir)
else:
    for fname in listdir(tmp_dir):
        unlink(join(tmp_dir, fname))

# create ctrl dir if non exitent
ctrl_dir = join(args.logdir, 'cma')
if not exists(ctrl_dir):
    makedirs(ctrl_dir)


################################################################################
#                           Thread routines                                    #
################################################################################
def slave_routine(p_queue, r_queue, e_queue, p_index):
    """ Thread routine.

    Threads interact with p_queue, the parameters queue, r_queue, the result
    queue and e_queue the end queue. They pull parameters from p_queue, execute
    the corresponding rollout, then place the result in r_queue.

    Each parameter has its own unique id. Parameters are pulled as tuples
    (s_id, params) and results are pushed as (s_id, result).  The same
    parameter can appear multiple times in p_queue, displaying the same id
    each time.

    As soon as e_queue is non empty, the thread terminate.

    When multiple gpus are involved, the assigned gpu is determined by the
    process index p_index (gpu = p_index % n_gpus).

    :args p_queue: queue containing couples (s_id, parameters) to evaluate
    :args r_queue: where to place results (s_id, results)
    :args e_queue: as soon as not empty, terminate
    :args p_index: the process index
    """
    sys.stdout = sys.__stdout__
    sys.stdout.write("hello")
    # init routine
    #gpu = p_index % torch.cuda.device_count()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # redirect streams
    sys.stdout = open(join(tmp_dir, str(getpid()) + '.out'), 'a')
    sys.stderr = open(join(tmp_dir, str(getpid()) + '.err'), 'a')

    with torch.no_grad():
        r_gen = RolloutGenerator(args.logdir, device, time_limit)

        while e_queue.empty():
            if p_queue.empty():
                sleep(.1)
            else:
                s_id, params = p_queue.get()
                r_queue.put((s_id, r_gen.rollout(params)))


################################################################################
#                Define queues and start workers                               #
################################################################################
p_queue = ctx.Queue()
r_queue = ctx.Queue()
e_queue = ctx.Queue()

for p_index in range(num_workers):
    ctx.Process(target=slave_routine, args=(p_queue, r_queue, e_queue, p_index)).start()


################################################################################
#                           Evaluation                                         #
################################################################################
def evaluate(solutions, results, rollouts=100):
    """ Give current controller evaluation.

    Evaluation is minus the cumulated reward averaged over rollout runs.

    :args solutions: CMA set of solutions
    :args results: corresponding results
    :args rollouts: number of rollouts

    :returns: minus averaged cumulated reward
    """
    index_min = np.argmin(results)
    best_guess = solutions[index_min]
    restimates = []


    for s_id in range(rollouts):
        p_queue.put((s_id, best_guess))

    print("Evaluating...")

    for _ in tqdm(range(rollouts)):
        while r_queue.empty():
            sleep(.1)
        restimates.append(r_queue.get()[1])

    return best_guess, np.mean(restimates), np.std(restimates)

################################################################################
#                           Launch CMA                                         #
################################################################################
controller = Controller(hp.vsize, hp.msize, hp.asize)  # dummy instance

# define current best and load parameters
cur_best = None
ctrl_file = join(ctrl_dir, 'best.tar')
print("Attempting to load previous best...")
if exists(ctrl_file):
    state = torch.load(ctrl_file, map_location={'cuda:0': 'cpu'})
    cur_best = - state['reward']
    controller.load_state_dict(state['state_dict'])
    print("Previous best was {}...".format(-cur_best))

parameters = controller.parameters()
es = cma.CMAEvolutionStrategy(flatten_parameters(parameters),  # number of model parameters
                              0.1,                                # initial standard deviation
                              {'popsize': pop_size})               # population size

epoch = 0
log_step = 3
while not es.stop():
    if cur_best is not None and - cur_best > args.target_return:
        print("Already better than target, breaking...")
        break

    r_list = [0] * pop_size  # result list
    solutions = es.ask()

    # push parameters to queue
    for s_id, s in enumerate(solutions):
        for _ in range(n_samples):
            p_queue.put((s_id, s))

    # retrieve results
    if args.display:
        pbar = tqdm(total=pop_size * n_samples)

    for _ in range(pop_size * n_samples):
        while r_queue.empty():
            sleep(.1)
        r_s_id, r = r_queue.get()
        r_list[r_s_id] += r / n_samples

        if args.display:
            pbar.update(1)

    if args.display:
        pbar.close()

    es.tell(solutions, r_list)
    es.disp()


    # evaluation and saving
    if epoch % log_step == log_step - 1:
        best_params, best, std_best = evaluate(solutions, r_list)
        print("Current evaluation: {}".format(best))
        if not cur_best or cur_best > best:
            cur_best = best
            print("Saving new best with value {}+-{}...".format(-cur_best, std_best))
            load_parameters(best_params, controller)
            torch.save(
                {'epoch': epoch,
                 'reward': - cur_best,
                 'state_dict': controller.state_dict()},
                join(ctrl_dir, 'best.tar'))
        if - best > args.target_return:
            print("Terminating controller training with value {}...".format(best))
            break


    epoch += 1

es.result_pretty()
e_queue.put('EOP')

In [None]:
class RolloutGenerator(object):
    """ Utility to generate rollouts.

    Encapsulate everything that is needed to generate rollouts in the TRUE ENV
    using a controller with previously trained VAE and MDRNN.

    :attr vae: VAE model loaded from mdir/vae
    :attr mdrnn: MDRNN model loaded from mdir/mdrnn
    :attr controller: Controller, either loaded from mdir/ctrl or randomly
        initialized
    :attr env: instance of the CarRacing-v0 gym environment
    :attr device: device used to run VAE, MDRNN and Controller
    :attr time_limit: rollouts have a maximum of time_limit timesteps
    """
    def __init__(self, mdir, device, time_limit):
        """ Build vae, rnn, controller and environment. """
        # Loading world model and vae
        '''
        vae_file, rnn_file, ctrl_file = \
            [join(mdir, m, 'best.tar') for m in ['vae', 'rnn', 'cma']]
        '''

        vae_file = []
        vae_file = vae_file.append(join(mdir, 'vae', '002k.pth.tar'))
        rnn_file = []
        rnn_file = rnn_file.append(join(mdir, 'rnn', '001k.pth.tar'))
        cma_file = []
        cma_file = cma_file.append(join(mdir, 'cma', 'best.tar'))


        assert exists(vae_file) and exists(rnn_file),\
            "Either vae or mdrnn is untrained."

        vae_state, rnn_state = [
            torch.load(fname, map_location={'cuda:0': str(device)})
            for fname in (vae_file, rnn_file)]

        for m, s in (('VAE', vae_state), ('RNN', rnn_state)):
            print("Loading {} at epoch {} "
                  "with test loss {}".format(
                      m, s['epoch'], s['precision']))

        self.vae = VAE(hp.vsize, 3).to(device)
        self.vae.load_state_dict(vae_state['state_dict'])

        self.rnn = RNN(hp.vsize, hp.asize, hp.msize).to(device)
        self.rnn.load_state_dict(
            {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()})

        self.controller = Controller(hp.vsize, hp.msize, hp.asize).to(device)

        # load controller if it was previously saved
        if exists(cma_file):
            cma_state = torch.load(cma_file, map_location={'cuda:0': str(device)})
            print("Loading Controller with reward {}".format(
                cma_state['reward']))
            self.controller.load_state_dict(cma_state['state_dict'])

        self.env = gym.make('CarRacing-v2', render_mode='human')
        self.device = device

        self.time_limit = time_limit

    def get_action_and_transition(self, obs, hidden):
        """ Get action and transition.

        Encode obs to latent using the VAE, then obtain estimation for next
        latent and next hidden state using the MDRNN and compute the controller
        corresponding action.

        :args obs: current observation (1 x 3 x 64 x 64) torch tensor
        :args hidden: current hidden state (1 x 256) torch tensor

        :returns: (action, next_hidden)
            - action: 1D np array
            - next_hidden (1 x 256) torch tensor
        """
        _, latent_mu, _ = self.vae(obs)
        action = self.controller(latent_mu, hidden[0])
        _, _, _, _, _, next_hidden = self.rnn(action, latent_mu, hidden)
        return action.squeeze().cpu().numpy(), next_hidden

    def rollout(self, params, render=False):
        """ Execute a rollout and returns minus cumulative reward.

        Load :params: into the controller and execute a single rollout. This
        is the main API of this class.

        :args params: parameters as a single 1D np array

        :returns: minus cumulative reward
        """
        # copy params into the controller
        if params is not None:
            load_parameters(params, self.controller)

        obs = self.env.reset()

        # This first render is required !
        self.env.render()

        hidden = [
            torch.zeros(1, hp.msize).to(self.device)
            for _ in range(2)]

        cumulative = 0
        i = 0
        while True:
            obs = transform(obs).unsqueeze(0).to(self.device)
            action, hidden = self.get_action_and_transition(obs, hidden)
            obs, reward, done, _ = self.env.step(action)

            if render:
                self.env.render()

            cumulative += reward
            if done or i > self.time_limit:
                return - cumulative
            i += 1

In [None]:
""" Test controller """

args = easydict.EasyDict({
    "logdir" : 'ckpt'
})

cma_file = join(args.logdir, 'cma', 'best.tar')

assert exists(cma_file),\
    "Controller was not trained..."

device = torch.device('cuda')

generator = RolloutGenerator(args.logdir, device, 1000)

with torch.no_grad():
    generator.rollout(None)