In [1]:
from collections import namedtuple
import gym
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import TensorDataset, DataLoader

In [2]:
class policy_network(nn.Module):
    # MLP softmax output
    def __init__(self, state_size, hidden_list, action_size):
        super().__init__()
        self.layers = nn.ModuleList()
        prev_layer = state_size
        for layer in hidden_list:
            self.layers.append(nn.Linear(prev_layer, layer))
            self.layers.append(nn.ReLU())
            prev_layer = layer
        self.layers.append(nn.Linear(prev_layer, action_size))                      

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return F.softmax(x, dim=-1)
        
class value_network(nn.Module):
    # MLP no softmax on output
    def __init__(self, state_size, hidden_list, output_size):
        super().__init__()
        self.layers = nn.ModuleList()
        prev_layer = state_size
        for layer in hidden_list:
            self.layers.append(nn.Linear(prev_layer, layer))
            self.layers.append(nn.ReLU())
            prev_layer = layer
        self.layers.append(nn.Linear(prev_layer, output_size))     

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
class transition_network(nn.Module):
    # MLP with forward conditioned on action; output of size of input
    def __init__(self, state_size, hidden_list, action_size, output_size):
        super().__init__()
        self.action_dim = action_size
        self.output_size = output_size
        self.layers = nn.ModuleList()
        prev_layer = state_size
        for idx, layer in enumerate(hidden_list):
            if idx == 0:
                if action_size == 2:
                    self.layers.append(nn.Linear(prev_layer+1, layer))
                else:
                    self.layers.append(nn.Linear(prev_layer+action_size, layer))
            if idx > 0:
                self.layers.append(nn.Linear(prev_layer, layer))
            self.layers.append(nn.ReLU())
            prev_layer = layer
        self.layers.append(nn.Linear(prev_layer, output_size))
        
    def forward(self, x, a, eval_mode=False):
        if self.action_dim == 2:
            out = torch.cat((x, a), 1)
        if self.action_dim != 2:
            extra = torch.zeros([x.size(0), self.action_dim], dtype=torch.float32)
            a_ = torch.tensor(a, dtype=torch.int64)
            extra = extra.scatter(1,a_,1)
            out = torch.cat((x, extra), 1)
        for layer in self.layers:
            out = layer(out)
        if eval_mode is True:
            out = torch.cat([out, x[:,:3*self.output_size]], 1)
        return out

class reward_network(nn.Module):
    # MLP with forward conditioned on action; output of size 1
    def __init__(self, state_size, hidden_list, action_size):
        super().__init__()
        self.action_dim = action_size
        self.layers = nn.ModuleList()
        prev_layer = state_size
        for idx, layer in enumerate(hidden_list):
            if idx == 0:
                if action_size == 2:
                    self.layers.append(nn.Linear(prev_layer+1, layer))
                else:
                    self.layers.append(nn.Linear(prev_layer+action_size, layer))
            if idx > 0:
                self.layers.append(nn.Linear(prev_layer, layer))
            self.layers.append(nn.ReLU())
            prev_layer = layer
        self.layers.append(nn.Linear(prev_layer, 1))   
        
    def forward(self, x, a):
        if self.action_dim == 2:
            x = torch.cat((x, a), 1)
        else:
            extra = torch.zeros([x.size(0), self.action_dim], dtype=torch.float32)
            a_ = torch.tensor(a, dtype=torch.int64)
            extra = extra.scatter(1,a_,1)
            x = torch.cat((x, extra), 1)
        for layer in self.layers:
            x = layer(x)
        return x
    
class VAE(nn.Module):
    def __init__(self, state, hidden, latent):
        super(VAE, self).__init__()
        self.e1 = nn.Linear(state, hidden)
        self.e21 = nn.Linear(hidden, latent)
        self.e22 = nn.Linear(hidden, latent)
        self.d1 = nn.Linear(latent, hidden)
        self.d2 = nn.Linear(hidden, state)
        self.latent = latent

    def encode(self, x):
        out = F.relu(self.e1(x))
        mu = self.e21(out)
        logvar = self.e22(out)
        return mu, logvar

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(mu.size()).normal_()
        z = mu + eps*std
        return z

    def decode(self, z):
        out = F.relu(self.d1(z))
        out = torch.sigmoid(self.d2(out))
        return out

    def forward(self, x, encoding_only=False, training=True):
        mu, logvar = self.encode(x)
        if training==True:
            z = self.reparametrize(mu, logvar)
        else:
            z = mu
        decoded = self.decode(z)
        if encoding_only==False:
            return decoded, mu, logvar
        else:
            return mu
        
