## 전이학습(Transfer Learning)
전이학습이란 기존의 잘 알려진 데이터 혹은 사전학습된 모델을 업무 효율 증대나 도메인 확장을 위해 사용하는 학습을 의미한다. 따라서 전이학습은 인공지능 분야에서 매우 중요한 연구 중 하나이며 다양한 방법론들이 존재한다. 잘 학습된 모델을 재사용하는 방법에 대해서 알아보자.

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim
from tqdm import trange

### 1.GPU 연산 확인


In [3]:
# GPU VS CPU
#현재 가능한 장치를 확인한다.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### 2. CIFAR10 데이터 불러오기

In [4]:
#CIFAR10: 클래스 10개를 가진 이미지 데이터
#'plane','car','bird','cat','deer','dog','frog','horse','ship','truck'

#데이터 불러오기 및 전처리 작업
transform = transforms.Compose([transforms.RandomCrop(32,padding=4),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,
                                        download=True,transform=transform)
trainloader= torch.utils.data.DataLoader(trainset,batch_size=8,shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data',train=False,
                                        download=True,transform=test_transform)
testloader= torch.utils.data.DataLoader(testset,batch_size=8,shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 92520883.70it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


### 3. Pretrained model 불러오기
파이토치에서는 다양한 사전 학습된 모델을 제공하고 있다.

In [5]:
#ResNet18 불러오기
#weights='DEFAULT'를 하면 ResNet18 IMAGENET1K_V1 구조와 사전 학습된 파라메타를 모두 불러온다.
#weights=False를 하면 ResNet18 구조만 불러온다.
# 모델과 텐서에 .to(device) 를 붙여야만 GPU연산이 가능

model = torchvision.models.resnet18(weights='DEFAULT')

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 121MB/s]


In [6]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
model.conv1

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [8]:
#모델의 구조를 보면 마지막 출력 노드가 1000개라는 것을 알 수 있다.
#이는 1000개의 클래스를 가진 ImageNet 데이터를 이용하여 사전학습된 모델이기 때문이다.
#따라서 우리가 사용하는 CIFAR10 데이터에 맞게 출력층의 노드를 10개로 변경해야만 한다.

model.conv1=  nn.Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)
model.fc = nn.Linear(512,10)
model= model.to(device)

In [9]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### 4. 손실함수와 최적화 방법 정의

In [10]:
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=1e-4,weight_decay=1e-2)

### 5. 사전학습 모델을 이용한 학습

In [11]:
num_epochs = 20
ls = 2
pbar = trange(num_epochs)

for epoch in pbar:
  correct = 0
  total = 0
  running_loss = 0.0
  for data in trainloader:
    inputs,labels = data[0].to(device), data[1].to(device)

    optimizer.zero_grad()
    outputs= model(inputs)
    loss = criterion(outputs,labels)
    loss.backward()
    optimizer.step()

    running_loss +=loss.item()
    _,predicted = torch.max(outputs.detach(),1)
    total +=labels.size(0)
    correct += (predicted==labels).sum().item()

  cost = running_loss/len(trainloader)
  acc = 100*correct/total

  if cost<ls:
    ls = cost
    torch.save(model.state_dict(),'./cifar10_resnet18.pth')
  pbar.set_postfix({'loss ': cost, 'train acc ':acc})


100%|██████████| 20/20 [33:57<00:00, 101.86s/it, loss =0.55, train acc =83.1]


### 6. 모델 평가

In [12]:
model = torchvision.models.resnet18(weights=None)
model.conv1=  nn.Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)
model.fc = nn.Linear(512,10)
model= model.to(device)
model.load_state_dict(torch.load('./cifar10_resnet18.pth'))

<All keys matched successfully>

In [13]:
correct = 0
total = 0
with torch.no_grad():
  model.eval()
  for data in testloader:
    images,labels=data[0].to(device),data[1].to(device)
    outputs = model(images)
    _,predicted = torch.max(outputs.data,1)
    total+=labels.size(0)
    correct+=(predicted==labels).sum().item()

print("Accuracy of the network on the 10000 test images: %d %%" %(100*correct/total))

Accuracy of the network on the 10000 test images: 83 %


In [14]:
import shutil
source_path = '/content/cifar10_resnet18.pth'
destination_path = '/content/drive/MyDrive/Pytorch/models'

shutil.move(source_path, destination_path)

'/content/drive/MyDrive/Pytorch/models/cifar10_resnet18.pth'