In [5]:
import os
import re
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pyscreeze
import uuid
from PIL import Image

In [6]:
# 識別モデル

DATA_DIR = "./data/piece-images"
TORCH_MODEL_PATH = "./models/piece_classifier.pth"
IMAGE_SIZE = 64

idx_to_label = [label for label in sorted(os.listdir(DATA_DIR)) if '.' not in label]

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3,),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = nn.Linear(9216, len(idx_to_label))

    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

def get_model():
    model = Net()
    model.load_state_dict(torch.load(TORCH_MODEL_PATH))
    model.eval()
    return model

def predict(images: list[Image.Image]) -> list[str]:
    transform = transforms.Compose(
        [
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.57665089, 0.5822121, 0.54763596),
                (0.18085433, 0.21391266, 0.23309964)
            )
        ]
    )
    inputs = torch.stack([transform(image) for image in images])

    model = get_model()
    output = model(inputs)
    pred = torch.argmax(output, dim=1)
    return [idx_to_label[index] for index in pred]


In [7]:
# ゲーム画像取り込み

# これから取り込む盤面サイズ
WIDTH = 9
HEIGHT = 9

# 9x9 盤面の位置・サイズ
LEFT = 812
TOP = 515
SIZE = 1314
UNIT = SIZE // 9

def capture(width = WIDTH, height = HEIGHT) -> Image.Image:
    board_width = UNIT * min(width, 9)
    board_height = int(board_width / width * height)
    if height > 9 and height > width:
        board_height = UNIT * min(height, 9)
        board_width = int(board_height / height * width)
    box = (
        int(LEFT + SIZE / 2 - board_width / 2),
        int(TOP + SIZE / 2 - board_height / 2),
        int(LEFT + SIZE / 2 + board_width / 2),
        int(TOP + SIZE / 2 + board_height / 2)
    )
    return pyscreeze.screenshot().crop(box)

In [10]:
# ピース画像保存

PIECE_IMAGE_DIR = "./data/sample-piece-images"

colors = """
_... ..._
.........
.........
.........
.........
.........
.........
.........
 ... ...
""".strip()

positions: list[tuple[int, int]] = []
for y, line in enumerate(colors.splitlines()):
    for x, token in enumerate(line):
        if token != " " and token != "_":
            positions.append((x, y))

board_image = capture(WIDTH, HEIGHT)
unit = board_image.width / WIDTH
piece_images = [
    board_image.crop((unit * x, unit * y, unit * (x + 1), unit * (y + 1))).convert('RGB')
    for x, y in positions
]
piece_labels = predict(piece_images)

for label, image in zip(piece_labels, piece_images):
    os.makedirs(f"{PIECE_IMAGE_DIR}/{label}", exist_ok=True)
    image.save(f"{PIECE_IMAGE_DIR}/{label}/{uuid.uuid4()}.png")

In [28]:
# みかん識別モデル

MIKAN_DATA_DIR = "./data/mikans"
MIKAN_TORCH_MODEL_PATH = "./models/mikan_classifier.pth"
MIKAN_IMAGE_SIZE = 96

mikan_idx_to_label = [label for label in sorted(os.listdir(MIKAN_DATA_DIR)) if re.match(r"\d+", label)]

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 5,),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(32, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = nn.Linear(10368, len(mikan_idx_to_label))

    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

def get_model():
    model = Net()
    model.load_state_dict(torch.load(MIKAN_TORCH_MODEL_PATH))
    model.eval()
    return model

def predict(images: list[Image.Image]) -> list[str]:
    transform = transforms.Compose(
        [
            transforms.Resize((MIKAN_IMAGE_SIZE, MIKAN_IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(
               (0.63526792, 0.57570206, 0.48665065),
                (0.22508891, 0.20648694, 0.26393888)
            )
        ]
    )
    inputs = torch.stack([transform(image) for image in images])

    model = get_model()
    output = model(inputs)
    pred = torch.argmax(output, dim=1)
    return [mikan_idx_to_label[index] for index in pred]


In [30]:
# みかん画像保存

MIKAN_IMAGE_DIR = "./data/mikans/raw"

mikan_1_1 = """
...   ...
.........
.........
 .......
  5..5.
  .. ..
 ... ...
a......a.
.........
""".strip()

mikan_1_2 = """
_.......
... . ...
.........
 a....a.
 ..   ..
f... ..f.
.........
 a....a.
 .......
""".strip()

mikan_3 = """
k. ... k.
.. ... ..
   ...
.........
.... ....
.........
   ...
u. ... u.
.. ... ..
""".strip()

mikan_4_1 = """
_.......
.........
.........
 a. . a.
 .. . ..
.........
.........
f. ... f.
.. ... ..
""".strip()

mikan_4_2 = """
.. ... ..
.. ... ..
... . ...
 .......
f......f.
.........
   ...
 p. . p.
 .. . ..
""".strip()

mikan = mikan_4_2

positions: list[tuple[int, int]] = []
for y, line in enumerate(mikan.splitlines()):
    for x, token in enumerate(line):
        if (re.match(r"[1-9a-z]", token)):
            positions.append((x, y))

board_image = capture(9, 9)
images = [
    board_image.crop((UNIT * x, UNIT * y, UNIT * (x + 2), UNIT * (y + 2))).convert('RGB')
    for x, y in positions
]
labels = predict(images)

for label, image in zip(labels, images):
    os.makedirs(f"{MIKAN_IMAGE_DIR}/{label}", exist_ok=True)
    image.save(f"{MIKAN_IMAGE_DIR}/{label}/{uuid.uuid4()}.png")



In [32]:
# 3Dプリンター画像保存

PRINTER_IMAGE_DIR = "./data/cap-printer"

printer = """
_... ..._
.........
.........
..5. 5...
.... ....
.........
.........
...   ...
""".strip()

positions: list[tuple[int, int]] = []
for y, line in enumerate(printer.splitlines()):
    for x, token in enumerate(line):
        if (re.match(r"[1-9a-z]", token)):
            positions.append((x, y))

board_image = capture(9, 8)
images = [
    board_image.crop((UNIT * x, UNIT * y, UNIT * (x + 2), UNIT * (y + 2))).convert('RGB')
    for x, y in positions
]

for image in images:
    os.makedirs(PRINTER_IMAGE_DIR, exist_ok=True)
    image.save(f"{PRINTER_IMAGE_DIR}/{uuid.uuid4()}.png")

