### Working 

In [8]:
import gym
import gym_sokoban
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt

# Define the DQN Network
class DQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(self.feature_size(input_shape), 512)
        self.fc2 = nn.Linear(512, num_actions)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

    def feature_size(self, input_shape):
        return self.conv3(self.conv2(self.conv1(torch.zeros(1, *input_shape)))).view(1, -1).size(1)

# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return torch.stack(state), torch.tensor(action), torch.tensor(reward), torch.stack(next_state), torch.tensor(done)

    def __len__(self):
        return len(self.buffer)

# Helper Functions
def preprocess_state(state):
    state = np.transpose(state, (2, 0, 1))  # Change the shape from (H, W, C) to (C, H, W)
    return torch.tensor(state, dtype=torch.float32)

def compute_td_loss(batch_size):
    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state = state.to(device)
    next_state = next_state.to(device)
    action = action.to(device)
    reward = reward.to(device)
    done = done.to(device)

    q_values = model(state)
    next_q_values = model(next_state)

    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    next_q_value = next_q_values.max(1)[0]
    expected_q_value = reward + gamma * next_q_value * (1 - done)

    loss = (q_value - expected_q_value.detach()).pow(2).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss

# Main Training Loop
env = gym.make('Sokoban-small-v1')
input_shape = (env.observation_space.shape[2], env.observation_space.shape[0], env.observation_space.shape[1])
num_actions = env.action_space.n

print("Input shape:", input_shape)  # Print the input shape to understand its dimensions

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = DQN(input_shape, num_actions).to(device)
optimizer = optim.Adam(model.parameters())
replay_buffer = ReplayBuffer(10000)
gamma = 0.99
batch_size = 32
epsilon_start = 1.0
epsilon_final = 0.01
epsilon_decay = 10000

epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)

num_frames = 1000000
losses = []
all_rewards = []
episode_reward = 0
max_reward = -float('inf')
early_stop_threshold = 100  # Number of episodes to check for convergence
convergence_reward_threshold = 0.9  # Threshold reward to consider as convergence

state = preprocess_state(env.reset()).unsqueeze(0).to(device)

for frame_idx in range(1, num_frames + 1):
    epsilon = epsilon_by_frame(frame_idx)
    if random.random() > epsilon:
        with torch.no_grad():
            q_value = model(state)
            action = q_value.max(1)[1].item()
    else:
        action = env.action_space.sample()

    next_state, reward, done, _ = env.step(action)
    next_state = preprocess_state(next_state).unsqueeze(0).to(device)
    replay_buffer.push(state.cpu(), action, reward, next_state.cpu(), done)

    state = next_state
    episode_reward += reward

    if done:
        state = preprocess_state(env.reset()).unsqueeze(0).to(device)
        all_rewards.append(episode_reward)
        episode_reward = 0

        # Early stopping check
        if len(all_rewards) > early_stop_threshold:
            mean_reward = np.mean(all_rewards[-early_stop_threshold:])
            if mean_reward > convergence_reward_threshold:
                print(f'Converged at frame {frame_idx}, mean reward: {mean_reward}')
                break

    if len(replay_buffer) > batch_size:
        loss = compute_td_loss(batch_size)
        losses.append(loss.item())

    if frame_idx % 10000 == 0:
        print(f'Frame: {frame_idx}, Reward: {np.mean(all_rewards[-10:])}')

# Save the trained model
torch.save(model.state_dict(), 'dqn_sokoban_model.pth')

plt.figure(figsize=(20, 5))
plt.subplot(131)
plt.title('Loss')
plt.plot(losses)
plt.subplot(132)
plt.title('Reward')
plt.plot(all_rewards)
plt.show()

# Loading the saved model
loaded_model = DQN(input_shape, num_actions).to(device)
loaded_model.load_state_dict(torch.load('dqn_sokoban_model.pth'))
loaded_model.eval()

# Example of using the loaded model
state = preprocess_state(env.reset()).unsqueeze(0).to(device)
done = False
while not done:
    with torch.no_grad():
        q_value = loaded_model(state)
        action = q_value.max(1)[1].item()
    state, reward, done, _ = env.step(action)
    state = preprocess_state(state).unsqueeze(0).to(device)
    env.render()


[[0 0 0 0 0 0 0]
 [0 1 1 2 0 0 0]
 [0 1 1 1 0 0 0]
 [0 1 1 1 0 0 0]
 [0 1 1 2 2 0 0]
 [0 1 1 1 1 0 0]
 [0 0 0 0 0 0 0]] [[0 0 0 0 0 0 0]
 [0 1 1 2 0 0 0]
 [0 1 4 1 0 0 0]
 [0 5 4 1 0 0 0]
 [0 1 4 2 2 0 0]
 [0 1 1 1 1 0 0]
 [0 0 0 0 0 0 0]] {(1, 3): (3, 2), (4, 3): (4, 2), (4, 4): (2, 2)}
Input shape: (3, 112, 112)
[[0 0 0 0 0 0 0]
 [0 1 1 2 0 0 0]
 [0 1 1 1 0 0 0]
 [0 1 1 1 0 0 0]
 [0 1 1 2 2 0 0]
 [0 1 1 1 1 0 0]
 [0 0 0 0 0 0 0]] [[0 0 0 0 0 0 0]
 [0 1 1 2 0 0 0]
 [0 1 4 1 0 0 0]
 [0 5 4 1 0 0 0]
 [0 1 4 2 2 0 0]
 [0 1 1 1 1 0 0]
 [0 0 0 0 0 0 0]] {(1, 3): (3, 2), (4, 3): (4, 2), (4, 4): (2, 2)}


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 1, 3, 112, 112]