# Model Save

In [10]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResidualBlock, self).__init__()
        
        self.in_channel, self.out_channel = in_channel, out_channel
        
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(out_channel, out_channel, kernel_size=1, padding=0)
        
        if in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
            )
        else:
            self.shortcut = nn.Sequential()
    
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = out + self.shortcut(x)
        return out

class ResNet(nn.Module):
    def __init__(self, color='gray'):
        super(ResNet, self).__init__()
        if color == "gray":
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        elif color == "rgb":
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
            
        self.resblock1 = ResidualBlock(32, 64)
        self.resblock2 = ResidualBlock(64, 64)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.avgpool(x)
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [4]:
model = ResNet().to(device)

## 1. weight 만 저장

In [5]:
torch.save(model.state_dict(), 'model_weights.pth')

In [6]:
model.load_state_dict(torch.load('model_weights.pth'))

<All keys matched successfully>

## 2. 구조도 함께 저장

In [7]:
torch.save(model, "model.pth")

model_2 = torch.load('model.pth')
print(model_2)

ResNet(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (resblock1): ResidualBlock(
    (conv1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (shortcut): Sequential(
      (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (resblock2): ResidualBlock(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (shortcut): Sequential()
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=64, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)


## 3. Training

In [8]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize(mean=(0.5,), std=(0.5,))
                   ])),
    batch_size=batch_size,
    shuffle=True)

In [9]:
x, y = next(iter(train_loader))
x.shape

torch.Size([32, 1, 28, 28])

In [11]:
optimizer = optim.Adam(model.parameters(), lr=0.03)
scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True)

def train_loop(dataloader, model, loss_fn, optimizer, scheduler, epoch):
    model.train()
    size = len(dataloader)
    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        
        pred = model(x)
        loss = loss_fn(pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss = loss.item()
            print(f"Epoch {epoch} : [{batch}/{size}] loss : {loss}")
            
    scheduler.step(loss)
    
    return loss.item()

for epoch in range(2):
    loss = train_loop(train_loader, model, F.nll_loss, optimizer, scheduler, epoch)
    print(f"epoch:{epoch} loss:{loss}" )

Epoch 0 : [0/1875] loss : 2.295548439025879
Epoch 0 : [100/1875] loss : 1.8849818706512451
Epoch 0 : [200/1875] loss : 1.6921669244766235
Epoch 0 : [300/1875] loss : 1.697455883026123
Epoch 0 : [400/1875] loss : 1.63392174243927
Epoch 0 : [500/1875] loss : 0.9757788181304932
Epoch 0 : [600/1875] loss : 1.493109107017517
Epoch 0 : [700/1875] loss : 0.9486769437789917
Epoch 0 : [800/1875] loss : 1.2624579668045044
Epoch 0 : [900/1875] loss : 0.8160249590873718
Epoch 0 : [1000/1875] loss : 1.3707740306854248
Epoch 0 : [1100/1875] loss : 0.6726800203323364
Epoch 0 : [1200/1875] loss : 0.5573949813842773
Epoch 0 : [1300/1875] loss : 0.4942813515663147
Epoch 0 : [1400/1875] loss : 1.3128349781036377
Epoch 0 : [1500/1875] loss : 0.5236315727233887
Epoch 0 : [1600/1875] loss : 0.6328096389770508
Epoch 0 : [1700/1875] loss : 0.2989206314086914
Epoch 0 : [1800/1875] loss : 0.6298531293869019
epoch:0 loss:0.47387564182281494
Epoch 1 : [0/1875] loss : 0.4432920217514038
Epoch 1 : [100/1875] loss :

## 4. Save, Load and Resuming Training

In [14]:
checkpoint_path = "checkpoint.pth"

In [15]:
torch.save({"epoch":epoch,  # 몇번째 epoch까지 돌았는지
           "model_state_dict":model.state_dict(),
           "optimizer_state_dict":optimizer.state_dict(), # 옵티마이저의 상태는 어떤지
           "loss":loss # loss는 얼마인지 
           }, checkpoint_path)

In [16]:
model = ResNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.03)

In [17]:
checkpoint = torch.load(checkpoint_path)
checkpoint

{'epoch': 1,
 'model_state_dict': OrderedDict([('conv1.weight',
               tensor([[[[ 1.2640e+00,  1.7569e-01, -1.6759e+00],
                         [ 2.0558e-01,  8.8453e-01, -2.7491e-01],
                         [-2.5114e-01,  9.1764e-01,  1.1310e+00]]],
               
               
                       [[[ 1.1152e+00,  6.5032e-02, -2.1101e-01],
                         [-4.2506e-01, -1.3728e+00,  1.4623e+00],
                         [ 1.1484e+00,  1.7362e-01,  2.5873e-01]]],
               
               
                       [[[-2.4723e-01, -6.2135e-01, -1.1641e+00],
                         [ 7.1118e-01, -2.3358e-01,  6.6371e-01],
                         [ 2.1636e-01, -2.7560e-01,  1.4702e+00]]],
               
               
                       [[[-8.1100e-01,  1.4209e+00,  5.2496e-01],
                         [-5.5043e-01,  4.5304e-01, -7.1943e-01],
                         [ 1.2039e-01,  1.3523e-01, -4.3572e-01]]],
               
               
        

In [18]:
# checkpoint 는 dic 형태
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])

In [20]:
model.load_state_dict(checkpoint["model_state_dict"])
spoch = checkpoint["epoch"]

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loss = checkpoint['loss']

- 다음에 학습할 때 이 정보들을 이용해 학습할 수 있다