# Model save

In [33]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

from torchvision import datasets, transforms

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [35]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResidualBlock, self).__init__()
        
        self.in_channel, self.out_channel = in_channel, out_channel
        
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(out_channel, out_channel, kernel_size=1, padding=0)
        
        if in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
            )
        else:
            self.shortcut = nn.Sequential()
    
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = out + self.shortcut(x)
        return out

class ResNet(nn.Module):
    def __init__(self, color='gray'):
        super(ResNet, self).__init__()
        if color == "gray":
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        elif color == "rgb":
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
            
        self.resblock1 = ResidualBlock(32, 64)
        self.resblock2 = ResidualBlock(64, 64)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.avgpool(x)
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [36]:
# class ResidualBlock(nn.Module):
#     def __init__(self, in_channel, out_channel):
#         super(ResidualBlock, self).__init__()

#         self.in_channel, self.out_channel = in_channel, out_channel

#         self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
#         self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
#         self.conv3 = nn.Conv2d(out_channel, out_channel, kernel_size=1, padding=0)
        
#         if in_channel != out_channel:
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channel, out_channel, kernel_size =1, padding = 0)
#             )
#         else:
#             self.shortcut = nn.Sequential()

#     def forward(self, x):
#         out = F.relu(self.conv1(x))
#         out = F.relu(self.conv2(out))
#         out = F.relu(self.conv3(out))
#         out = out + self.shortcut(x)
#         return out

# class ResNet(nn.Module):
#     def __init__(self, color='gray'):
#         super(ResNet, self).__init__()
#         if color == 'gray':
#             self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
#                                     # 1 : 그레이, 단색밖에 없으니까
#         elif color == 'rgb':
#             self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)

#         self.resblock1 = ResidualBlock(32,64)
#         self.resblock2 = ResidualBlock(64,64)
        
#         self.avgpool = nn.AdaptiveAvgPool2d((1,1))
#         self.fc1 = nn.Linear(64, 64)
#         self.fc2 = nn.Linear(64, 10)

#     def forward(self, x):
#         x = F.relu(self.conv1(x))
#         x = F.max_pool2d(x, 2, 2)
#         x = self.resblock1(x)
#         x = self.resblock2(x)
#         x = self.avgpool(x)
#         x = torch.flatten(x,1)
#         x = F.relu(self.fc1(x))
#         x = self.fc2(x)
#         x = F.log_softmax(x, dim=1)
#         return x

# model = ResNet().to(device)

In [37]:
model = ResNet().to(device)

# weight만 저장

In [38]:
torch.save(model.state_dict(), 'model_weights.pth')

In [39]:
model.load_state_dict(torch.load('model_weights.pth'))

<All keys matched successfully>

# 구조도 함께 저장

In [40]:
torch.save(model, 'model.pth')

In [41]:
model = torch.load('model.pth')

# Training

In [42]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset/', train=True, download=False,\
                    transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(0.5, 0.5),
                    ])),
    batch_size = batch_size,
    shuffle=True
)

In [43]:
x, y = next(iter(train_loader))
x.shape

torch.Size([32, 1, 28, 28])

In [44]:
# from torch.optim.lr_scheduler import ReduceLROnPlateau

# optimizer = optim.SGD(model.parameters(), lr=0.003)
# scheduler = ReduceLROnPlateau(optimizer, model='min', verbose=True)

# def train_loop(dataloader, model, loss_fn, optimizer, scheduler, epoch):
#     model.train()
#     size = len(dataloader)
#     for batch, (x,y) in enumerate(dataloader):
#         x, y = x.to(device), y.to(device)

#         pred = model(x)
#         loss = loss_fn(pred, y)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if batch % 100 == 0:
#             loss = loss.item()
#             print(f'Epoch {epoch} : [{batch} / {size}] loss : {loss}')

#     scheduler.step(loss)

#     return loss.item()

