<a href="https://colab.research.google.com/github/newmantic/Q_Learning/blob/main/Q_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import random

class QLearningAgent:
    def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.n_states = n_states
        self.n_actions = n_actions
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.q_table = np.zeros((n_states, n_actions))  # Initialize Q-table with zeros

    def choose_action(self, state):
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(range(self.n_actions))  # Explore: random action
        else:
            return np.argmax(self.q_table[state, :])  # Exploit: action with max Q-value

    def update(self, state, action, reward, next_state):
        best_next_action = np.argmax(self.q_table[next_state, :])
        td_target = reward + self.gamma * self.q_table[next_state, best_next_action]
        td_error = td_target - self.q_table[state, action]
        self.q_table[state, action] += self.alpha * td_error

    def get_q_table(self):
        return self.q_table

In [2]:
class GridEnvironment:
    def __init__(self, grid_size=(5, 5), start_state=(0, 0), goal_state=(4, 4)):
        self.grid_size = grid_size
        self.start_state = start_state
        self.goal_state = goal_state
        self.state = start_state

    def reset(self):
        self.state = self.start_state
        return self._get_state_index(self.state)

    def step(self, action):
        x, y = self.state
        if action == 0:  # up
            x = max(0, x - 1)
        elif action == 1:  # down
            x = min(self.grid_size[0] - 1, x + 1)
        elif action == 2:  # left
            y = max(0, y - 1)
        elif action == 3:  # right
            y = min(self.grid_size[1] - 1, y + 1)

        self.state = (x, y)
        reward = 1 if self.state == self.goal_state else -1
        done = self.state == self.goal_state
        return self._get_state_index(self.state), reward, done

    def _get_state_index(self, state):
        return state[0] * self.grid_size[1] + state[1]

In [3]:
def train_q_learning_agent():
    env = GridEnvironment()
    n_states = env.grid_size[0] * env.grid_size[1]
    n_actions = 4  # up, down, left, right
    agent = QLearningAgent(n_states, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1)

    n_episodes = 1000
    max_steps = 100

    for episode in range(n_episodes):
        state = env.reset()
        for step in range(max_steps):
            action = agent.choose_action(state)
            next_state, reward, done = env.step(action)
            agent.update(state, action, reward, next_state)
            state = next_state
            if done:
                break

    return agent

# Train the agent and print the Q-table
q_agent = train_q_learning_agent()
print("Q-Table after training:")
print(q_agent.get_q_table())

Q-Table after training:
[[-6.65911812 -5.86127292 -6.65292129 -5.86126876]
 [-5.5848814  -4.91040888 -6.17007447 -4.9104042 ]
 [-4.35109491 -3.94993745 -4.59279438 -3.94993862]
 [-3.27692686 -2.97977144 -3.42833482 -2.97976824]
 [-2.24416921 -1.99979915 -2.53769354 -2.25008515]
 [-5.92644641 -4.91041623 -5.3518503  -4.91041329]
 [-4.92473888 -3.94998548 -4.51099641 -3.94998416]
 [-3.96694546 -2.97979774 -3.83900484 -2.97979793]
 [-2.79360396 -1.99980037 -2.75761127 -1.99980042]
 [-1.98385788 -1.0099     -2.14265483 -1.68590391]
 [-4.77181011 -3.94995835 -4.37068715 -3.94995507]
 [-4.05950036 -2.97979977 -3.75109316 -2.97979942]
 [-3.20213917 -1.99980071 -2.47652596 -1.99980068]
 [-2.26478406 -1.00989999 -1.87139936 -1.00989999]
 [-1.16024816 -0.01       -1.3233302  -0.71928029]
 [-3.75638273 -2.97985191 -3.41157967 -2.97979651]
 [-2.76582827 -1.99997796 -3.04961108 -1.999801  ]
 [-1.94791025 -1.01251205 -2.15638366 -1.0099    ]
 [-1.34099146 -0.08843837 -1.5807369  -0.01      ]
 [-0.95