<a href="https://colab.research.google.com/github/comb0703/practice/blob/main/ResNeXt29_(2x64d).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [2]:
if torch.cuda.is_available() :
  device = torch.device('cuda')
else : 
  print("error")

In [3]:
transform = transforms.Compose([
                                transforms.RandomCrop(32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data',
                                             train = True,
                                             download =True,
                                             transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data',
                                            train=False,
                                            download=True,
                                            transform = transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=64,
                                           shuffle=True,
                                           num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=64,
                                          shuffle=False,
                                          num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


  cpuset_checked))


In [4]:
class ResNeXt(nn.Module) :
  def __init__(self, num_classes=10) :
    super(ResNeXt, self).__init__()


    self.conv1 = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,bias=False)
    self.bn1 = nn.BatchNorm2d(64)
   
    # block 1 (in_plane = 64, plane = 256, c = 2, b_w = 64, g_w = 128
    self.conv2 = nn.Conv2d(64,128,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn2 = nn.BatchNorm2d(128)
    self.conv3 = nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1,groups=2,bias=False)
    self.bn3 = nn.BatchNorm2d(128)
    self.conv4 = nn.Conv2d(128,256,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn4 = nn.BatchNorm2d(256)

    self.shortcut1 = nn.Sequential(
        nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
        nn.BatchNorm2d(256)
        )
    # block 2 (in_plane = 256, plane = 128, c = 2, b_w = 64, g_w =128)
    self.conv5 = nn.Conv2d(256,128,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn5 = nn.BatchNorm2d(128)
    self.conv6 = nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1,groups=2,bias=False)
    self.bn6 = nn.BatchNorm2d(128)
    self.conv7 = nn.Conv2d(128,256,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn7 = nn.BatchNorm2d(256)
    
    self.shortcut2 = nn.Sequential()

    # block 3 (in_plane = 256, plane = 256, c = 2, b_w = 64, g_w =128)
    self.conv8 = nn.Conv2d(256,128,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn8 = nn.BatchNorm2d(128)
    self.conv9 = nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1,groups=2,bias=False)
    self.bn9 = nn.BatchNorm2d(128)
    self.conv10 = nn.Conv2d(128,256,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn10 = nn.BatchNorm2d(256)
    
    self.shortcut3 = nn.Sequential()

    # block 4 (in_plane = 256, plane = 512, c = 2, b_w = 128, g_w =256)
    self.conv11 = nn.Conv2d(256,256,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn11 = nn.BatchNorm2d(256)
    self.conv12 = nn.Conv2d(256,256,kernel_size=3,stride=2,padding=1,groups=2,bias=False)
    self.bn12 = nn.BatchNorm2d(256)
    self.conv13 = nn.Conv2d(256,512,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn13 = nn.BatchNorm2d(512)
    
    self.shortcut4 = nn.Sequential(
        nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),
        nn.BatchNorm2d(512)
        )
    
    # block 5 (in_plane = 512, plane = 512, c = 2, b_w = 128, g_w =256)
    self.conv14 = nn.Conv2d(512,256,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn14 = nn.BatchNorm2d(256)
    self.conv15 = nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1,groups=2,bias=False)
    self.bn15 = nn.BatchNorm2d(256)
    self.conv16 = nn.Conv2d(256,512,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn16 = nn.BatchNorm2d(512)

    self.shortcut5 = nn.Sequential()

    # block 6 (in_plane = 512, plane = 512, c = 2, b_w = 128, g_w =256)
    self.conv17 = nn.Conv2d(512,256,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn17 = nn.BatchNorm2d(256)
    self.conv18 = nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1,groups=2,bias=False)
    self.bn18 = nn.BatchNorm2d(256)
    self.conv19 = nn.Conv2d(256,512,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn19 = nn.BatchNorm2d(512)

    self.shortcut6 = nn.Sequential()

    # block 7 (in_plane = 512 plane = 1024, c = 2, b_w = 256, g_w =512)
    self.conv20 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn20 = nn.BatchNorm2d(512)
    self.conv21 = nn.Conv2d(512,512,kernel_size=3,stride=2,padding=1,groups=2,bias=False)
    self.bn21 = nn.BatchNorm2d(512)
    self.conv22 = nn.Conv2d(512,1024,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn22 = nn.BatchNorm2d(1024)
    
    self.shortcut7 = nn.Sequential(
        nn.Conv2d(512, 1024, kernel_size=1, stride=2, bias=False),
        nn.BatchNorm2d(1024)
        )
    
    # block 8 (in_plane = 1024, plane = 1024, c = 2, b_w = 256, g_w = 512)
    self.conv23 = nn.Conv2d(1024,512,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn23 = nn.BatchNorm2d(512)
    self.conv24 = nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1,groups=2,bias=False)
    self.bn24 = nn.BatchNorm2d(512)
    self.conv25 = nn.Conv2d(512,1024,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn25 = nn.BatchNorm2d(1024)

    self.shortcut8 = nn.Sequential()

    # block 9 (in_plane = 1024, plane = 1024, c = 2, b_w = 256, g_w = 512)
    self.conv26 = nn.Conv2d(1024,512,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn26 = nn.BatchNorm2d(512)
    self.conv27 = nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1,groups=2,bias=False)
    self.bn27 = nn.BatchNorm2d(512)
    self.conv28 = nn.Conv2d(512,1024,kernel_size=1,stride=1,padding=0,bias=False)
    self.bn28 = nn.BatchNorm2d(1024)

    self.shortcut9 = nn.Sequential()


    self.linear = nn.Linear(1024, num_classes)  

    
  def forward(self,x) :
    out = F.relu(self.bn1(self.conv1(x)))
    # block 1
    input = out
    out = F.relu(self.bn2(self.conv2(out)))
    out = F.relu(self.bn3(self.conv3(out)))
    out = self.bn4(self.conv4(out))
    out += self.shortcut1(input)
    out = F.relu(out)      
    # block 2
    input = out
    out = F.relu(self.bn5(self.conv5(out)))
    out = F.relu(self.bn6(self.conv6(out)))
    out = self.bn7(self.conv7(out))
    out += self.shortcut2(input)
    out = F.relu(out)   
    # block 3
    input = out
    out = F.relu(self.bn8(self.conv8(out)))
    out = F.relu(self.bn9(self.conv9(out)))
    out = self.bn10(self.conv10(out))
    out += self.shortcut3(input)
    out = F.relu(out)  
    # block 4
    input = out
    out = F.relu(self.bn11(self.conv11(out)))
    out = F.relu(self.bn12(self.conv12(out)))
    out = self.bn13(self.conv13(out))
    out += self.shortcut4(input)
    out = F.relu(out)  
    # block 5
    input = out
    out = F.relu(self.bn14(self.conv14(out)))
    out = F.relu(self.bn15(self.conv15(out)))
    out = self.bn16(self.conv16(out))
    out += self.shortcut5(input)
    out = F.relu(out)
    # block 6
    input = out
    out = F.relu(self.bn17(self.conv17(out)))
    out = F.relu(self.bn18(self.conv18(out)))
    out = self.bn19(self.conv19(out))
    out += self.shortcut6(input)
    out = F.relu(out)    
    # block 7
    input = out
    out = F.relu(self.bn20(self.conv20(out)))
    out = F.relu(self.bn21(self.conv21(out)))
    out = self.bn22(self.conv22(out))
    out += self.shortcut7(input)
    out = F.relu(out)    
    # block 8
    input = out
    out = F.relu(self.bn23(self.conv23(out)))
    out = F.relu(self.bn24(self.conv24(out)))
    out = self.bn25(self.conv25(out))
    out += self.shortcut8(input)
    out = F.relu(out)    
    # block 9
    input = out
    out = F.relu(self.bn26(self.conv26(out)))
    out = F.relu(self.bn27(self.conv27(out)))
    out = self.bn28(self.conv28(out))
    out += self.shortcut9(input)
    out = F.relu(out)          
    
    out = F.avg_pool2d(out, 8)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out

In [5]:
#criterion,optimizer
net = ResNeXt()
net = net.to(device)

learning_rate = 0.1
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)

#scheduler
def adjust_learning_rate(optimizer, epoch):
    lr = learning_rate
    if epoch >= 80:
        lr /= 10
    if epoch >= 120:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


In [6]:
# train, test function
def train(epoch):
    print('\n[ Train epoch: %d ]' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, prediction = outputs.max(1)

        total += labels.size(0)
        correct += prediction.eq(labels).sum().item()
        
        if batch_idx % 100 == 0:
            print('\nCurrent batch:', str(batch_idx))
            print('Current train accuracy:', str(prediction.eq(labels).sum().item() / labels.size(0)))
            print('Current train average loss:', loss.item() / labels.size(0))

    print('\nTrain accuarcy:', 100. * correct / total)
    print('Train average loss:', train_loss / total)


def test(epoch):
    print('\n[ Test epoch: %d ]' % epoch)
    net.eval()
    loss = 0
    correct = 0
    total = 0

    for batch_idx, (images, labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        total += labels.size(0)

        outputs = net(images)
        loss += criterion(outputs, labels).item()

        _, prediction = outputs.max(1)
        correct += prediction.eq(labels).sum().item()

    print('\nTest accuarcy:', 100. * correct / total)
    print('Test average loss:', loss / total)

    state = {
        'net': net.state_dict()
    }

    file_name = 'ResNeXt29_1x64d.pt'
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/' + file_name)
    print('Model Saved')

In [7]:
writer = SummaryWriter()
for epoch in range(0, 300):
    adjust_learning_rate(optimizer, epoch)
    train(epoch)
    test(epoch)

KeyboardInterrupt: ignored