In [None]:
import os

BASE_DIR = "/kaggle/input/threes-binary"

def find_latest_checkpoint(base_dir, filename="threes_pro_latest.pth"):
    versions = []

    # duy·ªát to√†n b·ªô c√¢y th∆∞ m·ª•c
    for root, dirs, files in os.walk(base_dir):
        if filename in files:
            # l·∫•y version t·ª´ path: .../default/<version>/
            parts = root.split(os.sep)
            for p in reversed(parts):
                if p.isdigit():
                    versions.append((int(p), os.path.join(root, filename)))
                    break

    if not versions:
        raise FileNotFoundError(f"Kh√¥ng t√¨m th·∫•y {filename}")

    # l·∫•y version l·ªõn nh·∫•t
    versions.sort(key=lambda x: x[0], reverse=True)
    return versions[0][1]

In [None]:
RUST_BINARY = find_latest_checkpoint(BASE_DIR, "threes_game-0.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl")
!pip install --force-reinstall "{RUST_BINARY}"

In [None]:
import threes_rs

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
import os
import time
from collections import deque
import threes_rs  # Rust binding

# --- C·∫§U H√åNH (CONFIGURATION) ---
NUM_ENVS = 128          # S·ªë l∆∞·ª£ng m√¥i tr∆∞·ªùng ch·∫°y song song
BATCH_SIZE = 256        # K√≠ch th∆∞·ªõc batch training
GAMMA = 0.99            # Discount factor
LR = 1e-4               # Learning rate
TARGET_UPDATE = 5000    # T·∫ßn su·∫•t c·∫≠p nh·∫≠t m·∫°ng target
MEMORY_SIZE = 500000    # Replay Buffer size
EPS_DECAY = 500000      # T·ªëc ƒë·ªô gi·∫£m Epsilon (c√†ng l·ªõn c√†ng ch·∫≠m -> kh√°m ph√° l√¢u h∆°n)

CHECKPOINT_FILE = find_latest_checkpoint(BASE_DIR)
CHECKPOINT_SAVE = "/kaggle/working/threes_pro_latest.pth"

print("Using checkpoint:", CHECKPOINT_FILE)

# Mapping 13 lo·∫°i qu√¢n (1, ... 3072) -> Index (0..12)
TILE_TYPES = [1, 2, 3, 6, 12, 24, 48, 96, 192, 384, 768, 1536, 3072]
TILE_MAP = {v: i for i, v in enumerate(TILE_TYPES)}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- HELPER CLASSES ---

class DataAugmenter:
    """Handles 8x Data Augmentation (Rotations & Flips) using Vectorized NumPy ops."""
    def __init__(self):
        # Pre-computed permutations for 4x4 board (flattened to 16)
        self.BOARD_PERMS = [
            np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), # Id
            np.array([12, 8, 4, 0, 13, 9, 5, 1, 14, 10, 6, 2, 15, 11, 7, 3]), # Rot90
            np.array([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]), # Rot180
            np.array([3, 7, 11, 15, 2, 6, 10, 14, 1, 5, 9, 13, 0, 4, 8, 12]), # Rot270
            np.array([3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12]), # FlipX
            np.array([0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]), # FlipMainDiag
            np.array([12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]), # FlipY
            np.array([15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0]), # FlipAntiDiag
        ]
        
        # Action Mappings: 0:Up, 1:Down, 2:Left, 3:Right
        # Maps action index when board is transformed
        self.ACTION_PERMS = [
            np.array([0, 1, 2, 3]), # Id
            np.array([3, 2, 0, 1]), # Rot90
            np.array([1, 0, 3, 2]), # Rot180
            np.array([2, 3, 1, 0]), # Rot270
            np.array([0, 1, 3, 2]), # FlipX
            np.array([3, 2, 0, 1])[np.array([0, 1, 3, 2])], # FlipMainDiag
            np.array([1, 0, 3, 2])[np.array([0, 1, 3, 2])], # FlipY
            np.array([2, 3, 1, 0])[np.array([0, 1, 3, 2])], # FlipAntiDiag
        ]

    def augment_batch(self, states, actions, rewards, next_states, dones):
        """
        Input: Batch of N transitions.
        Output: Batch of N * 8 transitions (Original + 7 Symmetries).
        """
        # Split States into Board (N, 16) and Hints (N, 13)
        hints = states[:, 16:]
        next_hints = next_states[:, 16:]
        boards = states[:, :16]
        next_boards = next_states[:, :16]
        
        aug_states, aug_actions, aug_next_states = [], [], []
        # Rewards and Dones are invariant to symmetry, so we just repeat them later
        
        actions_np = np.array(actions)

        # Generate 8 variations
        for i in range(8):
            perm_b = self.BOARD_PERMS[i]
            perm_a = self.ACTION_PERMS[i]
            
            # 1. Transform Boards
            b_sym = boards[:, perm_b]
            nb_sym = next_boards[:, perm_b]
            
            # 2. Transform Actions
            a_sym = perm_a[actions_np]
            
            # 3. Reconstruct Full States (Board + Hint)
            s_sym = np.concatenate([b_sym, hints], axis=1)
            ns_sym = np.concatenate([nb_sym, next_hints], axis=1)
            
            aug_states.append(s_sym)
            aug_actions.append(a_sym)
            aug_next_states.append(ns_sym)

        # Vectorized Concatenation
        # Repeat rewards/dones 8 times
        return (
            np.concatenate(aug_states, axis=0),
            np.concatenate(aug_actions, axis=0),
            np.tile(rewards, 8),          # [R1..Rn, R1..Rn, ...]
            np.concatenate(aug_next_states, axis=0),
            np.tile(dones, 8)
        )

