In [1]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResidualBlock, self).__init__()

        self.in_channel, self.out_channel = in_channel, out_channel

        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(out_channel, out_channel, kernel_size=1, padding=0)

        if in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
            )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = out + self.shortcut(x)
        return out

class ResNet(nn.Module):
    def __init__(self, color='gray'):
        super(ResNet, self).__init__()
        if color == "gray":
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        elif color == "rgb":
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)

        self.resblock1 = ResidualBlock(32, 64)
        self.resblock2 = ResidualBlock(64, 64)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.avgpool(x)
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [4]:
model = ResNet().to(device)

## Weight만 저장

In [5]:
torch.save(model.state_dict(), 'model_weights.pth')

In [6]:
model.load_state_dict(torch.load('model_weights.pth'))

<All keys matched successfully>

## 구조도 함께 저장

In [7]:
torch.save(model, 'model.pth')

In [8]:
model = torch.load('model.pth')

## Training

In [9]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize(mean=(0.5,), std=(0.5,))
                   ])),
    batch_size=batch_size,
    shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 108665976.42it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 49003112.39it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 24260752.45it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3941760.56it/s]


Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [10]:
x, y = next(iter(train_loader))
x.shape

torch.Size([32, 1, 28, 28])

In [12]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

opt = optim.Adam(model.parameters(), lr = 0.03)
scheduler = ReduceLROnPlateau(opt, mode = 'min', verbose = True)


# training loop

def train_loop(dataloader, model, loss_fn, optimizer, scheduler, epoch):
  model.train()
  size = len(dataloader)

  for batch, (x, y) in enumerate (dataloader):
    x, y = x.to(device), y.to(device)

    pred = model(x)
    loss = loss_fn(pred, y)

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

    if batch % 100 == 0:
      print(f"Epoch {epoch} : [{batch} / {size}] loss : {loss.item()}")

  scheduler.step(loss)
  return loss.item()

for epoch in range(1):
  loss = train_loop(train_loader, model, F.nll_loss, opt, scheduler, epoch)
  print(f"Epoch : {epoch} loss : {loss}")

Epoch 0 : [0 / 1875] loss : 2.294445514678955
Epoch 0 : [100 / 1875] loss : 1.6751128435134888
Epoch 0 : [200 / 1875] loss : 1.7363073825836182
Epoch 0 : [300 / 1875] loss : 2.0579750537872314
Epoch 0 : [400 / 1875] loss : 1.9238131046295166
Epoch 0 : [500 / 1875] loss : 1.4457284212112427
Epoch 0 : [600 / 1875] loss : 1.9327447414398193
Epoch 0 : [700 / 1875] loss : 1.2465046644210815
Epoch 0 : [800 / 1875] loss : 1.1064250469207764
Epoch 0 : [900 / 1875] loss : 1.1474356651306152
Epoch 0 : [1000 / 1875] loss : 1.0431838035583496
Epoch 0 : [1100 / 1875] loss : 0.9327719807624817
Epoch 0 : [1200 / 1875] loss : 0.9456428289413452
Epoch 0 : [1300 / 1875] loss : 0.7241479754447937
Epoch 0 : [1400 / 1875] loss : 0.9882229566574097
Epoch 0 : [1500 / 1875] loss : 0.471072793006897
Epoch 0 : [1600 / 1875] loss : 0.6657662391662598
Epoch 0 : [1700 / 1875] loss : 0.8462094664573669
Epoch 0 : [1800 / 1875] loss : 1.0661183595657349
Epoch : 0 loss : 0.43183913826942444


## Save, Load and Resuming Training

In [13]:
checkpoint_path = 'checkpoint.pth'

torch.save({
    'epoch' : epoch,
    'model_state_dict' : model.state_dict(),
    'optimizer_state_dict' : opt.state_dict(),
    'loss' : loss
    }, checkpoint_path)

In [14]:
model = ResNet().to(device)
optimizer = optim.SGD(model.parameters(), lr = 0.03)

In [15]:
checkpoint = torch.load(checkpoint_path)
checkpoint

{'epoch': 0,
 'model_state_dict': OrderedDict([('conv1.weight',
               tensor([[[[-0.4626, -0.1580, -0.2987],
                         [ 0.3756,  0.2549,  0.2500],
                         [-0.0184, -0.0073,  0.0830]]],
               
               
                       [[[ 0.0772, -0.2546, -0.0114],
                         [-0.5085,  0.2473,  0.2900],
                         [ 0.0534,  0.4324, -0.2734]]],
               
               
                       [[[ 0.2429, -0.0906, -0.3864],
                         [ 0.1794, -0.1416,  0.7130],
                         [-0.2229, -0.1571,  0.5845]]],
               
               
                       [[[-0.5760,  0.3702,  0.0943],
                         [ 0.3789, -0.0807, -0.0463],
                         [ 0.2128, -0.0099,  0.1147]]],
               
               
                       [[[ 0.6718, -0.7348, -0.8526],
                         [ 0.1229, -0.4172, -1.0173],
                         [ 0.7653,  0.8311, 

In [16]:
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])

In [17]:
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loss = checkpoint['loss']

# 다음에 학습할 때 위 정보를 이어서 학습할 수 있게 됨!