In [156]:
import matlab.engine
import time
import torch
from torch import nn
from torch.distributions import Normal
from torch.optim import Adam, SGD
import numpy as np

In [158]:
LAMBDA = 0.95
GAMMA = 0.99

ACTOR_LR = 8e-4
CRITIC_LR = 4e-4

CLIP = 0.2
ENTROPY_COEF = 2e-2
BATCHES_PER_UPDATE = 64
BATCH_SIZE = 64

EPISODES_PER_UPDATE = 20
ITERATIONS = 200


In [2]:
eng = matlab.engine.start_matlab()

In [10]:
eng.quit()

In [3]:
print(eng.isprime(37))

True


In [4]:
state = eng.default_state_PSS()

In [21]:
start_time = time.time()
for _ in range(20):
    state = eng.sim_step(state, 1.)
print("--- %s seconds ---" % (time.time() - start_time))

--- 163.74731183052063 seconds ---


In [5]:
start_time = time.time()
for _ in range(100):
    state = eng.sim_step(state, 1.)
print("--- %s seconds ---" % (time.time() - start_time))

MatlabExecutionError: Initial state vector "X0" must be a real vector of length 1


In [67]:
sample = torch.tensor(np.random.rand(1, 4, 60), dtype=torch.float)
sample.shape

torch.Size([1, 4, 60])

