# RL implementation of ATLAS

In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import pandas as pd

In [12]:
class ATLASGame:
    def __init__(self, places):
        self.places = sorted([p.upper() for p in places])
        self.reset()

    def reset(self):
        self.used_places = set()
        self.current_place = None
        self.current_player = 0  # 0 = agent, 1 = opponent
        self.game_history = []
        return self

    def get_valid_moves(self):
        if self.current_place is None:
            return [p for p in self.places if p not in self.used_places]
        last_letter = self.current_place[-1].upper()  # Added .upper()
        return [p for p in self.places if p not in self.used_places and p[0].upper() == last_letter]  # Added .upper()

    def make_move(self, place):
        place = place.strip().upper()
        if place not in self.places:
            return False, f"{place} is invalid"
        if place in self.used_places:
            return False, f"{place} already used"

        valid_moves = self.get_valid_moves()
        if self.current_place is not None and place not in valid_moves:
            return False, f"{place} not valid now"

        self.used_places.add(place)
        self.current_place = place
        self.game_history.append((self.current_player, place))

        # Switch players
        self.current_player = 1 - self.current_player
        return True, "Move accepted"

    def is_terminal(self):
        return len(self.get_valid_moves()) == 0

In [13]:
class AtlasStateEncoder:
    def __init__(self, game):
        self.game = game
        self.n_places = len(game.places)
        self.letter_to_idx = {chr(i+65): i for i in range(26)}  # A-Z

    def encode(self):
        state = []

        # 1. Next required letter (one-hot)
        letter_vec = np.zeros(26, dtype=np.float32)
        if self.game.current_place:
            last_letter = self.game.current_place[-1].upper()
            letter_vec[self.letter_to_idx[last_letter]] = 1.0
        state.extend(letter_vec)

        # 2. Places used (binary)
        used_vec = np.array([1.0 if p in self.game.used_places else 0.0 for p in self.game.places], dtype=np.float32)
        state.extend(used_vec)

        # 3. Valid moves mask (binary)
        valid_moves = self.game.get_valid_moves()
        valid_vec = np.array([1.0 if p in valid_moves else 0.0 for p in self.game.places], dtype=np.float32)
        state.extend(valid_vec)

        return np.array(state, dtype=np.float32)


In [14]:
class AtlasEnv:
    def __init__(self, places):
        self.game = ATLASGame(places)
        self.encoder = AtlasStateEncoder(self.game)
        self.n_actions = len(places)

    def reset(self):
        self.game.reset()
        return self.encoder.encode()

    def step(self, action_idx):
        done = False
        reward = -0.01  # step penalty

        place = self.game.places[action_idx]
        success, msg = self.game.make_move(place)
        if not success:
            return self.encoder.encode(), -1.0, True

        if self.game.is_terminal():
            return self.encoder.encode(), 1.0, True

        # Opponent random move
        if self.game.current_player == 1:
            opp_moves = self.game.get_valid_moves()
            if opp_moves:
                opp_place = random.choice(opp_moves)
                self.game.make_move(opp_place)

        if self.game.is_terminal():
            return self.encoder.encode(), -0.5, True

        return self.encoder.encode(), reward, done


In [15]:
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, x):
        return self.net(x)


In [16]:
class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s2, d):
        self.buffer.append((s, a, r, s2, d))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s2, d = zip(*batch)
        return (
            torch.from_numpy(np.array(s)).float(),
            torch.from_numpy(np.array(a)).long(),
            torch.from_numpy(np.array(r)).float(),
            torch.from_numpy(np.array(s2)).float(),
            torch.from_numpy(np.array(d)).float()
        )

    def __len__(self):
        return len(self.buffer)


In [17]:
def save_checkpoint(path, model, target, optimizer, episode, epsilon):
    torch.save({
        "model": model.state_dict(),
        "target": target.state_dict(),
        "optimizer": optimizer.state_dict(),
        "episode": episode,
        "epsilon": epsilon,
    }, path)


