In [None]:
import wandb
wandb.login()

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [None]:
class Policy(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_size=32, mean_scale=1, min_std=1e-4):
        super(Policy, self).__init__()

        self.linear = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.ELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ELU(),
        )
        self.action = nn.Linear(hidden_size, 2 * out_dim)

        self._mean_scale = mean_scale
        self._min_std = min_std

    def forward(self, x):
        x = self.linear(x)
        action = self.action(x)
        action_mean, action_std_dev = torch.chunk(action, 2, dim=1)

        action_mean = self._mean_scale * torch.tanh(action_mean / self._mean_scale)
        action_std_dev = F.softplus(action_std_dev) + self._min_std
        return action_mean, action_std_dev

In [None]:
from __future__ import division
import os, sys
import warnings
os.environ["SDL_VIDEODRIVER"] = "dummy"
sys.path.append("..")
warnings.filterwarnings('ignore')

import psutil
import gc
import gym
#import dmc2gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import torch.optim as optim
from itertools import count
from gym import spaces
from torch.autograd import Variable
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
import time
import glfw
from itertools import count
import copy

from src.networks import FeedForwardNet, FeedForwardNet_v2, EnvModel, TransitionModel, Critic
from src.replay_buffer import Replay_buffer
from src.utils import fanin_init, weights_init_normal, Average, freeze, unfreeze
import src.longenvs

In [None]:
exp_config = {
    'max_episodes': 2000,
    'max_steps': 5000, # env.steps in one episode
    'domain_name': "Humanoid-v4",
    # 'domain_name': "HalfCheetah-v4",
    # 'task_name': 'walk',
    'seed': 86723146,
    'free_runs': 10, # model warm up steps
    'terminate_when_unhealthy': False,
    'eval_interval': 100
}

agent_config = {
    'policy_update_iterations': 100,
    'potential_update_iterations': 100,
    'model_update_iterations': 1000,
    'critic_update_iteration': 700,
    'top_perc': 0.5, # take top_perc of best transitions(state, next state, action, reward, done) by reward
    'max_buffer': 100000, # transitions buffer size

    'batch_size': 1000,
    'lr': 3e-4,
    'hidden_size': 512,
    'model_hidden_size': 512,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'max_steps': exp_config['max_steps']
}


globals().update(exp_config)
if terminate_when_unhealthy:
    env = gym.make(domain_name)
else:
    env = gym.make(domain_name, terminate_when_unhealthy=False)


def seed_everything(seed):
    #random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    env.seed(seed)

seed_everything(seed)
env_reset_rng = np.random.default_rng(seed=3*seed)

agent_name = f'{domain_name}_{seed}_k'

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

best_log_score =  0
best_log_step = 100

directory = './data/'
model_directory = directory + 'models/'
plot_directory = directory + 'plots/'
os.makedirs(directory, exist_ok = True)
os.makedirs(model_directory, exist_ok = True)

train_ctr = count()
eval_ctr = count()

run = wandb.init(name=agent_name, project='puterman-rl', save_code=True, config = {**exp_config, **agent_config})