def vae_loss(decoded_, x, mu, logvar, beta=1):
    loss_r = nn.BCELoss(size_average=False)
    l_1 = loss_r(decoded_, x)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    l_2 = torch.sum(KLD_element).mul_(-0.5)
    return l_1 + beta*l_2

In [3]:
class rollouts_wm(object):
    def __init__(self, batch_size, discount, experience_lenght):
        self.batch_size = batch_size
        self.rollout_memory = []
        self.rollout_values = []
        self.batch = []
        self.discount = discount
        self.transition = namedtuple('transition', ('state', 'action', 'next_state', 'reward', 'terminal'))
        self.experience_buffer = []
        self.experience_lenght = experience_lenght

    def init_episode(self):
        self.trajectory = []
        
    def push_to_trajectory(self, state, action, next_state, reward, terminal):
        self.trajectory.append(self.transition(state, action, next_state, reward, terminal))
    
    def monte_carlo(self, value_network=None):
        values = []
        if value_network is None:
            for idx, state in enumerate(self.trajectory):
                value = 0 
                for idx_, state_ in enumerate(self.trajectory[idx:]):
                    value += self.discount**idx_ * state_.reward
                values.append(value)
        else:
            final_state = self.trajectory[-1].next_state
            with torch.no_grad():
                final_state_val = value_network.forward(final_state)
            for idx, state in enumerate(self.trajectory):
                value = 0 
                for idx_, state_ in enumerate(self.trajectory[idx:]):
                    value += self.discount**idx_ * state_.reward
                    if (idx + idx_ + 1) == len(self.trajectory):
                        value += self.discount**(idx_+1) * final_state_val
                values.append(value)
        return values
    
    def push_to_memory(self, value_network=None):
        if value_network is None:
            values = self.monte_carlo()
        else:
            values = self.monte_carlo(value_network)
        self.rollout_memory = self.rollout_memory + self.trajectory
        self.rollout_values = self.rollout_values + values
        self.experience_buffer = self.experience_buffer + self.trajectory
        if len(self.experience_buffer) > self.experience_lenght:
            del self.experience_buffer[:(len(self.experience_buffer)-self.experience_lenght+1)]              

    def sample_data(self, policy_mode=True):
        if policy_mode is True:
            batch = self.rollout_memory[:self.batch_size]
            batch = self.transition(*zip(*batch))
            self.rollout_memory = self.rollout_memory[self.batch_size:]
            values = torch.cat(self.rollout_values[:self.batch_size])
            self.rollout_values = self.rollout_values[self.batch_size:]
            state = torch.cat(batch.state)
            return state, values
        if policy_mode is False:
            sample_ = random.sample(self.experience_buffer, self.batch_size)
            batch = self.transition(*zip(*sample_))
            #terminal = torch.cat(batch.terminal)
            state = torch.cat(batch.state)
            action = torch.cat(batch.action)
            reward = torch.cat(batch.reward)
            #reward = reward * terminal
            new_state = torch.cat(batch.next_state)
            return state, action, reward, new_state