In [137]:
class OSBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernels=[1, 3, 5, 7, 11]):
        super().__init__()
        assert out_channels % 2 == 0, "Numnber of out channels should be odd!"
        self.kernels = kernels
        self.convs1 = nn.ModuleList([nn.Conv1d(in_channels=in_channels,
                                               out_channels=4, kernel_size=kernel,
                                               padding=kernel // 2, padding_mode='reflect')
                                     for kernel in kernels])
        self.convs2 = nn.ModuleList([nn.Conv1d(in_channels=4 * len(kernels),
                                               out_channels=4,
                                               kernel_size=kernel,
                                               padding=kernel // 2, padding_mode='reflect')
                                     for kernel in kernels])
        self.batchnorm1 = nn.Sequential(nn.BatchNorm1d(4 * len(kernels)), nn.ReLU())
        self.batchnorm2 = nn.Sequential(nn.BatchNorm1d(4 * len(kernels)), nn.ReLU())
        self.convs3 = nn.ModuleList([nn.Conv1d(in_channels=4 * len(kernels),
                                                out_channels=out_channels // 2,
                                                kernel_size=kernel,
                                                padding=kernel // 2, padding_mode='reflect')
                                      for kernel in [1, 3]])
        self.batchnorm3 = nn.Sequential(nn.BatchNorm1d(out_channels), nn.ReLU())

    def forward(self, state):
        intermediate = torch.concat([l(state) for l in self.convs1], dim=-2)
        intermediate = self.batchnorm1(intermediate)
        intermediate = torch.concat([l(intermediate) for l in self.convs2], dim=-2)
        intermediate = self.batchnorm2(intermediate)
        intermediate = torch.concat([l(intermediate) for l in self.convs3], dim=-2)
        intermediate  =self.batchnorm3(intermediate)
        return intermediate


In [140]:
blk = OSBlock(4, 8)

In [141]:
res = blk(sample)
res.shape

torch.Size([1, 8, 60])

In [142]:
OSBlock(8,4)(res).shape

torch.Size([1, 4, 60])

In [124]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [125]:
count_parameters(blk)

2884

In [132]:
actor = Actor(5, [1, 1, 1, 1,1])
count_parameters(actor)

8439

In [134]:
pool = nn.AdaptiveAvgPool1d(10)
pool(res).shape

torch.Size([1, 4, 10])

In [143]:
res.shape

torch.Size([1, 8, 60])

In [161]:
def compute_lambda_returns_and_gae(trajectory):
    lambda_returns = []
    gae = []
    last_lr = 0.
    last_v = 0.
    for _, _, r, _, v in reversed(trajectory):
        ret = r + GAMMA * (last_v * (1 - LAMBDA) + last_lr * LAMBDA)
        last_lr = ret
        last_v = v
        lambda_returns.append(last_lr)
        gae.append(last_lr - v)

    # Each transition contains state, action, old action probability, value estimation and advantage estimation
    return [(s, a, p, v, adv) for (s, a, _, p, _), v, adv in zip(trajectory, reversed(lambda_returns), reversed(gae))]


class OSBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernels=[1, 3, 5, 7, 11]):
        super().__init__()
        assert out_channels % 2 == 0, "Numnber of out channels should be odd!"
        self.kernels = kernels
        self.convs1 = nn.ModuleList([nn.Conv1d(in_channels=in_channels,
                                               out_channels=4, kernel_size=kernel,
                                               padding=kernel // 2)
                                     for kernel in kernels])
        self.convs2 = nn.ModuleList([nn.Conv1d(in_channels=4 * len(kernels),
                                               out_channels=4,
                                               kernel_size=kernel,
                                               padding=kernel // 2)
                                     for kernel in kernels])
        self.batchnorm1 = nn.Sequential(nn.BatchNorm1d(4 * len(kernels)), nn.ReLU())
        self.batchnorm2 = nn.Sequential(nn.BatchNorm1d(4 * len(kernels)), nn.ReLU())
        self.convs3 = nn.ModuleList([nn.Conv1d(in_channels=4 * len(kernels),
                                               out_channels=out_channels // 2,
                                               kernel_size=kernel,
                                               padding=kernel // 2)
                                     for kernel in [1, 3]])
        self.batchnorm3 = nn.Sequential(nn.BatchNorm1d(out_channels), nn.ReLU())

    def forward(self, state):
        intermediate = torch.concat([l(state) for l in self.convs1], dim=-2)
        intermediate = self.batchnorm1(intermediate)
        intermediate = torch.concat([l(intermediate) for l in self.convs2], dim=-2)
        intermediate = self.batchnorm2(intermediate)
        intermediate = torch.concat([l(intermediate) for l in self.convs3], dim=-2)
        intermediate = self.batchnorm3(intermediate)
        return intermediate


class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_size):
        super().__init__()
        self.enc1 = OSBlock(in_channels, in_channels)
        self.enc2 = OSBlock(in_channels, in_channels)
        self.hidden_size = hidden_size
        self.linear = nn.Sequential(
            nn.Linear(hidden_size * in_channels, hidden_size * in_channels),
            nn.ReLU()
        )

    def forward(self, state):
        x = self.enc1(state)
        x = self.enc2(x + state)
        x = self.linear(x.view(x.shape[0], -1))
        return x


class Actor(nn.Module):
    def __init__(self, action_dim, action_scaler):
        super().__init__()
        # Advice: use same log_sigma for all states to improve stability
        # You can do this by defining log_sigma as nn.Parameter(torch.zeros(...))
        self.encoder = Encoder(4, 60)
        self.mean = torch.nn.Sequential(
            torch.nn.Linear(240, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, action_dim),
            torch.nn.ReLU(),
        )
        # self.sigma = nn.Sequential(
        #     nn.Linear(256, action_dim),
        #     nn.ELU()
        # )
        self.action_scaler = torch.tensor(action_scaler, dtype=torch.float32)
        self.sigma = nn.Parameter(torch.zeros(action_dim))

    def compute_proba(self, state, action):
        # Returns probability of action according to current policy and distribution of actions
        _, pa, distribution = self.act(state)
        proba = distribution.log_prob(action).sum(-1)
        return proba, distribution

    def act(self, state):
        # Returns an action (with tanh), not-transformed action (without tanh) and distribution of non-transformed actions
        # Remember: agent is not deterministic, sample actions from distribution (e.g. Gaussian)
        latent = self.encoder(state)
        mean = self.mean(latent)
        # sigma = torch.exp(-self.sigma(latent))
        sigma = torch.exp(self.sigma)
        distribution = Normal(mean, sigma)
        action = distribution.sample()
        tanh_action = torch.sigmoid(action) * self.action_scaler
        # tanh_action = torch.sigmoid(action)
        return tanh_action, action, distribution


class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder(4, 60)
        self.mean = torch.nn.Sequential(
            torch.nn.Linear(240, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 1),
            torch.nn.ReLU(),
        )

    def get_value(self, state):
        latent = self.encoder(state)
        return self.mean(latent)


class PPO:
    def __init__(self, action_scaler, action_dim, device='cpu'):
        self.device = device
        self.actor = Actor(action_dim, action_scaler).to(self.device)
        self.critic = Critic().to(self.device)
        self.actor_optim = Adam(self.actor.parameters(), ACTOR_LR, amsgrad=True)
        self.critic_optim = Adam(self.critic.parameters(), CRITIC_LR, amsgrad=True)

        # self.actor_scheduler = torch.optim.lr_scheduler.CyclicLR(self.actor_optim, base_lr=1e-3,
        #                                                          max_lr=1e-2, step_size_up=100, mode='triangular2',
        #                                                          cycle_momentum=False)
        # self.critic_scheduler = torch.optim.lr_scheduler.CyclicLR(self.critic_optim, base_lr=5e-4,
        #                                                           max_lr=5e-3, step_size_up=100, mode='triangular2',
        #                                                           cycle_momentum=False)

    def update(self, trajectories):
        transitions = [t for traj in trajectories for t in traj]  # Turn a list of trajectories into list of transitions
        state, action, old_prob, target_value, advantage = zip(*transitions)
        state = torch.FloatTensor(np.array(state)).to(self.device)
        action = torch.FloatTensor(np.array(action)).to(self.device)
        old_prob = torch.FloatTensor(np.array(old_prob)).to(self.device)
        target_value = torch.FloatTensor(np.array(target_value)).to(self.device)
        advantage = np.array(advantage)
        advantage = torch.FloatTensor((advantage - advantage.mean()) / (advantage.std() + 1e-8)).to(self.device)

        actor_loss_ls = []
        critic_loss_ls = []

        for _ in range(BATCHES_PER_UPDATE):
            # idx = np.random.randint(0, len(transitions), BATCH_SIZE)  # Choose random batch
            idx = torch.randint(0, len(transitions), (BATCH_SIZE,)).to(self.device)
            s = state[idx]
            a = action[idx]
            op = old_prob[idx]  # Probability of the action in state s.t. old policy
            v = target_value[idx]  # Estimated by lambda-returns
            adv = advantage[idx]  # Estimated by generalized advantage estimation

            # Update actor here
            log_prob, distribution = self.actor.compute_proba(s, a)
            ratio = torch.exp(log_prob - op)
            surr1 = ratio * adv
            surr2 = torch.clamp(ratio, 1 - CLIP, 1 + CLIP) * adv
            actor_loss = (-torch.min(surr1, surr2)).mean() - ENTROPY_COEF * distribution.entropy().mean()
            actor_loss_ls.append(actor_loss.item())
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            # Update critic here
            critic_value = self.critic.get_value(s)
            critic_loss = nn.MSELoss()(torch.squeeze(critic_value), v)
            critic_loss_ls.append(critic_loss.item())
            self.critic_optim.zero_grad()
            critic_loss.backward()
            self.critic_optim.step()
        # self.critic_scheduler.step()
        # self.actor_scheduler.step()
        return np.mean(actor_loss_ls), np.mean(critic_loss_ls)

    def get_value(self, state):
        with torch.no_grad():
            state.to(self.device)
            value = self.critic.get_value(state)
        return value.item()

    def act(self, state):
        with torch.no_grad():
            state.to(self.device)
            action, pure_action, distr = self.actor.act(state)
            log_prob = distr.log_prob(pure_action).sum(-1)
            # log_prob = distr.log_prob(pure_action)
        return action, pure_action, log_prob

    def save(self, name="agent.pkl", folder=""):
        torch.save(self.actor.state_dict(), path.join(folder, name))
        torch.save(self.critic.state_dict(), path.join(folder, 'critic_' + name))

    def load(self, name="agent.pkl", folder=""):
        self.actor.load_state_dict(torch.load(path.join(folder, name)))
        self.critic.load_state_dict(torch.load(path.join(folder, 'critic_' + name)))
        self.actor.eval()
        self.critic.eval()
        self.actor.to(self.device)
        self.critic.to(self.device)

    def perform(self):
        self.actor.to('cpu')
        self.critic.to('cpu')
        self.actor.eval()
        self.critic.eval()

    def train_(self):
        self.actor.to(self.device)
        self.critic.to(self.device)
        self.actor.train()
        self.critic.train()


In [162]:
ppo = PPO(action_dim=5, action_scaler=[1, 1, 1, 1, 1])

In [163]:
ppo.act(sample)

(tensor([[0.6797, 0.2996, 0.8412, 0.3858, 0.8816]]),
 tensor([[ 0.7526, -0.8494,  1.6669, -0.4649,  2.0077]]),
 tensor([-8.6462]))