# --- M·∫†NG TH·∫¶N KINH (CNN ARCHITECTURE) ---
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.embedding = nn.Embedding(16, 64) 
        
        self.conv_net = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=2, stride=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=2, stride=1),
            nn.ReLU(),
            nn.Flatten() 
        )

        self.fc = nn.Sequential(
            nn.Linear(1024 + 13, 512),
            nn.ReLU(),
            nn.Linear(512, 4)
        )

    def forward(self, state):
        # state: (Batch, 29) -> First 16 are board indices, last 13 are hints
        board = state[:, :16].long()
        hints = state[:, 16:].float()
        
        # Reshape board to image: (Batch, 16) -> (Batch, 64, 4, 4)
        x = self.embedding(board.clamp(0, 15))
        x = x.view(-1, 4, 4, 64).permute(0, 3, 1, 2)
        
        conv_out = self.conv_net(x)
        combined = torch.cat((conv_out, hints), dim=1)
        return self.fc(combined)

def prepare_state_batch(boards_flat, hint_sets):
    """Converts raw outputs from Rust env to Neural Net input tensors."""
    n = len(boards_flat)
    boards = np.array(boards_flat, dtype=np.float32)
    
    multi_hots = np.zeros((n, 13), dtype=np.float32)
    for i, hints in enumerate(hint_sets):
        for h in hints:
            if h in TILE_MAP:
                multi_hots[i, TILE_MAP[h]] = 1.0
                
    return np.concatenate([boards, multi_hots], axis=1)

# --- REPLAY BUFFER ---
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    def push_batch(self, states, actions, rewards, next_states, dones):
        # Efficient push loop
        for i in range(len(states)):
            self.buffer.append((states[i], actions[i], rewards[i], next_states[i], dones[i]))
            
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, ns, d = zip(*batch)
        return (torch.tensor(np.array(s), dtype=torch.float32, device=device),
                torch.tensor(a, dtype=torch.long, device=device),
                torch.tensor(r, dtype=torch.float32, device=device),
                torch.tensor(np.array(ns), dtype=torch.float32, device=device),
                torch.tensor(d, dtype=torch.float32, device=device))

def save_checkpoint(episode, steps, policy_net, optimizer, memory):
    checkpoint = {
        'episode': episode,
        'steps': steps,
        'model_state': policy_net.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        # 'memory': list(memory.buffer) # Uncomment to save replay buffer (large!)
    }
    torch.save(checkpoint, CHECKPOINT_SAVE)
    print(f"üíæ Checkpoint saved at ep {episode}")

def load_checkpoint(policy_net, optimizer, target_net, memory):
    if os.path.exists(CHECKPOINT_FILE):
        print("üîÑ Loading checkpoint...")
        checkpoint = torch.load(CHECKPOINT_FILE, map_location=device, weights_only=False)
        policy_net.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        target_net.load_state_dict(policy_net.state_dict())
        return checkpoint['episode'], checkpoint['steps']
    return 0, 0

