<a href="https://colab.research.google.com/github/itsmepriyabrata/priyabrata_ai_python/blob/main/Reinforcement_learning_part_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

policy gradients

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

env = gym.make('CartPole-v0')

class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(env.observation_space.shape[0], 128)
        self.fc2 = nn.Linear(128, env.action_space.n)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the REINFORCE algorithm
class REINFORCE:
    def __init__(self, policy_network):
        self.policy_network = policy_network
        self.optimizer = optim.Adam(self.policy_network.parameters(), lr=0.001)

    def act(self, state):
        state = torch.tensor(state, dtype=torch.float)
        action_probs = self.policy_network(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item()

    def update(self, state, action, reward, next_state):
        state = torch.tensor(state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)

        action_probs = self.policy_network(state)
        dist = Categorical(action_probs)
        log_prob = dist.log_prob(action)
        loss = -log_prob * reward

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

class ActorCritic:
    def __init__(self, policy_network, value_network):
        self.policy_network = policy_network
        self.value_network = value_network
        self.optimizer = optim.Adam(self.policy_network.parameters(), lr=0.001)

    def act(self, state):
        state = torch.tensor(state, dtype=torch.float)
        action_probs = self.policy_network(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item()

    def update(self, state, action, reward, next_state):
        state = torch.tensor(state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)

        action_probs = self.policy_network(state)
        dist = Categorical(action_probs)
        log_prob = dist.log_prob(action)
        loss = -log_prob * reward

        value = self.value_network(state)
        value_loss = (value - reward) ** 2

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

reinforce = REINFORCE(PolicyNetwork())
for episode in range(1000):
    state = env.reset()
    done = False
    rewards = 0
    while not done:
        action = reinforce.act(state)
        next_state, reward, done, _ = env.step(action)
        reinforce.update(state, action, reward, next_state)
        state = next_state
        rewards += reward
    print(f'Episode {episode+1}, Reward: {rewards}')

actor_critic = ActorCritic(PolicyNetwork(), PolicyNetwork())
for episode in range(1000):
    state = env.reset()
    done = False
    rewards = 0
    while not done:
        action = actor_critic.act(state)
        next_state, reward, done, _ = env.step(action)
        actor_critic.update(state, action, reward, next_state)
        state = next_state
        rewards += reward
    print(f'Episode {episode+1}, Reward: {rewards}')


**Actor-Critic methods**

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

env = gym.make('CartPole-v0')

class ActorCriticNetwork(nn.Module):
    def __init__(self):
        super(ActorCriticNetwork, self).__init__()
        self.fc1 = nn.Linear(env.observation_space.shape[0], 128)
        self.fc_actor = nn.Linear(128, env.action_space.n)
        self.fc_critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        action_probs = torch.softmax(self.fc_actor(x), dim=-1)
        value = self.fc_critic(x)
        return action_probs, value

class A2C:
    def __init__(self, network):
        self.network = network
        self.optimizer = optim.Adam(self.network.parameters(), lr=0.001)
        self.gamma = 0.99

    def act(self, state):
        state = torch.tensor(state, dtype=torch.float)
        action_probs, _ = self.network(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item()

    def update(self, states, actions, rewards, next_states):
        states = torch.tensor(states, dtype=torch.float)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float)
        next_states = torch.tensor(next_states, dtype=torch.float)

        action_probs, values = self.network(states)
        dist = Categorical(action_probs)
        log_probs = dist.log_prob(actions)

        next_action_probs, next_values = self.network(next_states)
        next_values = next_values.detach()

        td_targets = rewards + self.gamma * next_values
        td_errors = td_targets - values

        actor_loss = -log_probs * td_errors.detach()
        critic_loss = td_errors ** 2

        loss = actor_loss + critic_loss

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

class PPO:
    def __init__(self, network):
        self.network = network
        self.optimizer = optim.Adam(self.network.parameters(), lr=0.001)
        self.gamma = 0.99
        self.epsilon = 0.2

    def act(self, state):
        state = torch.tensor(state, dtype=torch.float)
        action_probs, _ = self.network(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item()

    def update(self, states, actions, rewards, next_states, old_log_probs):
        states = torch.tensor(states, dtype=torch.float)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float)
        next_states = torch.tensor(next_states, dtype=torch.float)
        old_log_probs = torch.tensor(old_log_probs, dtype=torch.float)

        action_probs, values = self.network(states)
        dist = Categorical(action_probs)
        log_probs = dist.log_prob(actions)

        next_action_probs, next_values = self.network(next_states)
        next_values = next_values.detach()

        td_targets = rewards + self.gamma * next_values
        td_errors = td_targets - values

        ratios = torch.exp(log_probs - old_log_probs)
        surr1 = ratios * td_errors
        surr2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * td_errors

        actor_loss = -torch.min(surr1, surr2)
        critic_loss = td_errors ** 2

        loss = actor_loss + critic_loss

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()

a2c = A2C(ActorCriticNetwork())
for episode in range(1000):
    state = env.reset()
    done = False
    rewards = 0
    while not done:
        action = a2c.act(state)
        next_state, reward, done, _ = env.step(action)
        a2c.update([state], [action], [reward], [next_state])
        state = next_state
        rewards += reward
    print(f'Episode {episode+1}, Reward: {rewards}')

ppo = PPO(ActorCriticNetwork())
for episode in range(1000):
    state = env.reset()
    done = False
    rewards = 0
    states, actions, rewards, next_states, old_log_probs = [], [], [], [], []
    while not done:
        action = ppo.act(state)
        next_state, reward, done, _ = env.step(action)
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        next_states.append(next_state)
        action_probs, _ = ppo.network(torch.tensor(state, dtype=torch.float))
        dist = Categorical(action_probs)
        old_log_probs.append(dist.log_prob(torch.tensor(action, dtype=torch.long)))
        state = next_state
        rewards += reward
    ppo.update(states, actions, rewards, next_states, old_log_probs)
    print(f'Episode {episode+1}, Reward: {rewards}')


Monte Carlo Tree Search

In [None]:
import random
import math

class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0

    def add_child(self, child):
        self.children.append(child)

    def update(self, value):
        self.visits += 1
        self.value += value

    def get_value(self):
        return self.value / self.visits

    def get_ucb(self, parent):
        return self.get_value() + math.sqrt(math.log(parent.visits) / self.visits)

class MCTS:
    def __init__(self, root_state):
        self.root = Node(root_state)

    def select(self, node):
        while node.children:
            node = max(node.children, key=lambda child: child.get_ucb(node))
        return node

    def expand(self, node):
        if node.state.is_terminal():
            return
        for action in node.state.get_actions():
            child_state = node.state.apply_action(action)
            child = Node(child_state, node)
            node.add_child(child)

    def simulate(self, node):
        state = node.state
        while not state.is_terminal():
            action = random.choice(state.get_actions())
            state = state.apply_action(action)
        return state.get_value()

    def backpropagate(self, node, value):
        while node:
            node.update(value)
            node = node.parent

    def run(self, iterations):
        for _ in range(iterations):
            node = self.root
            while True:
                node = self.select(node)
                if not node.children:
                    break
            self.expand(node)
            value = self.simulate(node)
            self.backpropagate(node, value)

    def get_action(self, state):
        node = self.root
        while True:
            node = self.select(node)
            if not node.children:
                break
        return node.children[0].state.get_action()

# Example usage:
class Game:
    def __init__(self):
        self.state = 'start'

    def get_actions(self):
        if self.state == 'start':
            return ['A', 'B']
        elif self.state == 'A':
            return ['C', 'D']
        elif self.state == 'B':
            return ['E', 'F']
        elif self.state == 'C':
            return ['G', 'H']
        elif self.state == 'D':
            return ['I', 'J']
        elif self.state == 'E':
            return ['K', 'L']
        elif self.state == 'F':
            return ['M', 'N']
        elif self.state == 'G':
            return ['O', 'P']
        elif self.state == 'H':
            return ['Q', 'R']
        elif self.state == 'I':
            return ['S', 'T']
        elif self.state == 'J':
            return ['U', 'V']
        elif self.state == 'K':
            return ['W', 'X']
        elif self.state == 'L':
            return ['Y', 'Z']
        elif self.state == 'M':
            return ['AA', 'AB']
        elif self.state == 'N':
            return ['AC', 'AD']
        elif self.state == 'O':
            return ['AE', 'AF']
        elif self.state == 'P':
            return ['AG', 'AH']
        elif self.state == 'Q':
            return ['AI', 'AJ']
        elif self.state == 'R':
            return ['AK', 'AL']
        elif self.state == 'S':
            return ['AM', 'AN']
        elif self.state == 'T':
            return ['AO', 'AP']
        elif self.state == 'U':
            return ['AQ', 'AR']
        elif self.state == 'V':
            return ['AS', 'AT']
        elif self.state == 'W':
            return ['AU', 'AV']
        elif self.state == 'X':
            return ['AW', 'AX']
        elif self.state == 'Y':
            return ['AY', 'AZ']
        elif self.state == 'Z':
            return ['BA', 'BB']
        elif self.state == 'AA':
            return ['BC', 'BD']
        elif self.state == 'AB':
            return ['BE', 'BF']
        elif self.state == 'AC':
            return ['BG', 'BH']
        elif self.state == 'AD':
            return ['BI', 'BJ']
        elif self.state == 'AE':
            return ['BK', 'BL']
        elif self.state == 'AF':
            return ['BM', 'BN']
        elif self.state == 'AG':
            return ['BO', 'BP']
        elif self.state == 'AH':
            return ['BQ', 'BR']
        elif self.state == 'AI':
            return ['BS', 'BT']
        elif self.state == 'AJ':
            return ['BU', 'BV']
        elif self.state == 'AK':
            return ['BW', 'BX']
        elif self.state == 'AL':
            return ['BY', 'BZ']
        elif self.state == 'AM':
            return ['CA', 'CB']
        elif self.state == 'AN':
            return ['CC', 'CD']
        elif self.state == 'AO':
            return ['CE', 'CF']
        elif self.state == 'AP':
            return ['CG', 'CH']
        elif self.state == 'AQ':
            return ['CI', 'CJ']
        elif self.state == 'AR':
            return ['CK', 'CL']
        elif self.state == 'AS':
            return ['CM', 'CN']
        elif self.state == 'AT':
            return ['CO', 'CP']
        elif self.state == 'AU':
            return ['CQ', 'CR']
        elif self.state == 'AV':
            return ['CS', 'CT']
        elif self.state == 'AW':
            return ['CU', 'CV']
        elif self.state == 'AX':
            return ['CW', 'CX']
        elif self.state == 'AY':
            return ['CY', 'CZ']
        elif self.state == 'AZ':
            return ['DA', 'DB']
        elif self.state == 'BA':
            return ['DC', 'DD']
        elif self.state == 'BB':
            return ['DE', 'DF']
        elif self.state == 'BC':
            return ['DG', 'DH']
        elif self.state == 'BD':
            return ['DI', 'DJ']
        elif self.state == 'BE':
            return ['DK', 'DL']
        elif self.state == 'BF':
            return ['DM', 'DN']
        elif self.state == 'BG':
            return ['DO', 'DP']
        elif self.state == 'BH':
            return ['DQ', 'DR']
        elif self.state == 'BI':
            return ['DS', 'DT']
        elif self.state == 'BJ':
            return ['DU', 'DV']
        elif self.state == 'BK':
            return ['DW', 'DX']
        elif self.state == 'BL':
            return ['DY', 'DZ']
        elif self.state == 'BM':
            return ['EA', 'EB']
        elif self.state == 'BN':
            return ['EC', 'ED']
        elif self.state == 'BO':
            return ['EE', 'EF']
        elif self.state == 'BP':
            return ['EG', 'EH']
        elif self.state == 'BQ':
            return ['EI', 'EJ']
        elif self.state == 'BR':
            return ['EK', 'EL']
        elif self.state == 'BS':
            return ['EM', 'EN']
        elif self.state == 'BT':
            return ['EO', 'EP']
        elif self.state == 'BU':
            return ['EQ', 'ER']
        elif self.state == 'BV':
            return ['ES', 'ET']
        elif self.state == 'BW':
            return ['EU', 'EV']
        elif self.state == 'BX':
            return ['EW', 'EX']
        elif self.state == 'BY':
            return ['EY', 'EZ']
        elif self.state == 'BZ':
            return ['FA', 'FB']
        elif self.state == 'BA':
            return ['FC', 'FD']
        elif self.state == 'BB':
            return ['FE', 'FF']
        elif self.state == 'BC':
            return ['FG', 'FH']
        elif self.state == 'BD':
            return ['FI', 'FJ']
        elif self.state == 'BE':
            return ['FK', 'FL']
        elif self.state == 'BF':
            return ['FM', 'FN']
        elif self.state == 'BG':
            return ['FO', 'FP']
        elif self.state == 'BH':
            return ['FQ', 'FR']
        elif self.state == 'BI':
            return ['FS', 'FT']
        elif self.state == 'BJ':
            return ['FU', 'FV']
        elif self.state == 'BK':
            return ['FW', 'FX']
        elif self.state == 'BL':
            return ['FY', 'FZ']
        elif self.state == 'BM':
            return ['GA', 'GB']
        elif self.state == 'BN':
            return ['GC', 'GD']
        elif self.state == 'BO':
            return ['GE', 'GF']
        elif self.state == 'BP':
            return ['GG', 'GH']
        elif self.state == 'BQ':
            return ['GI', 'GJ']
        elif self.state == 'BR':
            return ['GK', 'GL']
        elif self.state == 'BS':
            return ['GM', 'GN']
        elif self.state == 'BT':
            return ['GO', 'GP']
        elif self.state == 'BU':
            return ['GQ', 'GR']
        elif self.state == 'BV':
            return ['GS', 'GT']
        elif self.state == 'BW':
            return ['GU', 'GV']
        elif self.state == 'BX':
            return ['GW', 'GX']
        elif self.state == 'BY':
            return ['GY', 'GZ']
        elif self.state == 'BZ':
            return ['HA', 'HB']
        elif self.state == 'BA':
            return ['HC', 'HD']
        elif self.state == 'BB':
            return ['HE', 'HF']
        elif self.state == 'BC':
            return ['HG', 'HH']
        elif self.state == 'BD':
            return ['HI', 'HJ']
        elif self.state == 'BE':
            return ['HK', 'HL']
        elif self.state == 'BF':
            return ['HM', 'HN']
        elif self.state == 'BG':
            return ['HO', 'HP']
        elif self.state == 'BH':
            return ['HQ', 'HR']
        elif self.state == 'BI':
            return ['HS', 'HT']
        elif self.state == 'BJ':
            return ['HU', 'HV']
        elif self.state == 'BK':
            return ['HW', 'HX']
        elif self.state == 'BL':
            return ['HY', 'HZ']
        elif self.state == 'BM':
            return ['IA', 'IB']
        elif self.state == 'BN':
            return ['IC', 'ID']
        elif self.state == 'BO':
            return