In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

import GridWorld_env
import gymnasium as gym
import random

In [2]:
class Conv_NN(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(Conv_NN, self).__init__()
        self.conv1 = nn.Conv3d(3, 64, kernel_size=3, stride=1, padding=1)
        self.batchnorm1 = nn.BatchNorm3d(64)
        self.conv2 = nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1)
        self.batchnorm2 = nn.BatchNorm3d(64)
        self.conv3 = nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1)
        self.batchnorm3 = nn.BatchNorm3d(64)
        
        self.fc1 = nn.Linear(64*input_dim * input_dim * input_dim, action_dim)
    def forward(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = F.relu(self.batchnorm2(self.conv2(x) + x))
        x = F.relu(self.batchnorm3(self.conv3(x) + x))
        x = nn.Flatten()(x)
        x = self.fc1(x)
        return x

new_env = gym.make('GridWorld_env/GridWorld', dimension_size=4)
test = Conv_NN(4, 7)
new_env.reset()
state, _, _, _, _ = new_env.step(0)
state = torch.tensor(state).float()

import torchinfo
torchinfo.summary(test, (8, 3, 4, 4, 4))

  logger.warn(
  logger.warn(
  logger.warn(


Layer (type:depth-idx)                   Output Shape              Param #
Conv_NN                                  [8, 7]                    --
├─Conv3d: 1-1                            [8, 64, 4, 4, 4]          5,248
├─BatchNorm3d: 1-2                       [8, 64, 4, 4, 4]          128
├─Conv3d: 1-3                            [8, 64, 4, 4, 4]          110,656
├─BatchNorm3d: 1-4                       [8, 64, 4, 4, 4]          128
├─Conv3d: 1-5                            [8, 64, 4, 4, 4]          110,656
├─BatchNorm3d: 1-6                       [8, 64, 4, 4, 4]          128
├─Linear: 1-7                            [8, 7]                    28,679
Total params: 255,623
Trainable params: 255,623
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 116.23
Input size (MB): 0.01
Forward/backward pass size (MB): 1.57
Params size (MB): 1.02
Estimated Total Size (MB): 2.60

In [51]:
def deep_Q_learning():
    # Initialised replay buffer
    replay_buffer = []
    gamma = 0.995
    # Intialised Q network
    Q_network = Conv_NN(4, 7)
    optimizer = optim.SGD(Q_network.parameters(), lr=0.01, momentum=0.9)
    EPISODES = 100
    # Loop over episodes
    for episode in range(EPISODES):
        
        # Reset environment
        new_env.reset()
        
        # Loop over steps
        for timestep in range(2000):
            # Get state
            Q_network.zero_grad()
            state = new_env.unwrapped.get_obs()
            state = torch.tensor(state).unsqueeze(0).float()
            
            q_values = Q_network(state)
            action = torch.argmax(q_values)
            max_q = torch.max(q_values)
            
            next_state, reward, done, _, _ = new_env.step(action)
            
            replay_buffer.append((state.squeeze(0), max_q, action, reward, next_state, done))
            
            minibatch = random.choices(replay_buffer, k=32)
            # Form minibatch
            batch_state = torch.stack([i[0] for i in minibatch])
            batch_max_q = torch.stack([i[1] for i in minibatch])
            batch_action = torch.stack([i[2] for i in minibatch])
            batch_reward = torch.stack([i[3] for i in minibatch])
            batch_next_state = torch.stack([torch.tensor(i[4]).float() for i in minibatch])
            batch_done = torch.stack([i[5] for i in minibatch])
            
            
            # Form target y by boostrapping
            target_y = batch_reward.float() + gamma * torch.max(Q_network(batch_next_state).detach(), dim=1)[0] * (1 - batch_done.float())
            target_y.requires_grad = False
            
            y_hat = Q_network(batch_state).max(1)[0]
            loss = F.mse_loss(y_hat, target_y)
            loss.backward()
            optimizer.step()
            
        

In [52]:
deep_Q_learning()

KeyboardInterrupt: 