In [1]:
import sys
import os
sys.path.append(os.path.abspath("../"))  # or "../../" depending on location

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
# from models.ResNet18 import get_resnet18_for_cifar10
from models.CNN import CNN

import os
from tqdm import tqdm

In [None]:
# data load
# 前処理（ToTensor + 正規化）
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.261))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.261))
])

# ダウンロード済みのデータを指定して読み込み
train_dataset = datasets.CIFAR10(
    root="../data/cifar10", 
    train=True, 
    download=False,  
    transform=transform_train
)

test_dataset = datasets.CIFAR10(
    root="../data/cifar10", 
    train=False, 
    download=False, 
    transform=transform_test
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

subset_size = 10000
indices = np.random.choice(len(train_dataset), subset_size, replace=False)

# サブセットデータセットを作成
small_train_dataset = Subset(train_dataset, indices)
train_loader_small = DataLoader(small_train_dataset, batch_size=128, shuffle=True)


In [4]:
# train models
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")
model = CNN()
model = model.to(device)

Using device: mps


In [5]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 40

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(train_loader_small):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"[Epoch {epoch+1}] Train Loss: {avg_loss:.4f}")

    # 検証
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)
    acc = 100. * correct / total
    print(f"Test Accuracy: {acc:.2f}%")


100%|██████████| 79/79 [00:02<00:00, 33.68it/s]


[Epoch 1] Train Loss: 0.4017
Test Accuracy: 39.04%


100%|██████████| 79/79 [00:01<00:00, 41.09it/s]


[Epoch 2] Train Loss: 0.3407
Test Accuracy: 46.19%


100%|██████████| 79/79 [00:01<00:00, 40.98it/s]


[Epoch 3] Train Loss: 0.3098
Test Accuracy: 48.87%


100%|██████████| 79/79 [00:01<00:00, 42.09it/s]


[Epoch 4] Train Loss: 0.2931
Test Accuracy: 52.17%


100%|██████████| 79/79 [00:02<00:00, 38.64it/s]


[Epoch 5] Train Loss: 0.2782
Test Accuracy: 54.15%


100%|██████████| 79/79 [00:01<00:00, 40.60it/s]


[Epoch 6] Train Loss: 0.2678
Test Accuracy: 57.13%


100%|██████████| 79/79 [00:01<00:00, 41.54it/s]


[Epoch 7] Train Loss: 0.2518
Test Accuracy: 57.36%


100%|██████████| 79/79 [00:01<00:00, 40.87it/s]


[Epoch 8] Train Loss: 0.2436
Test Accuracy: 57.31%


100%|██████████| 79/79 [00:01<00:00, 40.84it/s]


[Epoch 9] Train Loss: 0.2338
Test Accuracy: 61.96%


100%|██████████| 79/79 [00:01<00:00, 40.98it/s]


[Epoch 10] Train Loss: 0.2210
Test Accuracy: 62.46%


100%|██████████| 79/79 [00:01<00:00, 40.79it/s]


[Epoch 11] Train Loss: 0.2181
Test Accuracy: 59.48%


100%|██████████| 79/79 [00:01<00:00, 41.10it/s]


[Epoch 12] Train Loss: 0.2102
Test Accuracy: 64.47%


100%|██████████| 79/79 [00:01<00:00, 42.08it/s]


[Epoch 13] Train Loss: 0.2068
Test Accuracy: 65.58%


100%|██████████| 79/79 [00:01<00:00, 41.72it/s]


[Epoch 14] Train Loss: 0.1957
Test Accuracy: 65.54%


100%|██████████| 79/79 [00:02<00:00, 39.25it/s]


[Epoch 15] Train Loss: 0.1881
Test Accuracy: 66.65%


100%|██████████| 79/79 [00:01<00:00, 40.94it/s]


[Epoch 16] Train Loss: 0.1848
Test Accuracy: 65.85%


100%|██████████| 79/79 [00:01<00:00, 41.07it/s]


