# Assignment 16: Policy Gradient

## 1) REINFORCE using pytorch interface

Let's build a class that will be at thye same time a pytorch DNN (with softmax output layer) and a Policy. REINFORCE will then simply compute traces and update the weights.

For the weight update, we can use the SGD optimizer, and perform backward passes on $-\alpha \gamma^t G_t \log(\pi(s,a))$ to update the DNN's weights according to the course's pseudo code.


In [5]:
import torch
from rl.distribution import Choose, Distribution
from rl.markov_decision_process import MarkovDecisionProcess, Policy

In [92]:
class SimpleSigmoidModel(torch.nn.Module):
    def __init__(self, n_layers, input_size, num_actions):
        super().__init__()
        def build_base_bloc():
            return torch.nn.Sequential(torch.nn.Linear(input_size,input_size))
        list_modules =  [torch.nn.Sequential(build_base_bloc(), torch.nn.Sigmoid()) for i in range(n_layers-1)]  + [torch.nn.Sequential(torch.nn.Linear(input_size,num_actions), torch.nn.Softmax(dim = 1))]
        self.model = torch.nn.Sequential(*list_modules)
    def forward(self, x):
        return self.model(x)



class torchPolicyDiscrete(torch.nn.Module, Policy):
    def __init__(self, n_layers, feature_extractors, action_space, learning_rate):
        super().__init__()
        self.model = SimpleSigmoidModel(n_layers, len(feature_extractors), len(action_space))
        self.feature_extractors = feature_extractors
        self.action_space = action_space
        self.action_indexes = {a:i for i,a in enumerate(self.action_space)}
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate)

    def forward(self, s_tensor):
        return self.model(s_tensor)

    def extract_features(self, s):
        return torch.tensor([[phi_i(s) for phi_i in self.feature_extractors]])

    @torch.no_grad()
    def distribution_action(self, s):
        probas = self.model(s).numpy().flatten()
        return Choose({a:probas[i] for i,a in enumerate(self.action_space)})
    def act(self, s):
        return self.distribution_action(s).sample()

    def update_params(self, s,a,G,gamma,t):
        a_index = self.action_indexes[a]
        s_tensor = self.extract_features(s)
        pi_s_a = self.forward(s_tensor)[0, a_index]
        loss = -(gamma**t)*G*torch.log(pi_s_a)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


def sample_one_episode_SAG(policy : torchPolicyDiscrete, mdp: MarkovDecisionProcess, init_state_distrib : Distribution, gamma : float):
    sasr_seq = [sasr for sasr in mdp.simulate_actions( start_states = init_state_distrib, policy = policy)]
    r_seq = [r for (s,a,snext,r) in sasr_seq]
    G_seq = []
    for r in r_seq[::-1]:
        if len(G_seq) == 0:
            G_seq.append(r)
        else:
            G_seq.append(r + G_seq[-1]*gamma)
    G_seq = G_seq[::-1]
    sag_seq = [(s,a,G) for ((s,a,snext,r), G) in zip(sasr_seq, G_seq)]
    return sag_seq

def update_policy_with_episode(policy : torchPolicyDiscrete, sag_seq, gamma : float):
    for t, (s,a,G) in enumerate(sag_seq):
        policy.update_params(s,a,G,gamma, t)


def REINFORCE(policy : torchPolicyDiscrete,gamma : float, mdp : MarkovDecisionProcess, init_state_distrib : Distribution, n_episodes : int):
    for _ in range(n_episodes):
        sag_seq = sample_one_episode_SAG(policy, mdp, init_state_distrib, gamma)
        update_policy_with_episode(policy, sag_seq, gamma)