In [None]:
class LPModel(object):
    def __init__(self, state_dim, action_dim, config_dict):

        self.__dict__ = config_dict

        self.action_dim = action_dim
        self.state_dim = state_dim
        self.num_model_update_iteration = 0

        # self.policy = FeedForwardNet_v2(state_dim, self.q_dim, action_dim, int(1.5*self.hidden_size)).to(self.device)
        self.policy = Policy(state_dim, action_dim, self.hidden_size).to(self.device)
        self.policy_optimizer = torch.optim.AdamW(self.policy.parameters(), self.lr)

        self.potential = FeedForwardNet(state_dim, action_dim, int(self.hidden_size)).to(self.device)
        self.potential_optimizer = torch.optim.AdamW(self.potential.parameters(), self.lr)

        self.model = TransitionModel(state_dim, action_dim, 2*self.model_hidden_size, self.model_hidden_size).to(self.device)
        self.model_optimizer = optim.AdamW(self.model.parameters(), lr=1e-3)
        self.model_loss = nn.MSELoss()


        self.critic = Critic(state_dim, action_dim, 2*self.model_hidden_size).to(self.device)
        self.critic_optimizer = optim.AdamW(self.critic.parameters(), lr=1e-3)

        self.replay_buffer = Replay_buffer(self.max_buffer, top_perc=self.top_perc)

    def select_action(self, state, random_prob = 0):
        state = Variable(torch.from_numpy(np.float32(state))).reshape(1, -1).to(self.device)
        action_mean, action_std_dev = self.policy(state)
        action = Normal(action_mean, action_std_dev).rsample()
        action = torch.clamp(action, -1, 1)
        return action.cpu().data.numpy().flatten()

    def critic_update(self):
        #start_time = time.time()
        closses = []
        unfreeze(self.critic);
        for it in range(self.critic_update_iteration):
            if it % 100 == 0:
                x, _, u, r, _ = self.replay_buffer.sample_last(self.max_steps, self.max_steps)
            else:
                x, _, u, r, _ = self.replay_buffer.sample(self.batch_size)
            state = torch.FloatTensor(x).to(self.device)
            action = torch.FloatTensor(u).to(self.device)
            reward = torch.FloatTensor(r).to(self.device)

            current_Q = self.critic(state, action)
            critic_loss = F.mse_loss(current_Q, reward)
            closses.append(critic_loss)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()
        #print("Critic_update: ", str(time.time() - start_time))
        return closses

    def env_update(self):
        #start_time = time.time()
        mlosses = []
        unfreeze(self.model);
        for it in range(self.model_update_iterations):
            if it % 100 == 0:
                x, y, u, _, _ = self.replay_buffer.sample_last(self.max_steps, self.max_steps)
            else:
                x, y, u, _, _ = self.replay_buffer.sample(self.batch_size)
            state = torch.FloatTensor(x).to(self.device)
            action = torch.FloatTensor(u).to(self.device)
            next_state = torch.FloatTensor(y).to(self.device)

            state_ = self.model(state, action)
            loss = self.model_loss(state_, next_state)

            mlosses.append(loss)

            self.model_optimizer.zero_grad()
            loss.backward()
            self.model_optimizer.step()
            self.num_model_update_iteration += 1
        #print("Env_update: ", str(time.time() - start_time))
        return mlosses

    def update(self, max_kl=1):
        ploss = []
        floss = []
        prev_step_kl = np.array([0])
        prev_policy = copy.deepcopy(self.policy)
        for it in range(self.policy_update_iterations):
            prev_step_policy = self.policy.state_dict()
            prev_step_opt = self.policy_optimizer.state_dict()
            
            unfreeze(self.policy); freeze(self.potential)#; freeze(self.model); freeze(self.critic)
            
            #some way of sampling statet and next_states for policy update
            if it%5 == 0:
                bs = self.max_steps
                x, y, _, _, _ = self.replay_buffer.sample_last(bs, self.max_steps)
                qs = 20
            else:
                bs = self.batch_size
                x, y, _, _, _ = self.replay_buffer.sample_r_sorted(bs)
                qs = 100

            sd = agent.state_dim
            
            next_state_w = torch.FloatTensor(y).to(agent.device).reshape(bs,1,sd).repeat(1,qs,1).reshape(-1, sd)
            state = torch.FloatTensor(x).to(agent.device).reshape(bs,1,sd).repeat(1,qs,1).reshape(-1, sd)
            action_mean, action_std_dev = self.policy(state)
            action = Normal(action_mean, action_std_dev).rsample()
            next_state = self.model(state,action)
            r = self.critic(state,action)

            P_loss = -r.mean() #- self.potential(next_state).mean()
            self.policy_optimizer.zero_grad()
            P_loss.backward()
            self.policy_optimizer.step()
            ploss.append(P_loss)


        for it in range(self.potential_update_iterations):
            freeze(self.policy); unfreeze(self.potential)

            gc.collect(); torch.cuda.empty_cache()

            with torch.no_grad():
                action_mean, action_std_dev = self.policy(state)
                action = Normal(action_mean, action_std_dev).rsample()
                next_state = self.model(state,action)
            f_loss = self.potential(next_state_w).mean() - self.potential(next_state.detach()).mean()
            f_loss = -f_loss
            self.potential_optimizer.zero_grad()
            f_loss.backward()
            self.potential_optimizer.step()
            floss.append(f_loss)
            
            gc.collect(); torch.cuda.empty_cache()
            
            
            #calculatint KL and adjusting learning rate for stability
            with torch.no_grad():
                x = self.replay_buffer.sample_last(self.max_steps, self.max_steps)[0]
                state = torch.FloatTensor(x).to(self.device)
                mean0, std0 = prev_policy(state)
                mean1, std1 = self.policy(state)
                kl = torch.log(std1) - torch.log(std0) + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
                kl = kl.sum(1, keepdim=True)
                gc.collect(); torch.cuda.empty_cache()
                
                if kl.mean() > max_kl:
                    if it > 20:
                        self.policy.load_state_dict(prev_step_policy)
                        self.policy_optimizer.load_state_dict(prev_step_opt)
                        break
                    else:
                        if kl.mean() > 2*max_kl:
                            self.policy.load_state_dict(prev_step_policy)
                            self.policy_optimizer.load_state_dict(prev_step_opt)
                            self.lr *= 0.8
                            for g in self.policy_optimizer.param_groups:
                                g['lr'] = self.lr
                            for g in self.potential_optimizer.param_groups:
                                g['lr'] = self.lr
                            break
                        else:
                            prev_step_kl = kl
                            #break
                if it > 80:
                    self.lr *= 10/8
                    for g in self.policy_optimizer.param_groups:
                        g['lr'] = self.lr
                    for g in self.potential_optimizer.param_groups:
                        g['lr'] = self.lr
                
                prev_step_kl = kl
        print(f'Policy updates:\t {it}, LR:\t {self.lr}, KL:\t {prev_step_kl.mean()}')

        return action, ploss, floss, prev_step_kl

    def evaluate(self, env):
        #start_time = time.time()
        eval_rewards = []
        for i_episode in range(10):
            ep_reward = 0
            state = env.reset(seed=int(env_reset_rng.integers(np.iinfo(np.int64).max)))
            while True:
                # env.render()
                action = self.select_action(state)
                s_, r, done, info = env.step(action)
                ep_reward += r
                if done:
                    break
                state = s_
            eval_rewards.append(ep_reward)
        eval_rewards = np.array(eval_rewards)
        min, mean, max = eval_rewards.min(), eval_rewards.mean(), eval_rewards.max()
        print(f'Eval: min={min}\tmean={mean}\tmax={max}')
        return min, mean, max


    def save(self):
        torch.save(self.policy.state_dict(), directory + 'policy.pth')
        torch.save(self.potential.state_dict(), directory + 'potential.pth')
        torch.save(self.model.state_dict(), directory + 'model.pth')
        print("====================================")
        print("Model has been saved...")
        print("====================================")

    def load(self):
        self.policy.load_state_dict(torch.load(directory + 'policy.pth'))
        self.potential.load_state_dict(torch.load(directory + 'potential.pth'))
        self.model.load_state_dict(torch.load(directory + 'model.pth'))
        print("====================================")
        print("models has been loaded...")