def load_checkpoint(path, model, target, optimizer, device):
    checkpoint = torch.load(path, map_location=device)

    model.load_state_dict(checkpoint["model"])
    target.load_state_dict(checkpoint["target"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    start_episode = checkpoint["episode"] + 1
    epsilon = checkpoint["epsilon"]

    return start_episode, epsilon


In [19]:
def train():
    CKPT_PATH = "checkpoints/atlas_dqn_checkpoint.pt"
    places = pd.read_csv("../../data/countries.csv")["Country"].values.tolist()

    env = AtlasEnv(places)
    state_dim = len(env.reset())
    action_dim = env.n_actions

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    model = DQN(state_dim, action_dim).to(device)
    target = DQN(state_dim, action_dim).to(device)
    target.load_state_dict(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    buffer = ReplayBuffer()
    gamma = 0.99
    batch_size = 64

    epsilon = 1.0
    epsilon_min = 0.005
    epsilon_decay = 0.9995
    start_episode = 1

    if os.path.exists(CKPT_PATH):
        start_episode, epsilon = load_checkpoint(CKPT_PATH, model, target, optimizer, device)
        print(f"Resuming from episode {start_episode}, epsilon={epsilon:.3f}")

    for episode in range(start_episode, start_episode + 501):
        state = env.reset()
        done = False
        total_reward = 0

        while not done:
            if random.random() < epsilon:
                action = random.randrange(action_dim)
            else:
                with torch.no_grad():
                    q = model(torch.tensor(state, dtype=torch.float32, device=device))
                    valid_mask = torch.tensor(state[-action_dim:], device=device)
                    q[valid_mask == 0] = -1e9
                    action = torch.argmax(q).item()

            next_state, reward, done = env.step(action)
            buffer.push(state, action, reward, next_state, done)

            state = next_state
            total_reward += reward

            if len(buffer) >= batch_size:
                s, a, r, s2, d = buffer.sample(batch_size)
                s = s.to(device, dtype=torch.float32)
                a = a.to(device)
                r = r.to(device)
                s2 = s2.to(device, dtype=torch.float32)
                d = d.to(device)

                q_vals = model(s).gather(1, a.unsqueeze(1)).squeeze(1)
                with torch.no_grad():
                    q2 = target(s2).max(1)[0]
                    target_q = r + gamma * q2 * (1 - d.float())

                loss = F.smooth_l1_loss(q_vals, target_q)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        if episode % 100 == 0:
            target.load_state_dict(model.state_dict())
            print(f"Episode {episode}, reward {total_reward:.2f}, epsilon {epsilon:.3f}")

        if episode % 500 == 0:
            target.load_state_dict(model.state_dict())
            save_checkpoint(CKPT_PATH, model, target, optimizer, episode, epsilon)
            print(f"Checkpoint saved at episode {episode}")

    return model


In [20]:
trained_model = train()

Using device: cuda
Resuming from episode 19001, epsilon=0.005
Episode 19100, reward -0.58, epsilon 0.005
Episode 19200, reward 0.87, epsilon 0.005
Episode 19300, reward -0.59, epsilon 0.005
Episode 19400, reward -0.64, epsilon 0.005
Episode 19500, reward 0.89, epsilon 0.005
Checkpoint saved at episode 19500


In [30]:
def play_with_model(env, model, human_first=True):
    state = env.reset()
    done = False
    current_player = 0 if human_first else 1  # 0 = human, 1 = model
    last_played = None

    while not done:
        print(f"\nCurrent player: {'Human' if current_player == 0 else 'AI'}")
        
        if last_played:
            last_char = last_played[-1].upper()
            print(f"Previous place: {last_played}")
            print(f"Your place must START with: {last_char}")
        else:
            print("First move - any place is valid")
        
        if current_player == 0:
            # Human turn
            valid_moves = env.game.get_valid_moves()
            if last_played:
                print(f"\nValid moves starting with '{last_played[-1].upper()}': {valid_moves[:10]}" + (" ..." if len(valid_moves) > 10 else ""))
            else:
                print(f"\nValid moves: {valid_moves[:10]}" + (" ..." if len(valid_moves) > 10 else ""))
            
            move = input("Your move: ").strip().upper()
            if move not in valid_moves:
                print(f"Illegal move! '{move}' is not valid.")
                print("Game over - you lose!")
                done = True
                break
            
            action = env.game.places.index(move)
            next_state, reward, done = env.step(action)
            print(f"✓ You played: {move}")
            last_played = move
        else:
            # AI turn
            device = next(model.parameters()).device
            state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
            
            # Get valid moves from the game
            valid_moves = env.game.get_valid_moves()
            
            with torch.no_grad():
                q = model(state_tensor)
                valid_mask = torch.tensor(state[-env.n_actions:], dtype=torch.bool, device=device)
                
                # Only consider valid actions
                q_masked = q.clone()
                q_masked[~valid_mask] = float('-inf')
                action = torch.argmax(q_masked).item()
            
            ai_place = env.game.places[action]
            
            # Check if AI's move is actually valid
            if ai_place not in valid_moves:
                print(f"✓ AI plays: {ai_place}")
                print(f"Illegal move! AI played '{ai_place}' but it should start with '{last_played[-1].upper()}'")
                print("Game over - Human wins!")
                done = True
                break
            
            next_state, reward, done = env.step(action)
            print(f"✓ AI plays: {ai_place}")
            last_played = ai_place
            
            if done and reward < 0:
                print("AI made an illegal move - you win!")
                break

        state = next_state
        current_player = 1 - current_player

    print("\nGame over!")
    if last_played:
        print(f"Final place: {last_played}")

In [31]:
places = pd.read_csv("../../data/countries.csv")["Country"].values.tolist()
env = AtlasEnv(places)
play_with_model(env, trained_model, human_first=True)


Current player: Human
First move - any place is valid

Valid moves: ['AFGHANISTAN', 'ALBANIA', 'ALGERIA', 'ANDORRA', 'ANGOLA', 'ANTIGUA AND BARBUDA', 'ARGENTINA', 'ARMENIA', 'AUSTRALIA', 'AUSTRIA'] ...
✓ You played: AUSTRIA

Current player: AI
Previous place: AUSTRIA
Your place must START with: A
✓ AI plays: NORWAY

Current player: Human
Previous place: NORWAY
Your place must START with: Y

Valid moves starting with 'Y': ['NAMIBIA', 'NAURU', 'NEPAL', 'NETHERLANDS', 'NEW ZEALAND', 'NICARAGUA', 'NIGER', 'NIGERIA', 'NORTH KOREA', 'NORTH MACEDONIA']
Illegal move! 'OKAY' is not valid.
Game over - you lose!

Game over!
Final place: NORWAY
