# AlexNet

In [3]:
import torch
import torch.nn as nn

IMAGE_SIZE = 227

class AlexNet(nn.Module):
  def __init__(self, num_classes=10):
    super(AlexNet, self).__init__()

    # 특징 추출
    self.features = nn.Sequential(
        nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
        nn.BatchNorm2d(96),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2),

        nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2),

        nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(384),
        nn.ReLU(inplace=True),

        nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(384),
        nn.ReLU(inplace=True),

        nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2)
    )

    dummy_input = torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)
    dummy_output = self.features(dummy_input)
    flatten_dim = dummy_output.view(1, -1).shape[1]


    # 분류
    self.classifier = nn.Sequential(
        nn.Flatten(),

        nn.Linear(flatten_dim, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),

        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),

        nn.Linear(4096, num_classes)
    )

  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x

In [4]:
from torch.utils.data import Dataset
import numpy as np
import cv2
from sklearn.utils import shuffle

BATCH_SIZE = 64

class Cifar10Dataset(Dataset):
  def __init__(self, images, labels=None, image_size=IMAGE_SIZE, augmentor=None, preprocess_function=None, shuffle_data=False):
    self.image = images
    self.labels = labels
    self.image_size = image_size
    self.augmentor = augmentor
    self.preprocess_function = preprocess_function

    if shuffle_data and labels is not None:
      self.images, self.labels = shuffle(self.images, self.labels)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    image = self.images[idx]
    label = self.labels[idx] if self.labels is not None else None

    if self.augmentor is not None:
      image = self.augmentor(image)['image']

    image = cv2.resize(image, (self.image_size, self.image_size))

    if self.preprocess_function is not None:
      image = self.preprocess_function(image)

    # 채널 순서 변경 (h, w, c) -> (c, h, w)
    image = np.transpose(image, (2, 0, 1)).astype(np.float32)

    if label is not None:
      return image, label
    else:
      return image


In [None]:
from torchvision.datasets import CIFAR10
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.nn.functional as F

train_data = CIFAR10(root='./', train=True, download=True)
test_data = CIFAR10(root='./', train=False, download=True)

train_images = np.array(train_data.data)
train_labels = np.array(train_data.targets).reshape(-1, 1)
test_images = np.array(test_data.data)
test_labels = np.array(test_data.targets).reshape(-1, 1)