In [4]:
class wmpg(object):
    def __init__(self, state_size, action_size, batch_size, discount, policy_net, value_net, transition_net, reward_net, encoder,
                 value_updates=3, policy_updates=3, wm_updates=3, imagination_horizon=5, imagination_lambda=0.75,
                 wm_memory=2000, policy_lr=0.0025, value_lr=0.0025, transition_lr=0.005, reward_lr=0.005, clip=None, 
                 scheduler_step=None, termination_length=None, k=None, entropy_coef=None):
        self.memory = rollouts_wm(batch_size, discount, wm_memory)
        self.policy_net = policy_net
        self.value_net = value_net
        self.transition_net = transition_net
        self.reward_net = reward_net
        self.o_p = optim.RMSprop(self.policy_net.parameters(), lr=policy_lr, eps=1e-5)
        self.o_v = optim.RMSprop(self.value_net.parameters(), lr=value_lr, eps=1e-5)
        self.o_t = optim.Adam(self.transition_net.parameters(), lr=transition_lr)
        self.o_r = optim.Adam(self.reward_net.parameters(), lr=reward_lr)
        self.loss_v = torch.nn.L1Loss()
        self.loss_t = torch.nn.L1Loss()
        self.loss_r = torch.nn.L1Loss()
        self.state_size = state_size
        self.action_size = action_size
        self.batch_size = batch_size
        self.discount = discount
        self.value_updates = int(value_updates)
        self.wm_updates = int(wm_updates)
        self.policy_updates = int(policy_updates)
        self.imagination_horizon = imagination_horizon
        self.imagination_lambda = imagination_lambda
        self.cum_rewards = 0
        self.clip = clip
        if scheduler_step is not None:
            self.scheduler_t = optim.lr_scheduler.StepLR(self.o_t, step_size=scheduler_step, gamma=0.99)
            self.scheduler_r = optim.lr_scheduler.StepLR(self.o_r, step_size=scheduler_step, gamma=0.99)
            self.scheduler_v = optim.lr_scheduler.StepLR(self.o_v, step_size=scheduler_step, gamma=0.99)
            self.scheduler_p = optim.lr_scheduler.StepLR(self.o_p, step_size=scheduler_step, gamma=0.99)
        self.scheduler_step = scheduler_step
        self.termination_length = termination_length
        self.k = k
        self.encoder = encoder
        self.entropy_coef = entropy_coef
        
    def train_wm(self):
        for i in range(self.wm_updates):
            states, actions, rewards, new_states = self.memory.sample_data(False)
            self.o_t.zero_grad()
            transitions = self.transition_net.forward(states, actions)
            loss_t_ = self.loss_t(transitions, new_states[:,:16])
            loss_t_.backward()
            self.o_t.step()
            self.o_r.zero_grad()
            rewards_ = self.reward_net.forward(states, actions)
            loss_r_ = self.loss_r(rewards_, rewards)
            loss_r_.backward()
            self.o_r.step()
            
    def train_value(self, states, values):
        for i in range(self.value_updates):
            self.o_v.zero_grad()
            values_ = self.value_net.forward(states)
            loss_v_ = self.loss_v(values_, values)
            loss_v_.backward()
            if self.clip is not None:
                torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), self.clip)
            self.o_v.step()
    
    def train_policy(self, states):
        self.o_p.zero_grad()
        probabilities = self.policy_net.forward(states)
        q_values = self.imagine_values(states)
        if self.imagination_horizon != 1:
            baseline = torch.sum(q_values * probabilities.detach(), dim=1).unsqueeze(-1)
            expected_values = torch.sum((q_values - baseline) * probabilities, dim=1)
        if self.imagination_horizon == 1:
            expected_values = torch.sum(q_values * probabilities, dim=1)
        loss_pol = -torch.mean(expected_values, dim=0)
        if self.entropy_coef is not None:
            loss_e = self.entropy_coef * self.entropy(probabilities)
            loss_final = loss_pol - loss_e
            loss_final.backward()
        if self.entropy_coef is None:
            loss_pol.backward()
        if self.clip is not None:
            torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip)
        self.o_p.step()
            
    def schedulers_step(self):
        if self.scheduler_step is not None:
            self.scheduler_t.step()
            self.scheduler_r.step()
            self.scheduler_v.step()
            self.scheduler_p.step()
        
    def train_networks(self):
        state_batch, value_batch = self.memory.sample_data()
        for i_g in range(self.policy_updates):    
            self.train_wm()
            self.train_value(state_batch, value_batch)
            if self.k is None:
                self.train_policy(state_batch)
            if self.k == 1:
                self.train_policy_1(state_batch)
        self.schedulers_step()
            
    def imagine_values(self, states):
        q_values = torch.zeros((self.batch_size, self.action_size), dtype=torch.float32)
        if self.imagination_lambda == 1:
            for i in range(self.action_size):
                transitions = states
                actions = torch.zeros([self.batch_size, 1], dtype=torch.float32)+i
                running_rewards = torch.zeros([self.batch_size, 1], dtype=torch.float32)
                with torch.no_grad():
                    for j in range(self.imagination_horizon):
                        reward = self.reward_net.forward(transitions, actions)
                        running_rewards += self.discount**j * reward
                        transitions = self.transition_net.forward(transitions, actions, True)
                        probabilities = self.policy_net.forward(transitions).detach().numpy()
                        actions = self.sample_from_matrix(probabilities)
                    running_rewards += self.discount**(self.imagination_horizon) * self.value_net.forward(transitions)
                q_values[:,i] = running_rewards.squeeze()
        if self.imagination_lambda != 1:
            for i in range(self.action_size):
                transitions = states
                actions = torch.zeros([self.batch_size, 1], dtype=torch.float32)+i
                running_rewards = torch.zeros([self.batch_size, 1], dtype=torch.float32)
                running_lambdas = torch.zeros([self.batch_size, 1], dtype=torch.float32)
                with torch.no_grad():
                    for j in range(self.imagination_horizon):
                        rewards = self.reward_net.forward(transitions, actions)
                        running_rewards += self.discount**j * rewards
                        transitions = self.transition_net.forward(transitions, actions, True)
                        values = running_rewards + self.discount**(j+1) * self.value_net.forward(transitions)
                        if (j+1) == self.imagination_horizon:
                            running_lambdas += self.imagination_lambda**(j)*values
                        else:
                            running_lambdas += (1 - self.imagination_lambda)*self.imagination_lambda**j*values
                        probabilities = self.policy_net.forward(transitions).detach().numpy()
                        actions = self.sample_from_matrix(probabilities)
                q_values[:,i] = running_lambdas.squeeze()
        return q_values
    
    def train_policy_1(self, states):
        self.o_p.zero_grad()
        probabilities = self.policy_net.forward(states)
        actions = self.sample_from_matrix(probabilities.detach().numpy())
        log_probabilities = torch.log(probabilities.gather(1, torch.tensor(actions, dtype=torch.int64)))
        q_values = self.imagine_values_1(states, actions)
        with torch.no_grad():
            baseline = self.value_net.forward(states).detach()
        expected_values = torch.sum((q_values - baseline) * log_probabilities, dim=1)
        loss_pol = -torch.mean(expected_values, dim=0)
        if self.entropy_coef is not None:
            loss_e = self.entropy_coef * self.entropy(probabilities)
            loss_final = loss_pol - loss_e
            loss_final.backward()
        if self.entropy_coef is None:
            loss_pol.backward()
        if self.clip is not None:
            torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip)
        self.o_p.step()
        
    def imagine_values_1(self, states, actions):
        q_values = torch.zeros((self.batch_size, 1), dtype=torch.float32)
        if self.imagination_lambda == 1:
            transitions = states
            running_rewards = torch.zeros([self.batch_size, 1], dtype=torch.float32)
            with torch.no_grad():
                for j in range(self.imagination_horizon):
                    reward = self.reward_net.forward(transitions, actions)
                    running_rewards += self.discount**j * reward
                    transitions = self.transition_net.forward(transitions, actions, True)
                    probabilities = self.policy_net.forward(transitions).detach().numpy()
                    actions = self.sample_from_matrix(probabilities)
                running_rewards += self.discount**(self.imagination_horizon) * self.value_net.forward(transitions)
            q_values[:,0] = running_rewards.squeeze()
        if self.imagination_lambda != 1:
            transitions = states
            running_rewards = torch.zeros([self.batch_size, 1], dtype=torch.float32)
            running_lambdas = torch.zeros([self.batch_size, 1], dtype=torch.float32)
            with torch.no_grad():
                for j in range(self.imagination_horizon):
                    rewards = self.reward_net.forward(transitions, actions)
                    running_rewards += self.discount**j * rewards
                    transitions = self.transition_net.forward(transitions, actions, True)
                    values = running_rewards + self.discount**(j+1) * self.value_net.forward(transitions)
                    if (j+1) == self.imagination_horizon:
                        running_lambdas += self.imagination_lambda**(j)*values
                    else:
                        running_lambdas += (1 - self.imagination_lambda)*self.imagination_lambda**j*values
                    probabilities = self.policy_net.forward(transitions).detach().numpy()
                    actions = self.sample_from_matrix(probabilities)
            q_values[:,0] = running_lambdas.squeeze()
        return q_values
                
    def sample_from_matrix(self, probs):
        cumulative = probs.cumsum(axis=1)
        uniform_samples = np.random.rand(len(cumulative), 1)
        samples = (uniform_samples < cumulative).argmax(axis=1).astype(np.float32)
        samples = torch.tensor(samples.reshape(np.shape(samples)[0],1))
        return samples.float()
    
    def preprocess(self, state):
        state = state[35:195]
        state = state[::2,::2,0]
        state[state == 144] = 0
        state[state == 109] = 0
        state[state != 0] = 1
        state = self.enlarge_ball(state)
        return state.astype(np.float).ravel()
    
    def enlarge_ball(self, image):
        image = np.copy(image)
        for i in range(1,79):
            for j in range(1,79):
                if image[i,j]==1:
                    if image[i,j+1]==0 and image[i,j-1]==0:
                        if image[i+1,j]==1:
                            image[i,j-1] = 1
                            image[i+1,j-1] = 1
                            image[i,j+1] = 1
                            image[i+1,j+1] = 1
                            if i!=0:
                                image[i-1,j-1:j+2] = 1
                                if i!=1:
                                    image[i-2,j-1:j+2] = 1
                            if i!=78:
                                image[i+2,j-1:j+2] = 1
                                if i!=77:
                                    image[i+3,j-1:j+2] = 1
        return image
        
    def training(self, env, episodes):
        results = np.zeros(episodes)
        skip_idx = 19
        for i in range(episodes):
            observation = env.reset()
            self.memory.init_episode()
            curr_observation = self.preprocess(observation)
            curr_observation = torch.tensor(curr_observation, dtype=torch.float32).reshape(1, 6400)
            prev_observation1 = torch.zeros((1, self.encoder.latent), dtype=torch.float32)
            prev_observation2 = torch.zeros((1, self.encoder.latent), dtype=torch.float32)
            prev_observation3 = torch.zeros((1, self.encoder.latent), dtype=torch.float32)
            episode_reward = 0
            steps = 0
            with torch.no_grad():
                curr_observation = self.encoder.forward(curr_observation, True, False).detach()
            state = torch.cat([curr_observation, prev_observation1, prev_observation2, prev_observation3], 1)
            while True:
                with torch.no_grad():
                    probabilities = self.policy_net.forward(state).detach().numpy()
                    action = torch.tensor(np.random.choice(self.action_size, p=probabilities.flatten()), dtype=torch.float32).reshape(1,1)
                action_translated = action + 2
                new_observation, reward, terminal, _ = env.step(int(action_translated.item()))
                new_observation = self.preprocess(new_observation)
                new_observation = torch.tensor(new_observation, dtype=torch.float32).reshape(1, 6400)
                with torch.no_grad():
                    new_observation = self.encoder.forward(new_observation, True, False).detach()
                new_state = torch.cat([new_observation, curr_observation, prev_observation1, prev_observation2], 1)
                episode_reward += reward
                steps += 1       
                done = torch.zeros([1, 1], dtype=torch.float32) if terminal else torch.ones([1, 1], dtype=torch.float32)
                if reward == 1 or reward == -1:
                    skip_idx = 0
                reward = torch.zeros([1,1], dtype=torch.float32) + reward
                new_state = torch.tensor(new_state, dtype=torch.float32).reshape(1, self.state_size)
                if skip_idx > 18 or skip_idx == 0:
                    self.memory.push_to_trajectory(state, action, new_state, reward, done)
                if len(self.memory.rollout_memory) > self.memory.batch_size:
                    self.train_networks()
                state = new_state
                prev_observation3 = prev_observation2
                prev_observation2 = prev_observation1
                prev_observation1 = curr_observation
                curr_observation = new_observation
                skip_idx += 1
                if terminal:
                    results[i] = episode_reward
                    self.cum_rewards += episode_reward
                    if self.termination_length is not None:
                        if steps < self.termination_length:
                            self.memory.push_to_memory()
                        else:
                            self.memory.push_to_memory(self.value_net)
                    if self.termination_length is None:
                        self.memory.push_to_memory()
                    print("\rEp: {} Online reward: {:.2f}; Steps: {}".format(i + 1, episode_reward, steps), end="")
                    break
        return np.array(results)
    
    def entropy(self, p_matrix):
        log_probs = torch.log(p_matrix)
        entropy = torch.sum(-p_matrix*log_probs, 1)
        return torch.mean(entropy)

