<a href="https://colab.research.google.com/github/cafeblue999/test/blob/master/simple2_train_tpu_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
train.py
(説明略)
"""

# ===== 固定定義：環境切り替え用フラグ =====
USE_TPU = True
USE_COLAB = True

# ------------------------------
# 必要なライブラリのインポート
# ------------------------------
import os, re, pickle, zipfile, random, numpy as np, configparser, argparse, functools
from tqdm import tqdm
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist

if USE_TPU:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_backend

if USE_COLAB:
    try:
        os.system("fusermount -u /content/drive")
    except Exception as e:
        print("Google Drive unmount failed:", e)
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
    except ImportError:
        print("Google Colab module not found.")

bar_fmt = "{l_bar}{bar}| {n:>6d}/{total:>6d} [{elapsed}<{remaining}, {rate_fmt}]"

# ==============================
# デバイス設定
# ==============================
if USE_TPU:
    device = xm.xla_device()
    if not dist.is_initialized():
        dist.init_process_group("xla", init_method='xla://')
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================
# ディレクトリ設定
# ==============================
if USE_COLAB:
    BASE_DIR = "/content/drive/My Drive/sgf"
    TRAIN_SGF_DIR = os.path.join(BASE_DIR, "train_sgf_KK")
    VAL_SGF_DIR = os.path.join(BASE_DIR, "test")
    TEST_SGFS_ZIP = os.path.join(VAL_SGF_DIR, "test_sgfs.zip")
    MODEL_OUTPUT_DIR = os.path.join(BASE_DIR, "models")
    CHECKPOINT_FILE = os.path.join(BASE_DIR, "checkpoint2.pt")
else:
    BASE_DIR = r"D:\igo\simple2"
    TRAIN_SGF_DIR = os.path.join(BASE_DIR, "train_sgf")
    VAL_SGF_DIR = os.path.join(BASE_DIR, "test")
    TEST_SGFS_ZIP = os.path.join(VAL_SGF_DIR, "test_sgfs.zip")
    MODEL_OUTPUT_DIR = os.path.join(BASE_DIR, "models")
    CHECKPOINT_FILE = os.path.join(BASE_DIR, "checkpoint2.pt")

if not os.path.exists(MODEL_OUTPUT_DIR):
    os.makedirs(MODEL_OUTPUT_DIR)

# ==============================
# DummyLogger
# ==============================
from datetime import datetime, timedelta, timezone
JST = timezone(timedelta(hours=9), 'JST')
class DummyLogger:
    def info(self, message, *args, **kwargs):
        timestamp = datetime.now(JST).strftime("%Y-%m-%d %H:%M:%S")
        print(f"{timestamp} INFO: {message}", *args, **kwargs)
    def warning(self, message, *args, **kwargs):
        timestamp = datetime.now(JST).strftime("%Y-%m-%d %H:%M:%S")
        print(f"{timestamp} WARNING: {message}", *args, **kwargs)
    def error(self, message, *args, **kwargs):
        timestamp = datetime.now(JST).strftime("%Y-%m-%d %H:%M:%S")
        print(f"{timestamp} ERROR: {message}", *args, **kwargs)

sgf_logger = DummyLogger()
train_logger = DummyLogger()

# ==============================
# 設定ファイル読み込み
# ==============================
def load_config(config_path):
    config = configparser.ConfigParser()
    config.read(config_path)
    try:
        BOARD_SIZE = int(config.get("BOARD", "board_size", fallback="19"))
        HISTORY_LENGTH = int(config.get("DATA", "history_length", fallback="8"))
        NUM_CHANNELS = 2 * HISTORY_LENGTH + 1
        NUM_ACTIONS = BOARD_SIZE * BOARD_SIZE + 1
        num_residual_blocks = int(config.get("MODEL", "num_residual_blocks", fallback="20"))
        model_channels = int(config.get("MODEL", "model_channels", fallback="256"))
        num_epochs = int(config.get("TRAIN", "num_epochs", fallback="100"))
        batch_size = int(config.get("TRAIN", "batch_size", fallback="256"))
        learning_rate = float(config.get("TRAIN", "learning_rate", fallback="0.001"))
        patience = int(config.get("TRAIN", "patience", fallback="10"))
        factor = float(config.get("TRAIN", "factor"))
        number_max_files = int(config.get("TRAIN", "number_max_files", fallback="256"))
    except Exception as e:
        train_logger.error(f"Error reading configuration: {e}")
        exit(1)
    return {
        "BOARD_SIZE": BOARD_SIZE,
        "HISTORY_LENGTH": HISTORY_LENGTH,
        "NUM_CHANNELS": NUM_CHANNELS,
        "NUM_ACTIONS": NUM_ACTIONS,
        "num_residual_blocks": num_residual_blocks,
        "model_channels": model_channels,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "patience": patience,
        "factor": factor,
        "number_max_files": number_max_files
    }

CONFIG_PATH = os.path.join(BASE_DIR, "config_py.ini")
config_params = load_config(CONFIG_PATH)

BOARD_SIZE = config_params["BOARD_SIZE"]
HISTORY_LENGTH = config_params["HISTORY_LENGTH"]
NUM_CHANNELS = config_params["NUM_CHANNELS"]
NUM_ACTIONS = config_params["NUM_ACTIONS"]
num_residual_blocks = config_params["num_residual_blocks"]
model_channels = config_params["model_channels"]
num_epochs = config_params["num_epochs"]
batch_size = config_params["batch_size"]
learning_rate = config_params["learning_rate"]
patience = config_params["patience"]
factor = config_params["factor"]
number_max_files = config_params["number_max_files"]

train_logger.info("==== Loaded Configuration ====")
train_logger.info(f"Config file: {CONFIG_PATH}")
train_logger.info(f"BOARD_SIZE: {BOARD_SIZE}")
train_logger.info(f"HISTORY_LENGTH: {HISTORY_LENGTH}")
train_logger.info(f"NUM_CHANNELS: {NUM_CHANNELS}")
train_logger.info(f"NUM_ACTIONS: {NUM_ACTIONS}")
train_logger.info(f"num_residual_blocks: {num_residual_blocks}")
train_logger.info(f"model_channels: {model_channels}")
train_logger.info(f"num_epochs: {num_epochs}")
train_logger.info(f"batch_size: {batch_size}")
train_logger.info(f"learning_rate: {learning_rate}")
train_logger.info(f"patience: {patience}")
train_logger.info(f"factor: {factor}")
train_logger.info(f"number_max_files: {number_max_files}")
train_logger.info("===============================")

# ==============================
# ネットワーク定義
# ==============================
# (ここでは、ResidualBlock, DilatedResidualBlock, SelfAttention, EnhancedResNetPolicyValueNetwork の定義を省略せずそのまま置く)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

class DilatedResidualBlock(nn.Module):
    def __init__(self, channels, dilation=2):
        super(DilatedResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):
        batch, C, H, W = x.size()
        proj_query = self.query_conv(x).view(batch, -1, H * W).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch, -1, H * W)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch, -1, H * W)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch, C, H, W)
        return self.gamma * out + x

class EnhancedResNetPolicyValueNetwork(nn.Module):
    def __init__(self, board_size, num_channels, num_residual_blocks, in_channels):
        super(EnhancedResNetPolicyValueNetwork, self).__init__()
        self.board_size = board_size
        self.conv_input = nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)
        self.bn_input = nn.BatchNorm2d(num_channels)
        blocks = []
        for i in range(num_residual_blocks):
            if i % 4 == 0:
                blocks.append(DilatedResidualBlock(num_channels, dilation=2))
            else:
                blocks.append(ResidualBlock(num_channels))
        self.residual_blocks = nn.Sequential(*blocks)
        self.attention = SelfAttention(num_channels)
        self.conv_policy = nn.Conv2d(num_channels, 2, kernel_size=1)
        self.bn_policy = nn.BatchNorm2d(2)
        self.dropout_policy = nn.Dropout(p=0.5)
        self.fc_policy = nn.Linear(2 * board_size * board_size, NUM_ACTIONS)
        self.conv_value = nn.Conv2d(num_channels, 1, kernel_size=1)
        self.bn_value = nn.BatchNorm2d(1)
        self.fc_value1 = nn.Linear(board_size * board_size, 64)
        self.dropout_value = nn.Dropout(p=0.5)
        self.fc_value2 = nn.Linear(64, 2)
    def forward(self, x):
        x = F.relu(self.bn_input(self.conv_input(x)))
        x = self.residual_blocks(x)
        x = self.attention(x)
        # Policy head
        p = F.relu(self.bn_policy(self.conv_policy(x)))
        p = self.dropout_policy(p)
        p = p.view(p.size(0), -1)
        p = self.fc_policy(p)
        p = F.log_softmax(p, dim=1)
        # Value head
        v = F.relu(self.bn_value(self.conv_value(x)))
        v = v.view(v.size(0), -1)
        v = F.relu(self.fc_value1(v))
        v = self.dropout_value(v)
        out = self.fc_value2(v)
        value = torch.tanh(out[:, 0])
        margin = out[:, 1]
        return p, (value, margin)

# ==============================
# SGFパーサー＆前処理関数
# ==============================
def parse_sgf(sgf_text):
    sgf_text = sgf_text.strip()
    if sgf_text.startswith('(') and sgf_text.endswith(')'):
        sgf_text = sgf_text[1:-1]
    parts = [part for part in sgf_text.split(';') if part.strip()]
    nodes = []
    prop_pattern = re.compile(r'([A-Z]+)\[([^\]]*)\]')
    for part in parts:
        props = {}
        for m in prop_pattern.finditer(part):
            key = m.group(1).encode('utf-8')
            value = m.group(2)
            props[key] = [value.encode('utf-8')]
        nodes.append(props)
    if not nodes:
        raise ValueError("No nodes found in SGF file")
    return {"root": nodes[0], "nodes": nodes[1:]}

def build_input_from_history(history, current_player, board_size, history_length):
    channels = []
    for i in range(history_length):
        if i < len(history):
            board = history[-(i+1)]
        else:
            board = np.zeros((board_size, board_size), dtype=np.float32)
        channels.append((board == 1).astype(np.float32))
        channels.append((board == 2).astype(np.float32))
    current_plane = np.ones((board_size, board_size), dtype=np.float32) if current_player == 1 else np.zeros((board_size, board_size), dtype=np.float32)
    channels.append(current_plane)
    return np.stack(channels, axis=0)

def apply_dihedral_transform(input_array, transform_id):
    if transform_id < 4:
        return np.rot90(input_array, k=transform_id, axes=(1,2))
    else:
        flipped = np.flip(input_array, axis=2)
        return np.rot90(flipped, k=transform_id-4, axes=(1,2))

def transform_policy(target_policy, transform_id, board_size):
    idx = np.argmax(target_policy)
    if idx == board_size * board_size:
        return target_policy
    row = idx // board_size
    col = idx % board_size
    board = np.zeros((board_size, board_size), dtype=np.float32)
    board[row, col] = 1.0
    transformed_board = apply_dihedral_transform(board[np.newaxis, ...], transform_id)[0]
    new_idx = np.argmax(transformed_board)
    new_policy = np.zeros_like(target_policy)
    new_policy[new_idx] = 1.0
    return new_policy

# ==============================
# 盤面クラス
# ==============================
class Board:
    def __init__(self, size):
        self.size = size
        self.board = np.zeros((size, size), dtype=np.int8)
    def neighbors(self, row, col):
        for dr, dc in [(-1,0), (1,0), (0,-1), (0,1)]:
            r, c = row+dr, col+dc
            if 0 <= r < self.size and 0 <= c < self.size:
                yield (r, c)
    def get_group(self, row, col):
        color = self.board[row, col]
        group = []
        liberties = set()
        stack = [(row, col)]
        visited = set()
        while stack:
            r, c = stack.pop()
            if (r, c) in visited:
                continue
            visited.add((r, c))
            group.append((r, c))
            for nr, nc in self.neighbors(r, c):
                if self.board[nr, nc] == 0:
                    liberties.add((nr, nc))
                elif self.board[nr, nc] == color and (nr, nc) not in visited:
                    stack.append((nr, nc))
        return group, liberties
    def play(self, move, color):
        row, col = move
        if self.board[row, col] != 0:
            raise ValueError("Illegal move: position already occupied")
        stone = 1 if color=='b' else 2
        self.board[row, col] = stone
        opponent = 2 if stone==1 else 1
        for nr, nc in self.neighbors(row, col):
            if self.board[nr, nc] == opponent:
                group, liberties = self.get_group(nr, nc)
                if len(liberties) == 0:
                    for r, c in group:
                        self.board[r, c] = 0
        group, liberties = self.get_group(row, col)
        if len(liberties) == 0:
            for r, c in group:
                self.board[r, c] = 0

# ==============================
# Datasetクラス
# ==============================
class AlphaZeroSGFDatasetPreloaded(Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        inp, pol, val, mar = self.samples[idx]
        board_tensor = torch.tensor(inp, dtype=torch.float32).view(NUM_CHANNELS, BOARD_SIZE, BOARD_SIZE)
        target_policy_tensor = torch.tensor(pol, dtype=torch.float32)
        target_value_tensor = torch.tensor(val, dtype=torch.float32)
        target_margin_tensor = torch.tensor(mar, dtype=torch.float32)
        return board_tensor, target_policy_tensor, target_value_tensor, target_margin_tensor

# ==============================
# SGFからサンプル生成関数
# ==============================
def process_sgf_to_samples_from_text(sgf_src, board_size, history_length, augment_all):
    samples = []
    try:
        sgf_data = parse_sgf(sgf_src)
    except Exception as e:
        sgf_logger.error(f"Error processing SGF text: {e}")
        return samples
    root = sgf_data["root"]
    try:
        sz = int(root.get(b'SZ')[0].decode('utf-8'))
    except Exception:
        sz = board_size
    result_prop = root.get(b'RE') if b'RE' in root else None
    result_str = result_prop[0].decode('utf-8') if result_prop and len(result_prop)>0 else "不明"
    target_value = 1.0 if result_str.startswith("B+") else -1.0 if result_str.startswith("W+") else 0.0
    try:
        target_margin = float(result_str[2:]) if result_str[2:] else 0.0
    except Exception:
        target_margin = 0.0
    board_obj = Board(sz)
    history_boards = [board_obj.board.copy().astype(np.float32)]
    current_player = 1
    for node in sgf_data["nodes"]:
        move_prop = b'B' if current_player==1 else b'W'
        move_vals = node.get(move_prop)
        input_tensor = build_input_from_history(history_boards, current_player, sz, history_length)
        if move_vals is None or len(move_vals)==0 or move_vals[0]==b"":
            target_policy = np.zeros(sz*sz+1, dtype=np.float32)
            target_policy[sz*sz] = 1.0
        else:
            try:
                move = move_vals[0]
                col = ord(move.decode('utf-8')[0])-ord('a')
                row = ord(move.decode('utf-8')[1])-ord('a')
                target_policy = np.zeros(sz*sz+1, dtype=np.float32)
                target_policy[row*sz+col] = 1.0
            except Exception as e:
                sgf_logger.warning(f"Error parsing move in SGF text: {e}")
                target_policy = np.zeros(sz*sz+1, dtype=np.float32)
                target_policy[sz*sz] = 1.0
        transforms = range(8) if augment_all else [np.random.randint(0,8)]
        for t in transforms:
            inp = apply_dihedral_transform(input_tensor, t)
            pol = transform_policy(target_policy, t, sz)
            samples.append((
                inp.flatten(),
                pol,
                np.array([target_value], dtype=np.float32),
                np.array([target_margin], dtype=np.float32)
            ))
        if move_vals is not None and len(move_vals)>0 and move_vals[0]!=b"":
            try:
                move = move_vals[0]
                col = ord(move.decode('utf-8')[0])-ord('a')
                row = ord(move.decode('utf-8')[1])-ord('a')
                board_obj.play((row, col), 'b' if current_player==1 else 'w')
                history_boards.append(board_obj.board.copy().astype(np.float32))
            except Exception as e:
                sgf_logger.warning(f"Error updating board from SGF text: {e}")
        current_player = 2 if current_player==1 else 1
    return samples

# ==============================
# データセットの保存／読み込み
# ==============================
def save_dataset(samples, output_file):
    with open(output_file, "wb") as f:
        pickle.dump(samples, f)
    sgf_logger.info(f"Saved dataset to {output_file}")

def load_dataset(output_file):
    with open(output_file, "rb") as f:
        samples = pickle.load(f)
    sgf_logger.info(f"Loaded dataset from {output_file}")
    return samples

# ==============================
# Test用データセット生成（zip利用）
# ==============================
def prepare_test_dataset(sgf_dir, board_size, history_length, augment_all, output_file):
    if os.path.exists(output_file):
        sgf_logger.info(f"Test dataset pickle {output_file} already exists. Loading it directly...")
        return load_dataset(output_file)
    if not os.path.exists(TEST_SGFS_ZIP):
        sgf_logger.info(f"Creating zip archive {TEST_SGFS_ZIP} from SGF files in {sgf_dir} ...")
        sgf_files = [os.path.join(sgf_dir, f) for f in os.listdir(sgf_dir)
                     if f.endswith('.sgf') and "analyzed" not in f.lower()]
        with zipfile.ZipFile(TEST_SGFS_ZIP, 'w') as zf:
            for filepath in sgf_files:
                zf.write(filepath, arcname=os.path.basename(filepath))
        sgf_logger.info(f"Zip archive created: {TEST_SGFS_ZIP}")
    else:
        sgf_logger.info(f"Zip archive {TEST_SGFS_ZIP} already exists. Loading from it...")
    all_samples = []
    with zipfile.ZipFile(TEST_SGFS_ZIP, 'r') as zf:
        sgf_names = [name for name in zf.namelist() if name.endswith('.sgf') and "analyzed" not in name.lower()]
        sgf_names.sort()
        sgf_logger.info(f"TEST: Total SGF files in zip to process: {len(sgf_names)}")
        for name in tqdm(sgf_names, desc="Processing TEST SGF files"):
            try:
                sgf_src = zf.read(name).decode('utf-8')
                file_samples = process_sgf_to_samples_from_text(sgf_src, board_size, history_length, augment_all=False)
                all_samples.extend(file_samples)
            except Exception as e:
                sgf_logger.error(f"Error processing {name} from zip: {e}")
    save_dataset(all_samples, output_file)
    sgf_logger.info(f"TEST: Saved test dataset (total samples: {len(all_samples)}) to {output_file}")
    return all_samples

# ==============================
# グローバル変数：未処理のSGFファイルリスト
# ==============================
remaining_sgf_files = []

def prepare_train_dataset_cycle(sgf_dir, board_size, history_length, augment_all, max_files):
    """
    指定フォルダ内のSGFファイル（"analyzed"を含まない）全体から、
    1サイクル分、ランダム順（重複なし）でmax_files件分のみ処理し、
    各SGFから前処理済みサンプルを生成して返す。
    """
    global remaining_sgf_files
    if not remaining_sgf_files:
        all_files = [os.path.join(sgf_dir, f) for f in os.listdir(sgf_dir)
                     if f.endswith('.sgf') and "analyzed" not in f.lower()]
        random.shuffle(all_files)
        remaining_sgf_files = all_files
        sgf_logger.info("Regenerated the random order of all SGF files.")
    if len(remaining_sgf_files) < max_files:
        selected_files = remaining_sgf_files
        remaining_sgf_files = []
        sgf_logger.info(f"Remaining SGF files less than max_files ({max_files}). Processing {len(selected_files)} files.")
    else:
        selected_files = remaining_sgf_files[:max_files]
        remaining_sgf_files = remaining_sgf_files[max_files:]
        sgf_logger.info(f"Selected {len(selected_files)} SGF files.")
    all_samples = []
    for sgf_file in selected_files:
        try:
            with open(sgf_file, "r", encoding="utf-8") as f:
                sgf_src = f.read()
            file_samples = process_sgf_to_samples_from_text(sgf_src, board_size, history_length, augment_all)
            all_samples.extend(file_samples)
        except Exception as e:
            sgf_logger.error(f"Error processing file {sgf_file}: {e}")
    random.shuffle(all_samples)
    sgf_logger.info(f"Training dataset cycle created. Total samples: {len(all_samples)}")
    return all_samples

def load_training_dataset(sgf_dir, board_size, history_length, augment_all, max_files):
    """
    指定フォルダ内のSGFファイルから、max_files 件分のサンプルを一度だけ生成し、
    前処理済みデータセット（AlphaZeroSGFDatasetPreloaded のインスタンス）を返す。
    """
    samples = prepare_train_dataset_cycle(sgf_dir, board_size, history_length, augment_all, max_files)
    dataset = AlphaZeroSGFDatasetPreloaded(samples)
    return dataset

# ==============================
# 訓練ループ用関数（1エポック分）
# ==============================
def train_one_iteration(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_policy_loss = 0.0
    total_value_loss = 0.0
    total_margin_loss = 0.0
    num_batches = 0
    overall_correct = 0
    overall_samples = 0
    value_loss_coefficient = 0.1
    margin_loss_coefficient = 0.0001
    print_interval = 100
    accumulated_accuracy = 0.0
    group_batches = 0
    for boards, target_policies, target_values, target_margins in tqdm(train_loader, desc="Training", bar_format=bar_fmt):
        boards = boards.to(device)
        target_policies = target_policies.to(device)
        target_values = target_values.to(device)
        target_margins = target_margins.to(device)
        optimizer.zero_grad()
        pred_policy, (pred_value, pred_margin) = model(boards)
        policy_loss = -torch.sum(target_policies * pred_policy) / boards.size(0)
        value_loss = F.mse_loss(pred_value.view(-1), target_values.view(-1))
        margin_loss = F.mse_loss(pred_margin.view(-1), target_margins.view(-1))
        loss = policy_loss + value_loss_coefficient * value_loss + margin_loss_coefficient * margin_loss
        loss.backward()
        optimizer.step()
        if USE_TPU:
            xm.mark_step()
        total_loss += loss.item()
        total_policy_loss += policy_loss.item()
        total_value_loss += value_loss.item()
        total_margin_loss += margin_loss.item()
        num_batches += 1
        batch_pred = pred_policy.argmax(dim=1)
        batch_target = target_policies.argmax(dim=1)
        batch_accuracy = (batch_pred == batch_target).float().mean().item()
        overall_correct += (batch_pred == batch_target).sum().item()
        overall_samples += boards.size(0)
        accumulated_accuracy += batch_accuracy
        group_batches += 1
        if num_batches % print_interval == 0:
            avg_accuracy = accumulated_accuracy / group_batches
            start_batch = num_batches - group_batches + 1
            end_batch = num_batches
            print(f"Batch {start_batch:4d}～{end_batch:4d} policy accuracy average: {avg_accuracy:6.4f}")
            accumulated_accuracy = 0.0
            group_batches = 0
        del boards, target_policies, target_values, target_margins
    if group_batches > 0:
        avg_accuracy = accumulated_accuracy / group_batches
        print(f"Other ({group_batches} batch) policy accuracy average: {avg_accuracy:6.4f}")
    if overall_samples > 0:
        overall_accuracy = overall_correct / overall_samples
        print(f"Overall policy accuracy of the latest model state in this training loop: {overall_accuracy:6.4f}")
    else:
        overall_accuracy = 0.0
    avg_loss = total_loss / num_batches
    avg_policy_loss = total_policy_loss / num_batches
    avg_value_loss = value_loss_coefficient * total_value_loss / num_batches
    avg_margin_loss = margin_loss_coefficient * total_margin_loss / num_batches
    train_logger.info(f"Training iteration total average loss: {avg_loss:.5f}")
    train_logger.info(f"Training iteration average policy loss: {avg_policy_loss:.5f}")
    train_logger.info(f"Training iteration average value loss: {avg_value_loss:.5f}")
    train_logger.info(f"Training iteration average margin loss: {avg_margin_loss:.5f}")
    train_logger.info(f"Training iteration overall policy accuracy: {overall_accuracy:.5f}")
    return avg_loss

# ==============================
# チェックポイント保存＆復元
# ==============================
def save_checkpoint(model, optimizer, epoch, best_val_loss, epochs_no_improve, best_policy_accuracy, checkpoint_file, device):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'epochs_no_improve': epochs_no_improve,
        'best_policy_accuracy': best_policy_accuracy
    }
    torch.save(checkpoint, checkpoint_file)
    train_logger.info(f"Checkpoint saved at epoch {epoch} to {checkpoint_file}")

def recursive_to(data, device):
    if isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, dict):
        return {k: recursive_to(v, device) for k, v in data.items()}
    elif isinstance(data, list):
        return [recursive_to(item, device) for item in data]
    else:
        return data

def load_checkpoint(model, optimizer, checkpoint_file, device):
    if os.path.exists(checkpoint_file):
        checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
        new_state_dict = {k: v.to(device) for k, v in checkpoint['model_state_dict'].items()}
        model.load_state_dict(new_state_dict)
        optimizer_state = recursive_to(checkpoint['optimizer_state_dict'], device)
        optimizer.load_state_dict(optimizer_state)
        epoch = checkpoint['epoch']
        best_policy_accuracy = checkpoint.get('best_policy_accuracy', 0.0)
        train_logger.info(f"Checkpoint loaded from {checkpoint_file} at epoch {epoch}")
        return epoch, best_policy_accuracy
    else:
        train_logger.info("No checkpoint found. Starting from scratch.")
        return 0, 0.0

# ==============================
# メイン用：訓練用データセットのキャッシュ読み込み
# ==============================
def load_training_dataset(sgf_dir, board_size, history_length, augment_all, max_files):
    samples = prepare_train_dataset_cycle(sgf_dir, board_size, history_length, augment_all, max_files)
    dataset = AlphaZeroSGFDatasetPreloaded(samples)
    return dataset

# ==============================
# TPU分散環境で動作するメイン処理
# ==============================
def _mp_fn(rank):
    if USE_TPU:
        if not dist.is_initialized():
            dist.init_process_group("xla", init_method='xla://')
        device = xm.xla_device()
        train_logger.info("Running on TPU device: {}".format(device))
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        train_logger.info("Running on device: {}".format(device))
    test_dataset_pickle = os.path.join(VAL_SGF_DIR, "test_dataset.pkl")
    test_samples = prepare_test_dataset(VAL_SGF_DIR, BOARD_SIZE, HISTORY_LENGTH, True, test_dataset_pickle)
    test_dataset = AlphaZeroSGFDatasetPreloaded(test_samples)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    model = EnhancedResNetPolicyValueNetwork(
        board_size=BOARD_SIZE,
        num_channels=model_channels,
        num_residual_blocks=num_residual_blocks,
        in_channels=NUM_CHANNELS
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=patience, factor=factor)
    start_epoch, best_policy_accuracy = load_checkpoint(model, optimizer, CHECKPOINT_FILE, device)
    for f in os.listdir(MODEL_OUTPUT_DIR):
        if f.startswith("model_") and f.endswith(".pt"):
            try:
                acc = float(f[len("model_"):-len(".pt")])
                if acc > best_policy_accuracy:
                    best_policy_accuracy = acc
                    best_model_file = os.path.join(MODEL_OUTPUT_DIR, f)
                    model.load_state_dict(torch.load(best_model_file, map_location=device))
                    train_logger.info("Restored best model with policy accuracy {:.5f} from {}".format(acc, best_model_file))
            except Exception:
                continue
    train_logger.info("Initial best_policy_accuracy: {:.5f}".format(best_policy_accuracy))
    current_lr = optimizer.param_groups[0]['lr']
    train_logger.info("Current learning rate : {:.8f}".format(current_lr))
    training_dataset = load_training_dataset(TRAIN_SGF_DIR, BOARD_SIZE, HISTORY_LENGTH, augment_all=True, max_files=number_max_files)
    epoch = start_epoch
    while True:
        train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
        train_one_iteration(model, train_loader, optimizer, device)
        epoch += 1
        policy_accuracy = validate_model(model, test_loader, device)
        if policy_accuracy > best_policy_accuracy:
            best_policy_accuracy = save_best_model(model, policy_accuracy, device, best_policy_accuracy)
        else:
            save_inference_model(model, device, "inference2_model_tmp.pt")
        lr_before = optimizer.param_groups[0]['lr']
        train_logger.info("Epoch {} - Before scheduler.step(): lr = {:.8f}".format(epoch+1, lr_before))
        scheduler.step(policy_accuracy)
        lr_after = optimizer.param_groups[0]['lr']
        train_logger.info("Epoch {} - After scheduler.step(): lr = {:.8f}".format(epoch+1, lr_after))
        dummy_best_val_loss = 0.0
        dummy_epochs_no_improve = 0
        save_checkpoint(model, optimizer, epoch, dummy_best_val_loss, dummy_epochs_no_improve, best_policy_accuracy, CHECKPOINT_FILE, device)
        train_logger.info("Iteration completed. Restarting next iteration...\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default=os.path.join(BASE_DIR, "config_py.ini"),
                        help="Path to configuration file")
    parser.add_argument("--checkpoint", type=str, default=CHECKPOINT_FILE,
                        help="Path to checkpoint file")
    args, unknown = parser.parse_known_args()
    if not os.path.exists(args.config):
        sgf_logger.warning("Config file not found. Using default hyperparameters.")
    train_logger.info("=== Starting Training and Validation Loop ===")
    if USE_TPU:
        import torch_xla.distributed.xla_multiprocessing as xmp
        nprocs = 1
        xmp.spawn(_mp_fn, args=(), nprocs=nprocs)
    else:
        _mp_fn(0)


Mounted at /content/drive
2025-04-13 15:04:07 INFO: ==== Loaded Configuration ====
2025-04-13 15:04:07 INFO: Config file: /content/drive/My Drive/sgf/config_py.ini
2025-04-13 15:04:07 INFO: BOARD_SIZE: 19
2025-04-13 15:04:07 INFO: HISTORY_LENGTH: 8
2025-04-13 15:04:07 INFO: NUM_CHANNELS: 17
2025-04-13 15:04:07 INFO: NUM_ACTIONS: 362
2025-04-13 15:04:07 INFO: num_residual_blocks: 20
2025-04-13 15:04:07 INFO: model_channels: 256
2025-04-13 15:04:07 INFO: num_epochs: 1000
2025-04-13 15:04:07 INFO: batch_size: 256
2025-04-13 15:04:07 INFO: learning_rate: 0.001
2025-04-13 15:04:07 INFO: patience: 1
2025-04-13 15:04:07 INFO: factor: 0.8
2025-04-13 15:04:07 INFO: number_max_files: 256
2025-04-13 15:04:07 INFO: === Starting Training and Validation Loop ===
2025-04-13 15:04:07 INFO: Running on TPU device: xla:0
2025-04-13 15:04:07 INFO: Test dataset pickle /content/drive/My Drive/sgf/test/test_dataset.pkl already exists. Loading it directly...
2025-04-13 15:04:23 INFO: Loaded dataset from /cont

Training:   4%|▍         |    100/  2521 [01:49<35:15,  1.14it/s]

Batch    1～ 100 policy accuracy average: 0.0083


Training:   8%|▊         |    200/  2521 [03:19<34:51,  1.11it/s]

Batch  101～ 200 policy accuracy average: 0.0112


Training:  12%|█▏        |    300/  2521 [04:49<33:22,  1.11it/s]

Batch  201～ 300 policy accuracy average: 0.0194


Training:  16%|█▌        |    400/  2521 [06:19<31:53,  1.11it/s]

Batch  301～ 400 policy accuracy average: 0.0316


Training:  20%|█▉        |    500/  2521 [07:50<30:32,  1.10it/s]

Batch  401～ 500 policy accuracy average: 0.0446


Training:  24%|██▍       |    600/  2521 [09:20<29:03,  1.10it/s]

Batch  501～ 600 policy accuracy average: 0.0443


Training:  28%|██▊       |    700/  2521 [10:51<27:25,  1.11it/s]

Batch  601～ 700 policy accuracy average: 0.0468


Training:  32%|███▏      |    800/  2521 [12:21<25:25,  1.13it/s]

Batch  701～ 800 policy accuracy average: 0.0507


Training:  36%|███▌      |    900/  2521 [13:50<23:53,  1.13it/s]

Batch  801～ 900 policy accuracy average: 0.0619


Training:  40%|███▉      |   1000/  2521 [15:20<22:10,  1.14it/s]

Batch  901～1000 policy accuracy average: 0.0583


Training:  44%|████▎     |   1100/  2521 [16:47<20:46,  1.14it/s]

Batch 1001～1100 policy accuracy average: 0.0663


Training:  48%|████▊     |   1200/  2521 [18:16<19:54,  1.11it/s]

Batch 1101～1200 policy accuracy average: 0.0672


Training:  50%|█████     |   1268/  2521 [19:18<18:53,  1.11it/s]