In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from timm import create_model
from utils.image_utils import GameScreenshotDataset, get_training_transform, save_model
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import japanize_matplotlib

In [None]:
# トレーニングのパラメータ
BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.05

# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
else:
    print("CPU")

In [None]:
# データセットの作成
train_dataset = GameScreenshotDataset("training", transform=get_training_transform())
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Classes: {train_dataset.classes}")
print(f"Total training images: {len(train_dataset)}")


# サンプル画像の表示
def show_sample_images(dataset, num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    for i in range(num_images):
        img, label = dataset[i]
        # 画像の正規化を戻す
        img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        axes[i].imshow(img.permute(1, 2, 0).clip(0, 1))
        axes[i].set_title(f"Class: {dataset.classes[label]}")
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


show_sample_images(train_dataset)

In [None]:
# モデルの初期化
model = create_model("swin_base_patch4_window7_224", pretrained=True, num_classes=len(train_dataset.classes))
model = model.to(device)

# 損失関数とオプティマイザの設定
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [None]:
# トレーニングループ
train_losses = []
train_accuracies = []

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # プログレスバーの更新
        progress_bar.set_postfix({"loss": running_loss / len(train_loader), "acc": 100.0 * correct / total})

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100.0 * correct / total

    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)

    print(f"Epoch {epoch+1}: Loss = {epoch_loss:.4f}, Accuracy = {epoch_acc:.2f}%")

In [None]:
# 損失と精度のプロット
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.subplot(1, 2, 2)
plt.plot(train_accuracies)
plt.title("Training Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")

plt.tight_layout()
plt.show()

In [None]:
# モデルの保存
save_model(model, "models/trained_model_1.pickle")
print("Model saved successfully!")

In [None]:
def test_model(model, test_image_path):
    model.eval()
    transform = GameScreenshotDataset.get_default_transform()

    # 画像の読み込みと前処理
    image = Image.open(test_image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)

    # 予測
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        predicted_class = train_dataset.classes[output.argmax(1).item()]
        confidence = probabilities.max().item()

    print(f"Predicted class: {predicted_class}")
    print(f"Confidence: {confidence:.2%}")

    # 画像の表示
    plt.imshow(image)
    plt.title(f"Prediction: {predicted_class} ({confidence:.2%})")
    plt.axis("off")
    plt.show()


# テスト画像でモデルをテスト
test_image_path = "input\\Screenshot_2024.12.03_21.07.52.109.png"  # テスト画像のパスを指定
test_model(model, test_image_path)