# Training Logic

In [6]:
import torch
from torch import nn 
from torch import optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize(mean=(0.5,), std=(0.5,))
                   ])),
    batch_size=batch_size,
    shuffle=True)

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResidualBlock, self).__init__()
        
        self.in_channel, self.out_channel = in_channel, out_channel
        
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(out_channel, out_channel, kernel_size=1, padding=0)
        
        if in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
            )
        else:
            self.shortcut = nn.Sequential()
    
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        out = out + self.shortcut(x)
        return out

class ResNet(nn.Module):
    def __init__(self, color='gray'):
        super(ResNet, self).__init__()
        if color == "gray":
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        elif color == "rgb":
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
            
        self.resblock1 = ResidualBlock(32, 64)
        self.resblock2 = ResidualBlock(64, 64)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.avgpool(x)
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [5]:
model = ResNet().to(device)

##  Learning Rate Scheduler

In [9]:
optimizer = optim.Adam(model.parameters(), lr=0.03)

In [10]:
scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True)

- mode는 loss 인 경우 "min", acc인 경우 "max"를 사용
- verbose=True: lr이 바뀔때마다 log를 보여준다

In [11]:
def train_loop(dataloader, model, loss_fn, optimizer, scheduler, epoch):
    model.train()
    size = len(dataloader)
    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        
        pred = model(x)
        loss = loss_fn(pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss = loss.item()
            print(f"Epoch {epoch} : [{batch}/{size}] loss : {loss}")
            
    scheduler.step(loss)
    
    return loss.item()

In [13]:
for epoch in range(2):
    loss = train_loop(train_loader, model, F.nll_loss, optimizer, scheduler, epoch)
    print(f"epoch:{epoch} loss:{loss}" )
    

Epoch 0 : [0/1875] loss : 2.243256092071533
Epoch 0 : [100/1875] loss : 1.508970022201538
Epoch 0 : [200/1875] loss : 1.23863685131073
Epoch 0 : [300/1875] loss : 1.3099850416183472
Epoch 0 : [400/1875] loss : 1.6059422492980957
Epoch 0 : [500/1875] loss : 1.469323754310608
Epoch 0 : [600/1875] loss : 0.9231387972831726
Epoch 0 : [700/1875] loss : 1.1293424367904663
Epoch 0 : [800/1875] loss : 1.134285807609558
Epoch 0 : [900/1875] loss : 0.821662187576294
Epoch 0 : [1000/1875] loss : 0.9763290286064148
Epoch 0 : [1100/1875] loss : 0.7143877744674683
Epoch 0 : [1200/1875] loss : 0.8270964026451111
Epoch 0 : [1300/1875] loss : 0.6810207962989807
Epoch 0 : [1400/1875] loss : 0.6565802693367004
Epoch 0 : [1500/1875] loss : 0.7728536128997803
Epoch 0 : [1600/1875] loss : 1.0831559896469116
Epoch 0 : [1700/1875] loss : 0.5612004399299622
Epoch 0 : [1800/1875] loss : 0.4417472183704376
epoch:0 loss:0.7592490911483765
Epoch 1 : [0/1875] loss : 0.654495358467102
Epoch 1 : [100/1875] loss : 0.5