In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RLNet(nn.Module):
  def __init__(self):
    super(RLNet, self).__init__()

    # Activation
    self.gelu = nn.GELU(approximate='tanh')

    # ConvLayer1
    self.conv_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7,7), stride=(2,2), padding=(3,3))
    self.maxpool_1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
    self.batchnorm_1 = nn.BatchNorm2d(64)

    # ConvLayer2
    self.conv_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), stride=(2,2), padding=(1,1))
    self.batchnorm_2 = nn.BatchNorm2d(128)
    self.maxpool_2 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))

    # ConvLayer3
    self.conv_3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(5,5), stride=(1,1), padding=(3,3))
    self.maxpool_3 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
    self.batchnorm_3 = nn.BatchNorm2d(256)

    # ConvLayer4
    self.conv_4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3,3), stride=(2,2), padding=(3,3))
    self.maxpool_4 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
    self.batchnorm_4 = nn.BatchNorm2d(512)

    # ConvLayer5
    self.conv_5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), stride=(2,2), padding=(3,3))
    self.maxpool_5 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
    self.batchnorm_5 = nn.BatchNorm2d(512)

    # Deep Q-Network
    self.MLP_1 = nn.Linear(in_features=2048, out_features=512)
    self.MLP_2 = nn.Linear(in_features=512, out_features=128)
    self.MLP_3 = nn.Linear(in_features=128, out_features=6)

  def forward(self, x):
    # ConvLayer1
    x = self.conv_1(x)
    x = self.gelu(x)
    x = self.maxpool_1(x)
    x = self.batchnorm_1(x)

    # ConvLayer2
    x = self.conv_2(x)
    x = self.gelu(x)
    x = self.batchnorm_2(x)
    x = self.maxpool_2(x)

    # ConvLayer3
    x = self.conv_3(x)
    x = self.gelu(x)
    x = self.maxpool_3(x)
    x = self.batchnorm_3(x)

    # ConvLayer4
    x = self.conv_4(x)
    x = self.gelu(x)
    x = self.batchnorm_4(x)
    x = self.maxpool_4(x)

    # ConvLayer5
    x = self.conv_5(x)
    x = self.maxpool_5(x)
    x = self.batchnorm_5(x)

    # Flatten
    x = torch.flatten(x)

    # DQN
    x = self.MLP_1(x)
    x = self.gelu(x)
    x = self.MLP_2(x)
    x = self.gelu(x)
    x = self.MLP_3(x)

    return x

In [None]:
# Running the model
import cv2
import torch
import torchvision.transforms as transforms

image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (244, 244))

# Convert to PyTorch tensor
transform = transforms.ToTensor()
tensor = transform(image)
tensor = tensor.unsqueeze(0)

model = RLNet()

model(tensor)

tensor([-0.0108,  0.0822, -0.0107,  0.0171,  0.0339, -0.0750],
       grad_fn=<ViewBackward0>)

In [None]:
from torchsummary import summary

input_shape = (3,244,244)
summary(RLNet(), input_shape)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x4096 and 2048x512)

In [None]:
import numpy as np
import random
from collections import deque
import torch.optim as optim

# Define DQN Agent with Experience Replay Buffer
class DQNAgent:
    def __init__(self, gamma, lr, epsilon, epsilon_decay, buffer_size):
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.memory = deque(maxlen=buffer_size)
        self.model = RLNet()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return np.random.choice(6)
        q_values = self.model(torch.tensor(state, dtype=torch.float32))
        return torch.argmax(q_values).item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                target = reward + self.gamma * torch.max(self.model(torch.tensor(next_state, dtype=torch.float32))).item()
            target_f = self.model(torch.tensor(state, dtype=torch.float32)).numpy()
            target_f[action] = target
            self.optimizer.zero_grad()
            loss = nn.MSELoss()(torch.tensor(target_f), self.model(torch.tensor(state, dtype=torch.float32)))
            loss.backward()
            self.optimizer.step()
        if self.epsilon > 0.01:
            self.epsilon *= self.epsilon_decay

In [None]:
agent = DQNAgent(lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, buffer_size=10000)

# done is the reset condition, either reached center or out of distribution

# Train the DQN agent with Experience Replay Buffer
batch_size = 32
num_episodes = 1000
for episode in range(num_episodes):
    state = env.reset() # reset robot to new state (x, y, z, theta, phi, psi) -> get new image -> the image is the state
    total_reward = 0
    done = False
    while not done:
        action = agent.act(state) # action should be the change in (x, y, z, theta, phi, psi)
        next_state, reward, done, _ = env.step(action) # perform the action using Moveit2 -> and calc new_state, reward and if done or not
        agent.remember(state, action, reward, next_state, done) # save in buffer
        state = next_state
        total_reward += reward
        agent.replay(batch_size)
    print(f"Episode: {episode + 1}, Total Reward: {total_reward}")




In [None]:
# Saving Model
torch.save(agent.model.state_dict(), "model.pth")

In [None]:
# Loading Model
model_load = RLNet()
model_load.load_state_dict(torch.load("model.pth", weights_only=True))
model_load.eval()

OrderedDict([('conv_1.weight',
              tensor([[[[-0.0716, -0.0078, -0.0580,  ...,  0.0032, -0.0812, -0.0240],
                        [ 0.0042,  0.0168, -0.0161,  ..., -0.0688, -0.0464,  0.0359],
                        [-0.0326, -0.0207,  0.0626,  ..., -0.0154, -0.0484,  0.0480],
                        ...,
                        [ 0.0219,  0.0423,  0.0745,  ...,  0.0711,  0.0249, -0.0476],
                        [-0.0792,  0.0580, -0.0089,  ..., -0.0554, -0.0750,  0.0387],
                        [ 0.0595, -0.0788,  0.0182,  ..., -0.0396,  0.0347, -0.0737]],
              
                       [[-0.0743, -0.0790,  0.0371,  ...,  0.0412,  0.0120,  0.0451],
                        [-0.0344, -0.0586, -0.0211,  ...,  0.0275,  0.0165, -0.0196],
                        [-0.0300, -0.0497,  0.0482,  ..., -0.0114, -0.0730, -0.0071],
                        ...,
                        [-0.0584, -0.0332, -0.0739,  ..., -0.0612, -0.0277, -0.0818],
                        [ 0.0026, -