In [None]:
env = gym.make("PongDeterministic-v4")
env.frameskip = 4
latent = 16
action_size = 2
SEED = 10
batch_size = 512
discount = 0.99
hidden_p = [512]
hidden_v = [512]
hidden_t = [1028]
hidden_r = [1028]
i_v = 1
i_g = 5
i_wm = 3
horizon = 5
lambda_ = 0.99
lr_p = 0.001
lr_v = 0.001
lr_t = 0.002
lr_r = 0.002
episodes = 1000
clip = 1
step = 10
termination_length = 100
k = 1

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

encoder = VAE(6400, 512, latent)
encoder.load_state_dict(torch.load('VAE_MLP_512_16'))

policy_net = policy_network(latent*4, hidden_p, action_size)
value_net = value_network(latent*4, hidden_v, 1)
reward_net = reward_network(latent*4, hidden_r, action_size)
transition_net = transition_network(latent*4, hidden_t, action_size, latent)
agent = wmpg(4*latent, action_size, batch_size, discount, policy_net, value_net, transition_net, reward_net, encoder, value_updates=i_v, 
             policy_updates=i_g, wm_updates=i_wm, imagination_horizon=horizon, imagination_lambda=lambda_, wm_memory=5000, policy_lr=lr_p, 
             value_lr=lr_v, transition_lr=lr_t, reward_lr=lr_r, clip=clip, scheduler_step=step, termination_length=termination_length, k=k, entropy_coef=0.01)
results = agent.training(env, episodes)
np.savetxt('wmpg_pong.csv', results, delimiter=',')



Ep: 1 Online reward: -20.00; Steps: 1041



Ep: 3 Online reward: -21.00; Steps: 764