In [1]:
import sys
import os
sys.path.append(os.path.abspath("../"))  # or "../../" depending on location

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from models.CNN import CNN

import os
from tqdm import tqdm

In [3]:
# data load
# 前処理（ToTensor + 正規化）
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.261))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.261))
])

# ダウンロード済みのデータを指定して読み込み
root = "../data"
train_dataset = datasets.CIFAR10(
    root=root, 
    train=True, 
    download=False,  
    transform=transform_train
)

test_dataset = datasets.CIFAR10(
    root=root, 
    train=False, 
    download=False, 
    transform=transform_test
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)



In [4]:
# train models
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")
model = CNN()
model = model.to(device)

Using device: cuda


In [None]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"[Epoch {epoch+1}] Train Loss: {avg_loss:.4f}")

    # 検証
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)
    acc = 100. * correct / total
    print(f"Test Accuracy: {acc:.2f}%")


100%|██████████| 391/391 [00:09<00:00, 40.98it/s]


[Epoch 1] Train Loss: 1.5912
Test Accuracy: 53.13%


100%|██████████| 391/391 [00:11<00:00, 33.59it/s]


[Epoch 2] Train Loss: 1.2222
Test Accuracy: 64.56%


100%|██████████| 391/391 [00:11<00:00, 33.14it/s]


[Epoch 3] Train Loss: 1.0400
Test Accuracy: 66.82%


100%|██████████| 391/391 [00:13<00:00, 28.60it/s]


[Epoch 4] Train Loss: 0.9349
Test Accuracy: 71.75%


100%|██████████| 391/391 [00:11<00:00, 34.55it/s]


[Epoch 5] Train Loss: 0.8618
Test Accuracy: 73.85%


100%|██████████| 391/391 [00:11<00:00, 32.66it/s]


[Epoch 6] Train Loss: 0.8106
Test Accuracy: 74.51%


100%|██████████| 391/391 [00:14<00:00, 27.41it/s]


[Epoch 7] Train Loss: 0.7685
Test Accuracy: 75.43%


100%|██████████| 391/391 [00:12<00:00, 32.20it/s]


[Epoch 8] Train Loss: 0.7368
Test Accuracy: 76.75%


100%|██████████| 391/391 [00:12<00:00, 30.97it/s]


[Epoch 9] Train Loss: 0.7005
Test Accuracy: 76.71%


100%|██████████| 391/391 [00:11<00:00, 32.66it/s]


[Epoch 10] Train Loss: 0.6773
Test Accuracy: 78.45%


100%|██████████| 391/391 [00:11<00:00, 32.62it/s]


[Epoch 11] Train Loss: 0.6540
Test Accuracy: 79.06%


100%|██████████| 391/391 [00:12<00:00, 31.20it/s]


[Epoch 12] Train Loss: 0.6314
Test Accuracy: 79.20%


100%|██████████| 391/391 [00:15<00:00, 25.74it/s]


[Epoch 13] Train Loss: 0.6098
Test Accuracy: 78.63%


 47%|████▋     | 185/391 [00:06<00:06, 30.22it/s]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "../models/pretrained/CNN_cifar10.pth")