[Epoch 17] Train Loss: 0.1795
Test Accuracy: 66.38%


100%|██████████| 79/79 [00:01<00:00, 40.80it/s]


[Epoch 18] Train Loss: 0.1761
Test Accuracy: 68.36%


100%|██████████| 79/79 [00:01<00:00, 40.76it/s]


[Epoch 19] Train Loss: 0.1720
Test Accuracy: 66.64%


100%|██████████| 79/79 [00:01<00:00, 40.29it/s]


[Epoch 20] Train Loss: 0.1658
Test Accuracy: 69.40%


100%|██████████| 79/79 [00:01<00:00, 40.30it/s]


[Epoch 21] Train Loss: 0.1618
Test Accuracy: 68.95%


100%|██████████| 79/79 [00:01<00:00, 40.71it/s]


[Epoch 22] Train Loss: 0.1606
Test Accuracy: 70.12%


100%|██████████| 79/79 [00:01<00:00, 40.94it/s]


[Epoch 23] Train Loss: 0.1538
Test Accuracy: 69.73%


100%|██████████| 79/79 [00:01<00:00, 41.46it/s]


[Epoch 24] Train Loss: 0.1527
Test Accuracy: 71.08%


100%|██████████| 79/79 [00:01<00:00, 40.39it/s]


[Epoch 25] Train Loss: 0.1481
Test Accuracy: 71.16%


100%|██████████| 79/79 [00:01<00:00, 41.68it/s]


[Epoch 26] Train Loss: 0.1478
Test Accuracy: 70.31%


100%|██████████| 79/79 [00:01<00:00, 41.01it/s]


[Epoch 27] Train Loss: 0.1436
Test Accuracy: 70.59%


100%|██████████| 79/79 [00:01<00:00, 41.66it/s]


[Epoch 28] Train Loss: 0.1420
Test Accuracy: 70.95%


100%|██████████| 79/79 [00:01<00:00, 42.01it/s]


[Epoch 29] Train Loss: 0.1399
Test Accuracy: 71.08%


100%|██████████| 79/79 [00:01<00:00, 41.28it/s]


[Epoch 30] Train Loss: 0.1387
Test Accuracy: 72.33%


100%|██████████| 79/79 [00:01<00:00, 40.66it/s]


[Epoch 31] Train Loss: 0.1313
Test Accuracy: 71.62%


100%|██████████| 79/79 [00:01<00:00, 41.12it/s]


[Epoch 32] Train Loss: 0.1305
Test Accuracy: 71.63%


100%|██████████| 79/79 [00:01<00:00, 41.49it/s]


[Epoch 33] Train Loss: 0.1258
Test Accuracy: 70.66%


100%|██████████| 79/79 [00:01<00:00, 40.93it/s]


[Epoch 34] Train Loss: 0.1279
Test Accuracy: 71.75%


100%|██████████| 79/79 [00:01<00:00, 41.08it/s]


[Epoch 35] Train Loss: 0.1229
Test Accuracy: 71.87%


100%|██████████| 79/79 [00:01<00:00, 41.13it/s]


[Epoch 36] Train Loss: 0.1223
Test Accuracy: 71.65%


100%|██████████| 79/79 [00:01<00:00, 40.94it/s]


[Epoch 37] Train Loss: 0.1205
Test Accuracy: 72.03%


100%|██████████| 79/79 [00:01<00:00, 40.48it/s]


[Epoch 38] Train Loss: 0.1192
Test Accuracy: 72.96%


100%|██████████| 79/79 [00:01<00:00, 41.34it/s]


[Epoch 39] Train Loss: 0.1130
Test Accuracy: 72.73%


100%|██████████| 79/79 [00:01<00:00, 39.66it/s]


[Epoch 40] Train Loss: 0.1094
Test Accuracy: 73.60%


In [None]:
torch.save(model.state_dict(), "../models/pretrained/CNN_cifar10_small.pth")