print("====================================")

In [None]:
seed_everything(seed)
env_reset_rng = np.random.default_rng(seed=3*seed)

In [None]:
agent = LPModel(state_dim, action_dim, agent_config)
model_name = agent_name+'_model.pt'
policy_name = agent_name+'_best_policy.pt'

In [None]:
wandb.define_metric("Steps")
wandb.define_metric("*", step_metric="Steps")

total_step = 0
ploss, floss, kl = [torch.tensor([0])], [torch.tensor([0])], torch.FloatTensor([0])
runs_data = wandb.Table(columns=["run_actions", "run_rewards"])
all_run_rewards = []

for episode in tqdm(range(max_episodes)):
    ep_rewards = []
    ep_steps = 0
    act_tmp = []
    rwd_tmp = []
    while ep_steps < max_steps:
        # Run env interaction
        # run_step = 0
        run_reward = 0
        act_tmp = []
        rwd_tmp = []

        state = env.reset(seed=int(env_reset_rng.integers(np.iinfo(np.int64).max)))

        for run_step in range(1000):#for cheetah 10000
            # env.render()
            if episode < free_runs:
                action = np.random.uniform(
                    low=float(env.action_space.low[0]),
                    high=float(env.action_space.high[0]),
                    size=(action_dim),
                )
            else:
                action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.replay_buffer.push(
                (state, next_state, action, reward, np.float32(done))
            )

            act_tmp.append(action)
            rwd_tmp.append(reward)

            state = next_state
            # run_step += 1
            run_reward += reward

            total_step += 1

            if done:
                break
        ep_steps += run_step
        # print(np.array(act_tmp).shape, np.array(rwd_tmp).shape)
        runs_data.add_data(np.array(act_tmp), np.array(rwd_tmp))
        ep_rewards.append(run_reward)

    plt.figure(figsize=(5, 2))
    plt.plot(act_tmp[:1000])
    plt.show()
    print("Total T:{} Run Reward: \t{}".format(total_step, ep_rewards))

    ep_rewards = np.array(ep_rewards)
    all_run_rewards.append(ep_rewards)

    if ep_rewards.mean() > best_log_score + best_log_step:
        best_log_score = ep_rewards.mean()
        save_path = f"{model_directory}best_policy_{run.id}_{episode}_{int(ep_rewards.mean())}.pth"
        torch.save(agent.policy.state_dict(), save_path)
        run.log_model(path=save_path)

    closs = agent.critic_update()
    mloss = agent.env_update()
    if episode >= free_runs:
        _, ploss, floss, kl = agent.update()

    print(
        "Episode: \t{}  Ploss: \t{} Floss: \t{} KL: \t{} Closs: \t{} Mloss: \t{}".format(
            episode,
            Average(ploss).item(),
            Average(floss).item(),
            kl.mean(),
            Average(closs).item(),
            Average(mloss).item(),
        )
    )
    wandb.log(
        {
            "Rollout/min": ep_rewards.min(),
            "Rollout/mean": ep_rewards.mean(),
            "Rollout/max": ep_rewards.max(),
            "Ploss": Average(ploss).item(),
            "Floss": Average(floss).item(),
            "KL": kl.mean(),
            "Closs": Average(closs).item(),
            "Mloss": Average(mloss).item(),
            "Episode_data": runs_data,
            "Steps": next(train_ctr),
        }
    )
#     if episode % eval_interval == 0:
#         #print('Ploss', Average(ploss).item())
#         #print('Floss', Average(floss).item())
#         eval_min, eval_mean, eval_max = agent.evaluate(env)
#         wandb.log({
#             'Eval/min': eval_min,
#             'Eval/mean': eval_mean,
#             'Eval/max': eval_max,
#             'Steps': next(eval_ctr)
#             })
#         #achieved_reward = agent.evaluate(agent.policy, env)
#         # if achieved_reward> best_reward:
#         #     best_reward = achieved_reward
#             #torch.save(agent.policy.state_dict(), f'/home/jovyan/LinearProgrammingRL/saved_models/{policy_name}')
#         print("--------------------------------")
gc.collect()
torch.cuda.empty_cache()

In [None]:
wandb.finish()