<a href="https://colab.research.google.com/github/hang-1n-there/Resnet_CIFA10/blob/main/resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class BasicBlock(nn.Module):
  mul = 1 # 출력 채널 수를 조절

  def __init__(self, in_channel, out_channel, stride=1):
    super(BasicBlock, self).__init__()

    self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(out_channel)

    self.conv2 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(out_channel)

    self.shortcut = nn.Sequential()

    if stride !=1:
      self.shortcut = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False),
        nn.BatchNorm2d(out_channel),
      )

  def forward(self,x):
    result = self.conv1(x)
    result = self.bn1(result)
    result = F.relu(result)
    result = self.conv2(x)
    result = self.bn2(result)
    result += self.shortcut(x)
    result = F.relu(result)

    return result

class BottleNeck(nn.Module):
    mul = 4

    def __init__(self, in_channel, out_channel, stride=1):
        super(BottleNeck, self).__init__()

        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)

        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)

        self.conv3 = nn.Conv2d(out_channel, out_channel * self.mul, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel * self.mul)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channel != out_channel * self.mul:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel * self.mul, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(out_channel * self.mul)
            )

    def forward(self, x):
        result = F.relu(self.bn1(self.conv1(x)))
        result = F.relu(self.bn2(self.conv2(result)))
        result = self.bn3(self.conv3(result))

        shortcut = self.shortcut(x)
        result += shortcut
        result = F.relu(result)

        return result

class Resnet(nn.Module):
  def __init__(self, block, num_blocks, num_classes = 10):
    super(Resnet, self).__init__()

    self.in_channel = 64

    self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3)
    self.bn1 = nn.BatchNorm2d(self.in_channel)
    self.maxpool1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

    self.layer1 = self.make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self.make_layer(block, 128, num_blocks[0], stride=1)
    self.layer3 = self.make_layer(block, 256, num_blocks[0], stride=1)
    self.layer4 = self.make_layer(block, 518, num_blocks[0], stride=1)
    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    self.linear = nn.Linear(2072, num_classes)

  def make_layer(self, block, out_channel, num_blocks, stride):
    # 다운샘플링을 위해 첫 번째 블럭에만 stride 적용
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []

    for num_block in range(num_blocks):
      layers.append(block(self.in_channel, out_channel, strides[num_block]))
      self.in_channel = block.mul * out_channel

    return nn.Sequential(*layers)

  def forward(self, x):
    result = self.conv1(x)
    result = self.bn1(result)
    result = F.relu(result)
    result = self.maxpool1(result)
    result = self.layer1(result)
    result = self.layer2(result)
    result = self.layer3(result)
    result = self.layer4(result)
    result = self.avgpool(result)
    result = torch.flatten(result, 1)
    result = self.linear(result)

    return result

def ResNet18():
    return Resnet(BasicBlock, [2, 2, 2, 2])

def ResNet34():
    return Resnet(BasicBlock, [3, 4, 6, 3])

def ResNet50():
    return Resnet(BottleNeck, [3, 4, 6, 3])

def ResNet101():
    return Resnet(BottleNeck, [3, 4, 23, 3])

def ResNet152():
    return Resnet(BottleNeck, [3, 8, 36, 3])

In [None]:
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
import os
import torchvision.models as models

In [None]:
# Simple Learning Rate Scheduler
def lr_scheduler(optimizer, epoch):
    lr = learning_rate
    if epoch >= 50:
        lr /= 10
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Xavier
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=8)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 40829185.89it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified




In [None]:
device = 'cuda'
model = ResNet50()

In [None]:
model.apply(init_weights)
model = model.to(device)

  torch.nn.init.xavier_uniform(m.weight)


In [None]:
learning_rate = 0.1
num_epoch = 150
model_name = 'model.pth'

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)

train_loss = 0
valid_loss = 0
correct = 0
total_cnt = 0
best_acc = 0

In [None]:
# Train
from tqdm import tqdm

for epoch in tqdm(range(num_epoch)):
    print(f"====== { epoch+1} epoch of { num_epoch } ======")
    model.train()
    lr_scheduler(optimizer, epoch)
    train_loss = 0
    valid_loss = 0
    correct = 0
    total_cnt = 0

    for step, batch in enumerate(train_loader):
        #  input and target
        batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()

        logits = model(batch[0])
        loss = loss_fn(logits, batch[1])
        loss.backward()

        optimizer.step()
        train_loss += loss.item()
        _, predict = logits.max(1)

        total_cnt += batch[1].size(0)
        correct +=  predict.eq(batch[1]).sum().item()

        if step % 100 == 0 and step != 0:
            print(f"\n====== { step } Step of { len(train_loader) } ======")
            print(f"Train Acc : { correct / total_cnt }")
            print(f"Train Loss : { loss.item() / batch[1].size(0) }")

    correct = 0
    total_cnt = 0

    # Test
    with torch.no_grad():
        model.eval()
        for step, batch in enumerate(test_loader):
            # input and target
            batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
            total_cnt += batch[1].size(0)
            logits = model(batch[0])
            valid_loss += loss_fn(logits, batch[1])
            _, predict = logits.max(1)
            correct += predict.eq(batch[1]).sum().item()
        valid_acc = correct / total_cnt
        print(f"\nValid Acc : { valid_acc }")
        print(f"Valid Loss : { valid_loss / total_cnt }")

        if(valid_acc > best_acc):
            best_acc = valid_acc
            torch.save(model, model_name)
            print("Model Saved!")

  0%|          | 0/150 [00:00<?, ?it/s]



  self.pid = os.fork()



Train Acc : 0.15485767326732675
Train Loss : 0.008431224152445793


  1%|          | 1/150 [01:50<4:34:43, 110.63s/it]


