### Imports and Setup

In [1]:
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

### Model Definition

In [2]:
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

### Training

In [3]:
def train_dqn(env, model, episodes=5000, learning_rate=0.001, discount_factor=0.95, exploration_prob=1.0, exploration_decay=0.995, min_exploration=0.05):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    for episode in range(episodes):
        state = env.reset()[0]
        terminated = False

        while not terminated:
            with torch.no_grad():
                if np.random.uniform(0, 1) < exploration_prob:
                    action = env.action_space.sample()  # Explore
                else:
                    q_values = model(torch.tensor([state], dtype=torch.float32))
                    action = torch.argmax(q_values).item()  # Exploit

            next_state, reward, terminated, _, _ = env.step(action)

            # Update model
            target = reward
            if not terminated:
                with torch.no_grad():
                    target = reward + discount_factor * torch.max(model(torch.tensor([next_state], dtype=torch.float32)))

            q_values = model(torch.tensor([state], dtype=torch.float32))
            loss = loss_fn(q_values[0][action], torch.tensor([target], dtype=torch.float32))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            state = next_state

        exploration_prob = max(min_exploration, exploration_prob * exploration_decay)

        if episode % 100 == 0:
            print(f"Episode {episode} finished.")

# Train
env = gym.make('Taxi-v3')
input_dim = env.observation_space.n
output_dim = env.action_space.n
model = DQN(input_dim, output_dim)
train_dqn(env, model)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1 and 500x6)