# for epoch in range(1):
#     loss = train_loop(train_loader, model, F.nll_loss, optimizer, scheduler, epoch)
#     print(f'epoch : {epoch} loss : {loss}')

In [45]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

optimizer = optim.SGD(model.parameters(), lr=0.03)
scheduler = ReduceLROnPlateau(optimizer, mode = 'min', verbose=True)

def train_loop(dataloader, model, loss_fn, optimizer, scheduler, epoch):
    model.train()
    size = len(dataloader)


    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()   # 초기화

        loss.backward()     # 미분값 계산

        optimizer.step

        if batch % 100 == 0:
            loss = loss.item()
            print(f'Epoch {epoch} : [{batch}/{size}] loss : {loss}')

    scheduler.step(loss)

    return loss.item()

for epoch in range(10):
    loss = train_loop(train_loader, model, F.nll_loss, optimizer, scheduler, epoch)
    print(f'epoch : {epoch} loss : {loss}')

Epoch 0 : [0/1875] loss : 2.3018546104431152
Epoch 0 : [100/1875] loss : 2.302274465560913
Epoch 0 : [200/1875] loss : 2.283839464187622
Epoch 0 : [300/1875] loss : 2.324542284011841
Epoch 0 : [400/1875] loss : 2.300628662109375
Epoch 0 : [500/1875] loss : 2.315134286880493
Epoch 0 : [600/1875] loss : 2.3080620765686035
Epoch 0 : [700/1875] loss : 2.297044038772583
Epoch 0 : [800/1875] loss : 2.299560785293579
Epoch 0 : [900/1875] loss : 2.3014931678771973
Epoch 0 : [1000/1875] loss : 2.2910258769989014
Epoch 0 : [1100/1875] loss : 2.2976624965667725
Epoch 0 : [1200/1875] loss : 2.3103346824645996
Epoch 0 : [1300/1875] loss : 2.3117756843566895
Epoch 0 : [1400/1875] loss : 2.3235435485839844
Epoch 0 : [1500/1875] loss : 2.2999377250671387
Epoch 0 : [1600/1875] loss : 2.302978515625
Epoch 0 : [1700/1875] loss : 2.320695400238037
Epoch 0 : [1800/1875] loss : 2.2903568744659424
epoch : 0 loss : 2.2842984199523926
Epoch 1 : [0/1875] loss : 2.3145430088043213
Epoch 1 : [100/1875] loss : 2.3

# Save, Load and Resuming Training

In [46]:
checkpoint_path = 'checkpoint.pth'

In [47]:
torch.save({
    'epoch' : epoch,
    'model_state_dict' : model.state_dict(),
    'optimizer_state_dict' : optimizer.state_dict(),
    'loss': loss
}, checkpoint_path)

In [48]:
model =ResNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.003)

In [49]:
checkpoint = torch.load(checkpoint_path)
checkpoint

{'epoch': 9,
 'model_state_dict': OrderedDict([('conv1.weight',
               tensor([[[[-0.2442, -0.0999,  0.0123],
                         [ 0.1136,  0.1975,  0.1067],
                         [-0.2070, -0.2523, -0.1994]]],
               
               
                       [[[ 0.2914, -0.3256, -0.2369],
                         [-0.2804, -0.2985, -0.2861],
                         [ 0.1401, -0.0921, -0.0584]]],
               
               
                       [[[ 0.1760, -0.2031,  0.2694],
                         [ 0.1439,  0.1016, -0.3137],
                         [-0.1393,  0.0312, -0.0743]]],
               
               
                       [[[-0.2223, -0.0884,  0.2614],
                         [-0.1359,  0.2702, -0.3200],
                         [ 0.0572, -0.2822,  0.0945]]],
               
               
                       [[[-0.3177,  0.1940,  0.0651],
                         [ 0.1900,  0.0911, -0.0599],
                         [ 0.3013, -0.1271, 

In [50]:
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])

In [53]:
# 다음에 학습할 때 해당 정보들을 가지고 이어서 진행하게 된다.
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loss = checkpoint['loss']