In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


In [4]:
class MMD_DQN_Agent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, gamma=0.99, lr=0.001, 
                 batch_size=64, buffer_size=100000, tau=0.001, mmd_lambda=0.1):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.tau = tau
        self.mmd_lambda = mmd_lambda
        
        self.memory = deque(maxlen=buffer_size)
        
        self.policy_net = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.target_net = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        
        self.update_target_net()
    
    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
    
    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            return random.choice(np.arange(self.action_dim))
        else:
            state = torch.FloatTensor(state).unsqueeze(0).to(device)
            with torch.no_grad():
                q_values = self.policy_net(state)
            return q_values.argmax().item()
    
    def store_transition(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def sample_batch(self):
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.FloatTensor(dones).to(device)
        return states, actions, rewards, next_states, dones
    
    def compute_mmd_loss(self, q_values, target_q_values):
        batch_size = q_values.size(0)
        q_values = q_values.view(batch_size, -1)
        target_q_values = target_q_values.view(batch_size, -1)
        
        xx, yy, zz = torch.mm(q_values, q_values.t()), torch.mm(target_q_values, target_q_values.t()), torch.mm(q_values, target_q_values.t())
        rx = (xx.diag().unsqueeze(0).expand_as(xx))
        ry = (yy.diag().unsqueeze(0).expand_as(yy))
        
        dxx = rx.t() + rx - 2. * xx
        dyy = ry.t() + ry - 2. * yy
        dxy = rx.t() + ry - 2. * zz
        
        mmd = torch.exp(-0.5 * dxx).mean() + torch.exp(-0.5 * dyy).mean() - 2. * torch.exp(-0.5 * dxy).mean()
        return mmd
    
    def train(self, num_episodes=1000, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
        epsilon = epsilon_start
        for episode in range(num_episodes):
            state = env.reset()
            done = False
            total_reward = 0
            
            while not done:
                action = self.select_action(state, epsilon)
                next_state, reward, done, _ = env.step(action)
                self.store_transition(state, action, reward, next_state, done)
                state = next_state
                total_reward += reward
                
                if len(self.memory) >= self.batch_size:
                    self.learn()
                    
                if done:
                    print(f"Episode {episode + 1}, Total Reward: {total_reward}")
                    break
            
            epsilon = max(epsilon_end, epsilon_decay * epsilon)
            if episode % 10 == 0:
                self.update_target_net()
    
    def test(self, num_episodes=10,epsilon_end=0.01):

        
        epsilon = epsilon_end
        for episode in range(num_episodes):
            state = env.reset()
            done = False
            total_reward = 0
            
            while not done:
                action = self.select_action(state,epsilon)
                next_state, reward, done, _ = env.step(action)
                self.store_transition(state, action, reward, next_state, done)
                state = next_state
                total_reward += reward
                
                
                    
                if done:
                    print(f"Episode {episode + 1}, Total Reward: {total_reward}")
                    break
            
    def learn(self):
        states, actions, rewards, next_states, dones = self.sample_batch()
        
        q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            max_next_q_values = self.target_net(next_states).max(1)[0]
            target_q_values = rewards + (1 - dones) * self.gamma * max_next_q_values
        
        mmd_loss = self.compute_mmd_loss(q_values, target_q_values)
        loss = nn.MSELoss()(q_values, target_q_values) + self.mmd_lambda * mmd_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    def save(self, filepath):
        torch.save(self.policy_net.state_dict(), filepath)
    
    def load(self, filepath):
        self.policy_net.load_state_dict(torch.load(filepath))
        self.update_target_net()

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Environment
env = gym.make('LunarLander-v2',render_mode='human')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Agent
agent = MMD_DQN_Agent(state_dim, action_dim)

agent.load('mmd_dqn_lunar_lander.pth')
agent.test()
env.close()

  deprecation(
  deprecation(
  if not isinstance(terminated, (bool, np.bool8)):


Episode 1, Total Reward: 238.97379897785518
Episode 2, Total Reward: 242.79604109939947
Episode 3, Total Reward: 233.17176969274757
Episode 4, Total Reward: 223.60623937338454
Episode 5, Total Reward: 192.90729660138348
Episode 6, Total Reward: 236.17795370044072
Episode 7, Total Reward: 273.44636343669504
Episode 8, Total Reward: 247.31269793905588
Episode 9, Total Reward: 312.87615338018776
Episode 10, Total Reward: 267.93871862286846


: 