# --- MAIN TRAINING LOOP ---
if __name__ == "__main__":
    # Init Objects
    policy_net = DQN().to(device)
    target_net = DQN().to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=LR)
    memory = ReplayBuffer(MEMORY_SIZE)
    criterion = nn.HuberLoss()
    augmenter = DataAugmenter() # New Augmenter Class
    
    # Load Weights
    start_episode, total_steps = load_checkpoint(policy_net, optimizer, target_net, memory)
    
    # Init Vectorized Env (Rust)
    vec_env = threes_rs.ThreesVecEnv(NUM_ENVS)
    
    # Tracking Metrics
    WINDOW_SIZE = 100
    metrics = {
        'r': deque(maxlen=WINDOW_SIZE), # Rewards
        'm': deque(maxlen=WINDOW_SIZE), # Moves
        'l': deque(maxlen=WINDOW_SIZE), # Loss
        'q': deque(maxlen=WINDOW_SIZE), # Q-Values
        't': deque(maxlen=WINDOW_SIZE)  # Max Tiles
    }
    
    # Per-Env Accumulators
    env_curr = {
        'moves': np.zeros(NUM_ENVS),
        'rewards': np.zeros(NUM_ENVS),
        'max_tile': np.zeros(NUM_ENVS)
    }

    print(f"üöÄ Starting VECTORIZED training with {NUM_ENVS} envs...")
    start_time = time.time()

    # Initial Observation
    raw_boards = vec_env.reset()
    hint_sets = vec_env.get_hint_sets()
    states = prepare_state_batch(raw_boards, hint_sets)
    episode_count = start_episode

    while True:
        total_steps += NUM_ENVS
        eps = 0.05 + (0.95) * math.exp(-1. * total_steps / EPS_DECAY)
        
        # 1. SELECT ACTION (Batch)
        valid_masks = vec_env.valid_moves_batch()
        valid_masks_t = torch.tensor(valid_masks, device=device, dtype=torch.bool)
        
        with torch.no_grad():
            states_t = torch.tensor(states, dtype=torch.float32, device=device)
            q_values_all = policy_net(states_t)
            q_values_all[~valid_masks_t] = -float('inf') # Action Masking
            
            if random.random() > eps:
                actions = q_values_all.argmax(dim=1).cpu().numpy().tolist()
            else:
                # Random valid move sampling
                actions = []
                for mask in valid_masks:
                    valid_indices = [j for j, v in enumerate(mask) if v]
                    actions.append(random.choice(valid_indices) if valid_indices else 0)

        # 2. STEP (Parallel Rust Env)
        next_boards, rewards, dones = vec_env.step(actions)
        next_hint_sets = vec_env.get_hint_sets()
        next_states = prepare_state_batch(next_boards, next_hint_sets)
        
        # 3. AUGMENT & STORE
        # Generate 8x data per step
        aug_batch = augmenter.augment_batch(states, actions, rewards, next_states, dones)
        memory.push_batch(*aug_batch)
        
        states = next_states
        
        # 4. LOGGING & METRICS
        for i in range(NUM_ENVS):
            env_curr['moves'][i] += 1
            env_curr['rewards'][i] += rewards[i]
            env_curr['max_tile'][i] = max(env_curr['max_tile'][i], max(next_boards[i]))
            
            if dones[i]:
                episode_count += 1
                metrics['r'].append(env_curr['rewards'][i])
                metrics['m'].append(env_curr['moves'][i])
                metrics['t'].append(env_curr['max_tile'][i])
                
                # Reset Env Stats
                env_curr['moves'][i] = 0
                env_curr['rewards'][i] = 0
                env_curr['max_tile'][i] = 0
                
                # Console Log
                if episode_count % 100 == 0:
                    elapsed = time.time() - start_time
                    fps = total_steps / elapsed
                    print(f"Ep {episode_count:6d} | Steps: {total_steps:9d} | "
                          f"Avg R: {np.mean(metrics['r']):6.2f} | "
                          f"Moves: {np.mean(metrics['m']):4.1f} | "
                          f"MaxTile: {np.mean(metrics['t']):4.0f} | "
                          f"Loss: {np.mean(metrics['l']) if metrics['l'] else 0:.4f} | "
                          f"Q: {np.mean(metrics['q']) if metrics['q'] else 0:.2f} | "
                          f"FPS: {fps:.1f}")
                
                # Save Checkpoint
                if episode_count % 5000 == 0:
                    save_checkpoint(episode_count, total_steps, policy_net, optimizer, memory)

        # 5. TRAIN (Intensive)
        if len(memory.buffer) >= BATCH_SIZE:
             # Train 8 times per step to match data generation speed
            for _ in range(8):
                transitions = memory.sample(BATCH_SIZE)
                b_s, b_a, b_r, b_ns, b_d = transitions
                
                q_eval = policy_net(b_s).gather(1, b_a.unsqueeze(1)).squeeze(1)
                
                with torch.no_grad():
                    next_actions = policy_net(b_ns).argmax(1).unsqueeze(1)
                    next_q = target_net(b_ns).gather(1, next_actions).squeeze(1)
                    next_q[b_d.bool()] = 0.0
                    expected_q = b_r + (GAMMA * next_q)
                
                loss = criterion(q_eval, expected_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                metrics['l'].append(loss.item())
                metrics['q'].append(q_eval.mean().item())

        # Update Target Net
        if total_steps % TARGET_UPDATE < NUM_ENVS:
            target_net.load_state_dict(policy_net.state_dict())