Valid Acc : 0.2226
Valid Loss : 0.00843049120157957
Model Saved!

Train Acc : 0.23851330445544555
Train Loss : 0.007928382605314255

Valid Acc : 0.2789
Valid Loss : 0.007953212596476078


  1%|▏         | 2/150 [03:44<4:38:05, 112.74s/it]

Model Saved!

Train Acc : 0.2995049504950495
Train Loss : 0.007257702760398388

Valid Acc : 0.3567
Valid Loss : 0.006878793239593506


  2%|▏         | 3/150 [05:44<4:43:42, 115.80s/it]

Model Saved!

Train Acc : 0.35852413366336633
Train Loss : 0.006846493575721979

Valid Acc : 0.4222
Valid Loss : 0.0062881819903850555


  3%|▎         | 4/150 [07:44<4:46:11, 117.61s/it]

Model Saved!

Train Acc : 0.3961556311881188
Train Loss : 0.006119542755186558


  3%|▎         | 5/150 [09:43<4:45:31, 118.15s/it]


Valid Acc : 0.39
Valid Loss : 0.007040169555693865

Train Acc : 0.43579826732673266
Train Loss : 0.005680623464286327

Valid Acc : 0.4534
Valid Loss : 0.00650556618347764


  4%|▍         | 6/150 [11:44<4:45:15, 118.86s/it]

Model Saved!

Train Acc : 0.46905940594059403
Train Loss : 0.005557815078645945

Valid Acc : 0.4939
Valid Loss : 0.005538885947316885


  5%|▍         | 7/150 [13:44<4:44:15, 119.27s/it]

Model Saved!

Train Acc : 0.4985689975247525
Train Loss : 0.00528485095128417

Valid Acc : 0.5174
Valid Loss : 0.005402952898293734


  5%|▌         | 8/150 [15:44<4:43:20, 119.72s/it]

Model Saved!

Train Acc : 0.5216584158415841
Train Loss : 0.0052812485955655575

Valid Acc : 0.5387
Valid Loss : 0.005211806856095791


  6%|▌         | 9/150 [17:45<4:42:07, 120.05s/it]

Model Saved!

Train Acc : 0.5505105198019802
Train Loss : 0.005494485609233379

Valid Acc : 0.5431
Valid Loss : 0.005167281720787287


  7%|▋         | 10/150 [19:46<4:40:23, 120.17s/it]

Model Saved!

Train Acc : 0.5743734529702971
Train Loss : 0.004282016772776842


  7%|▋         | 11/150 [21:46<4:38:22, 120.16s/it]


Valid Acc : 0.5174
Valid Loss : 0.005740209016948938

Train Acc : 0.6001314975247525
Train Loss : 0.004051889758557081

Valid Acc : 0.576
Valid Loss : 0.0048752436414361


  8%|▊         | 12/150 [23:46<4:36:32, 120.24s/it]

Model Saved!

Train Acc : 0.6243038366336634
Train Loss : 0.0036234608851373196

Valid Acc : 0.5949
Valid Loss : 0.0046732304617762566


  9%|▊         | 13/150 [25:47<4:35:11, 120.52s/it]

Model Saved!

Train Acc : 0.6372215346534653
Train Loss : 0.0035028306301683187

Valid Acc : 0.6283
Valid Loss : 0.00435118842869997


  9%|▉         | 14/150 [27:48<4:33:19, 120.59s/it]

Model Saved!

Train Acc : 0.6532332920792079
Train Loss : 0.003292716108262539


 10%|█         | 15/150 [29:49<4:31:21, 120.60s/it]


Valid Acc : 0.6245
Valid Loss : 0.0044210124760866165

Train Acc : 0.6687422648514851
Train Loss : 0.0038868633564561605

Valid Acc : 0.6423
Valid Loss : 0.0042333174496889114


 11%|█         | 16/150 [31:50<4:29:32, 120.69s/it]

Model Saved!

Train Acc : 0.6834003712871287
Train Loss : 0.003354523563757539

Valid Acc : 0.65
Valid Loss : 0.0041038827039301395


 11%|█▏        | 17/150 [33:50<4:27:30, 120.68s/it]

Model Saved!

Train Acc : 0.6988319925742574
Train Loss : 0.00355724454857409

Valid Acc : 0.6771
Valid Loss : 0.0037550346460193396


 12%|█▏        | 18/150 [35:51<4:25:49, 120.83s/it]

Model Saved!

Train Acc : 0.7166615099009901
Train Loss : 0.003138939617201686


 13%|█▎        | 19/150 [37:52<4:23:40, 120.77s/it]


Valid Acc : 0.6737
Valid Loss : 0.003878289368003607

Train Acc : 0.724319306930693
Train Loss : 0.0032069210428744555

Valid Acc : 0.7164
Valid Loss : 0.0033200737088918686


 13%|█▎        | 20/150 [39:52<4:21:20, 120.62s/it]

Model Saved!

Train Acc : 0.7443920173267327
Train Loss : 0.002944500185549259

Valid Acc : 0.7189
Valid Loss : 0.003456716425716877


 14%|█▍        | 21/150 [41:53<4:19:32, 120.71s/it]

Model Saved!

Train Acc : 0.7583152846534653
Train Loss : 0.002741477685049176

Valid Acc : 0.7418
Valid Loss : 0.0029967024456709623


 15%|█▍        | 22/150 [43:55<4:18:05, 120.98s/it]

Model Saved!

Train Acc : 0.7662824876237624
Train Loss : 0.002762447576969862


 15%|█▍        | 22/150 [45:35<4:25:14, 124.33s/it]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10,8))
plt.plot(num_epoch, train_loss, num_epoch, valid_loss)
plt.legend(labels=['train', 'valid'], loc='upper left')
plt.grid(True)
plt.show()