In [None]:
# Fix for data/iterable_dataloader.py
# We need to align training with inference by using inferred opponent actions instead of ground truth.

from generals.core.action import Action

def infer_opponent_action(prev_obs_dict, curr_obs_dict, grid_size):
    """
    Infer opponent action based on visible changes, matching SOTAAgent logic.
    """
    # Extract masks
    prev_opp = prev_obs_dict['opponent_cells']
    cur_opp = curr_obs_dict['opponent_cells']
    
    # Visibility mask (derived from fog and structures in fog)
    # Note: In the dataloader, 'fog_cells' is 1 for fog, 0 for visible.
    fog = curr_obs_dict['fog_cells']
    sif = curr_obs_dict['structures_in_fog']
    visible = (fog == 0) & (sif == 0)
    
    # Newly visible opponent-owned cells (candidate destination).
    # Must be visible NOW, owned by opponent NOW, and NOT owned by opponent BEFORE.
    new_opp = visible & (cur_opp.astype(bool)) & (~prev_opp.astype(bool))
    new_positions = np.argwhere(new_opp)
    
    if new_positions.shape[0] != 1:
        return Action(to_pass=True)

    dest_r, dest_c = (int(new_positions[0][0]), int(new_positions[0][1]))

    # Candidate sources: adjacent cells that were opponent-owned previously.
    for direction, (dr, dc) in enumerate([(-1, 0), (1, 0), (0, -1), (0, 1)]):
        src_r, src_c = dest_r - dr, dest_c - dc
        if src_r < 0 or src_c < 0 or src_r >= grid_size or src_c >= grid_size:
            continue
        if prev_opp[src_r, src_c] == 1:
            # Found a plausible source
            return Action(to_pass=False, row=src_r, col=src_c, direction=direction, to_split=False)

    return Action(to_pass=True)


In [1]:
import argparse
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
from tqdm import tqdm
import numpy as np
from datetime import datetime
import json

from data.dataloader import GeneralsReplayDataset
from data.iterable_dataloader import create_iterable_dataloader
from agents.network import SOTANetwork

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("Warning: wandb not installed. Logging to wandb is disabled.")

  from pkg_resources import resource_stream, resource_exists


pygame 2.6.1 (SDL 2.28.4, Python 3.11.14)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [4]:
config_path = '/root/oyx_fork/configs/config_base.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)


data_config = config['data']
train_config = config['training']

train_replays = int(data_config['max_replays'] * data_config['train_split'])
val_replays = data_config['max_replays'] - train_replays

train_loader = create_iterable_dataloader(
    data_dir=data_config['data_dir'],
    batch_size=train_config['batch_size'],
    grid_size=data_config['grid_size'],
    num_workers=train_config['num_workers'],
    max_replays=train_replays,
    min_stars=data_config['min_stars'],
    max_turns=data_config['max_turns'],
)

In [None]:
sample_batch = next(iter(train_loader))

In [20]:
# Unpack the batch
obs_batch, memory_batch, actions_batch, _ = sample_batch
obs, memory, actions = obs_batch[0], memory_batch[0], actions_batch[0]
print(f"Batch dimensions:")
print(f"  Obs: {obs.shape}")
print(f"  Memory: {memory.shape}")
print(f"  Actions: {actions.shape}")

Batch dimensions:
  Obs: torch.Size([15, 24, 24])
  Memory: torch.Size([18, 24, 24])
  Actions: torch.Size([5])


In [16]:
for i, data in enumerate(train_loader):
    if i == 0:
        print(data['obs']['screen'].shape)

TypeError: list indices must be integers or slices, not str