# MAO Reinforcement Learning – **Classical DQN**
This notebook contains a fully‑working reference implementation of the MAO environment, two rule‑based agents, and a Deep Q‑Network learner based on PyTorch.  Run all cells top‑to‑bottom to train the agent; adjust hyper‑parameters at the bottom.

In [2]:
import random, math, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from collections import deque, namedtuple
# ---- Card helpers --------------------------------------------------
SUITS = "♠♥♦♣"
RANKS = "A23456789TJQK"
CARD2IDX = {f"{r}{s}": i for i, (r, s) in enumerate((r, s) for r in RANKS for s in SUITS)}
IDX2CARD = {v: k for k, v in CARD2IDX.items()}

def rank(idx): return RANKS[idx // 4]
def suit(idx): return SUITS[idx % 4]


In [3]:
# ---- MAO Environment ---------------------------------------------
class MaoEnv:
    def __init__(self, n_players=3, penalty_draw=1):
        self.n_players, self.penalty_draw = n_players, penalty_draw
        self.state_size = 110
        self.action_space = 52  # play card idx or draw (we'll treat 'draw' as action=52)
        self.reset()
    # ...............................................................
    def reset(self, seed=None):
        if seed is not None: random.seed(seed)
        deck = list(range(52)); random.shuffle(deck)
        self.hands = [deque(deck[i*7:(i+1)*7]) for i in range(self.n_players)]
        self.draw_pile = deque(deck[7*self.n_players:-1])
        self.discard   = [deck[-1]]
        self.current_player = 0
        self.direction = 1
        self.skip_next = False
        return self._obs()
    # ...............................................................
    def _obs(self):
        pid = self.current_player
        hand_vec = np.zeros(52); hand_vec[list(self.hands[pid])] = 1
        top_vec  = np.zeros(52); top_vec[self.discard[-1]] = 1
        counts   = np.array([len(h)/20 for i,h in enumerate(self.hands) if i!=pid])
        turn     = np.eye(self.n_players)[pid]
        dirflag  = np.array([int(self.direction==-1)])
        return np.concatenate([hand_vec, top_vec, counts, turn, dirflag])
    # ...............................................................
    def legal_actions(self, pid):
        top = self.discard[-1]
        return [c for c in self.hands[pid] if rank(c)==rank(top) or suit(c)==suit(top)]
    # ...............................................................
    def step(self, action):
        reward, done = 0.0, False
        pid = self.current_player
        if action == 52 or action not in self.legal_actions(pid):
            self._penalize(pid)
            reward = -1
        else:
            self.hands[pid].remove(action)
            self.discard.append(action)
            reward = +1
            r = rank(action)
            if r == 'A': self.skip_next = True
            elif r == 'K': self.direction *= -1
            if len(self.hands[pid]) == 0:
                done, reward = True, 10
        if not done:
            self._advance()
        return self._obs(), reward, done, {}
    # ...............................................................
    def _penalize(self, pid, n=None):
        n = n or self.penalty_draw
        for _ in range(n):
            if self.draw_pile: self.hands[pid].append(self.draw_pile.popleft())
    # ...............................................................
    def _advance(self):
        step = self.direction
        nxt  = (self.current_player + step) % self.n_players
        if self.skip_next:
            nxt = (nxt + step) % self.n_players
            self.skip_next = False
        self.current_player = nxt


In [4]:
# ---- Replay Buffer ------------------------------------------------
Transition = namedtuple('T', ('s','a','r','s2','d'))
class ReplayBuffer:
    def __init__(self, cap=100_000): self.cap=cap; self.buf=deque(maxlen=cap)
    def push(self,*args): self.buf.append(Transition(*args))
    def sample(self, k):  return random.sample(self.buf, k)
    def __len__(self):    return len(self.buf)


In [5]:
# ---- DQN Agent ----------------------------------------------------
class QNet(nn.Module):
    def __init__(self, state, actions):
        super().__init__()
        self.fc1 = nn.Linear(state,256)
        self.fc2 = nn.Linear(256,256)
        self.out = nn.Linear(256,actions)
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

class DQNAgent:
    def __init__(self, state_size, action_size, lr=1e-4, gamma=0.99,
                 eps_start=1.0, eps_end=0.1, eps_decay=5e-5):
        self.q = QNet(state_size, action_size)
        self.target = QNet(state_size, action_size)
        self.target.load_state_dict(self.q.state_dict())
        self.opt = torch.optim.Adam(self.q.parameters(), lr=lr)
        self.mem = ReplayBuffer()
        self.gamma, self.eps, self.eps_end, self.eps_decay = gamma, eps_start, eps_end, eps_decay
        self.steps = 0
    # ...............................................................
    def act(self, obs, legal):
        self.steps += 1
        self.eps = max(self.eps_end, self.eps - self.eps_decay)
        if random.random() < self.eps:
            return random.choice(legal) if legal else 52
        with torch.no_grad():
            qvals = self.q(torch.tensor(obs, dtype=torch.float32))
            # mask illegal
            mask = torch.full_like(qvals, -1e9)
            mask[legal] = 0
            qvals = qvals + mask
            return int(torch.argmax(qvals).item())
    # ...............................................................
    def learn(self, batch=256):
        if len(self.mem) < batch: return
        tr = Transition(*zip(*self.mem.sample(batch)))
        s  = torch.tensor(np.stack(tr.s), dtype=torch.float32)
        a  = torch.tensor(tr.a).unsqueeze(1)
        r  = torch.tensor(tr.r, dtype=torch.float32)
        s2 = torch.tensor(np.stack(tr.s2), dtype=torch.float32)
        d  = torch.tensor(tr.d, dtype=torch.float32)
        q_pred = self.q(s).gather(1,a).squeeze()
        with torch.no_grad():
            q_next = self.target(s2).max(1)[0]
            y = r + self.gamma * q_next * (1-d)
        loss = F.mse_loss(q_pred, y)
        self.opt.zero_grad(); loss.backward(); self.opt.step()
        if self.steps % 500 == 0:
            self.target.load_state_dict(self.q.state_dict())


In [None]:
# ---- Training Loop -----------------------------------------------
env = MaoEnv()
agent = DQNAgent(env.state_size, 53)   # 52 cards + draw
EPISODES = 5_000
for ep in range(EPISODES):
    obs = env.reset()
    done = False
    while not done:
        legal = env.legal_actions(env.current_player)
        action = agent.act(obs, legal)
        next_obs, reward, done, _ = env.step(action)
        agent.mem.push(obs, action, reward, next_obs, done)
        agent.learn()
        obs = next_obs
    if ep % 250 == 0: print(f"Episode {ep}, ε={agent.eps:.2f}")
print("Training finished.")

Episode 0, ε=1.00
