In [None]:
import torch 
from torch import optim
import sys 
from torchvision import datasets
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from models.resnet import *

## CIFAR-10の準備

In [None]:
data_path = "./data/"
cifar10_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transforms.ToTensor())
cifar10_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transforms.ToTensor())

In [None]:
img, label = cifar10_train[99]

In [None]:
plt.imshow(img.permute(1,2,0))

## imageのaugumentationや正規化を実施した上で、再度cifar10を読み込む

In [None]:
train_img_stack = torch.stack([img_t for img_t, _ in cifar10_train])
mean = train_img_stack.view(3, -1).mean(dim=1)
std = train_img_stack.view(3, -1).std(dim=1)

In [None]:
transforms_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]
)


cifar10_train_aug = datasets.CIFAR10(data_path, train=True, download=True, transform=transforms_train)

## DataLoaderを定義する

In [None]:
cifar10_train_dataloader = DataLoader(cifar10_train_aug, batch_size=64, shuffle=True)

## 学習用のセットアップを行う
- モデルの定義
- 損失関数の定義
- Optimizerの定義

In [None]:
# モデルの定義
resnet18 = resnet18()
# 損失関数の定義
loss_fn = nn.CrossEntropyLoss()
# Optimizerの定義
optimizer = optim.SGD(resnet18.parameters(), lr=1e-2)

## 学習用を実施

In [None]:
n_epochs = 1

for n_epoch in range(n_epochs):
    for imgs, labels in cifar10_train_dataloader:
        imgs = imgs.cuda()
        labels = labels.cuda()
        outputs = resnet18(imgs)
        loss = loss_fn(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()