<a href="https://colab.research.google.com/github/jetsonai/Working-R-Ssaem/blob/main/CNN/%5B4%5D_Transfer_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **사전 학습된 모델을 활용한 전이학습 (Transfer Learning with Pretrained Network)**

## 1. 라이브러리 불러오기

In [None]:
!pip3 install torchmetrics

In [None]:
from os import makedirs, listdir
from os.path import join

import random

from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights

from torchmetrics import Accuracy
from torchsummary import summary
from tqdm import tqdm

## 2. 시드 고정

In [None]:
def fix_seed(seed) :
  # Fix Seed
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

## 3. 손실 함수 및 평가 지표 평균화

In [None]:
class AverageMeter(object) :
  def __init__(self) :
    self.reset()

  def reset(self) :
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1) :
    self.val = val
    self.sum += val*n
    self.count += n
    self.avg = self.sum/self.count

## 3. 시험 데이터셋 다운로드

In [None]:
!wget https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
!unzip -qq cats_and_dogs_filtered.zip

## 4. Custom Dataloader 생성

In [None]:
########## Custom DataLoader ##########
class PyTorchCustomDataset(Dataset):
  def __init__(self, root_dir="cats_and_dogs_filtered/train", transform=None):
    self.image_abs_path = root_dir
    self.transform = transform
    self.label_list = listdir(self.image_abs_path)
    self.label_list.sort()
    self.x_list = []
    self.y_list = []
    for label_index, label_str in enumerate(self.label_list):
      img_path = join(self.image_abs_path, label_str)
      img_list = listdir(img_path)
      for img in img_list:
        self.x_list.append(join(img_path, img))
        self.y_list.append(label_index)
    pass

  def __len__(self):
    return len(self.x_list)

  def __getitem__(self, idx):
    image = Image.open(self.x_list[idx])
    if image.mode != "RGB":
      image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    return image, torch.tensor(self.y_list[idx]).type(torch.LongTensor)

  def __save_label_map__(self, dst_text_path="label_map.txt"):
    label_list = self.label_list
    f = open(dst_text_path, 'w')
    for i in range(len(label_list)):
      f.write(label_list[i]+'\n')
    f.close()
    pass

  def __num_classes__(self):
    return len(self.label_list)

## Pretrained Model Parameter 전체를 학습 (Unfrozen Backbone)

In [None]:
########## Training Code ##########
def transfer_learning_unfrozen(model, img_channels=3, img_size=224, num_classes=2, lr=1e-4, total_epochs=20, seed=42, batch_size=16, src="cats_and_dogs_filtered") :
    # Load Dataset
    train_transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                          transforms.RandomHorizontalFlip(0.5),
                                          transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                                                         saturation=0.2, hue=0.1)], p=0.8),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                               std=[0.229, 0.224, 0.225])]) # ImageNet의 RGB 통계량
    test_transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                              std=[0.229, 0.224, 0.225])]) # ImageNet의 RGB 통계량

    # Create Custom Dataset Instance
    train_dataset = PyTorchCustomDataset(join(src, "train"), train_transform)
    test_dataset = PyTorchCustomDataset(join(src, "validation"), test_transform)

    # Fix Seed
    fix_seed(seed)

    # Create DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # Check Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Current Device : {device}")

    # Fix Seed
    fix_seed(seed)

    # Unfreeze CNN Backbone
    for param in model.parameters() :
      param.requires_grad = True

    # Replace Linear Layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes) # Customize Classifier

    # Assign Device
    model = model.to(device)

    # Summarize Model
    summary(model, (img_channels, img_size, img_size))

    # Create Optimizer Instance
    optimizer = optim.SGD(model.parameters(), lr=lr)

    # Create Loss Instance
    criterion = nn.CrossEntropyLoss()

    # Create Metric Instance
    metric = Accuracy("multiclass", num_classes=num_classes).to(device)

    # Create AverageMeter Instance
    train_loss, train_acc = AverageMeter(), AverageMeter()
    test_loss, test_acc = AverageMeter(), AverageMeter()

    # Create List Instance
    train_loss_list, train_acc_list = [], []
    test_loss_list, test_acc_list = [], []

    # Create Directory
    ckpt_dir, graph_dir = "ckpt/backbone_unfrozen", "result/backbone_unfrozen"
    makedirs(ckpt_dir, exist_ok=True), makedirs(graph_dir, exist_ok=True)

    # Set Best Accuracy
    best_acc = 0

    # Start Training
    for epoch in range(total_epochs) :
      # Create TQDM Bar Instance
      train_bar = tqdm(train_loader)

      # Reset AverageMeter
      train_loss.reset(), train_acc.reset()

      # Set Training Mode
      model.train()

      # Training Phase
      for data in train_bar :
        img, label = data
        img, label = img.to(device), label.to(device)

        # Update Classifier Weights
        optimizer.zero_grad()
        pred = model(img)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()

        # Compute Metric
        acc = metric(pred, label)

        # Update AverageMeter
        train_loss.update(loss.cpu().item()), train_acc.update(acc.cpu().item())

        # Show Training Status
        train_bar.set_description(desc=f"[Train] [{epoch+1}/{total_epochs}] < Loss:{train_loss.avg:.4f} | Acc.:{train_acc.avg:.4f} >")

      # Add Training Loss and Accuracy
      train_loss_list.append(train_loss.avg), train_acc_list.append(train_acc.avg)

      # Create TQDM Bar Instance
      test_bar = tqdm(test_loader)

      # Reset AverageMeter
      test_loss.reset(), test_acc.reset()

      # Evaluate Model
      with torch.no_grad() :
        # Set Test Mode
        model.eval()

        for data in test_bar :
          img, label = data
          img, label = img.to(device), label.to(device)

          # Update Classifier Weights
          pred = model(img)
          loss = criterion(pred, label)

          # Compute Metric
          acc = metric(pred, label)

          # Update AverageMeter
          test_loss.update(loss.cpu().item()), test_acc.update(acc.cpu().item())

          # Show Training Status
          test_bar.set_description(desc=f"[Test] [{epoch+1}/{total_epochs}] < Loss:{test_loss.avg:.4f} | Acc.:{test_acc.avg:.4f} >")

      # Add Test Loss and Accuracy
      test_loss_list.append(test_loss.avg), test_acc_list.append(test_acc.avg)

      # Save Network
      if test_acc.avg > best_acc :
        best_acc = test_acc.avg
        torch.save(model.state_dict(), f"{ckpt_dir}/best.pth")
      torch.save(model.state_dict(), f"{ckpt_dir}/latest.pth")

      # Plot Training vs. Test Loss Graph
      plt.clf()
      plt.plot(np.arange(epoch+1), train_loss_list, label="Training Loss")
      plt.plot(np.arange(epoch+1), test_loss_list, label="Test Loss")
      plt.title("Loss (Training vs. Test)")
      plt.xlabel("Epoch"), plt.ylabel("Loss")
      plt.legend(loc="best")
      plt.savefig(f"{graph_dir}/loss.png")

      # Plot Training vs. Test Accuracy Graph
      plt.clf()
      plt.plot(np.arange(epoch+1), train_acc_list, label="Training Accuracy")
      plt.plot(np.arange(epoch+1), test_acc_list, label="Test Accuracy")
      plt.title("Accuracy (Training vs. Test)")
      plt.xlabel("Epoch"), plt.ylabel("Accuracy")
      plt.legend(loc="best")
      plt.savefig(f"{graph_dir}/accuracy.png")

### 훈련 진행

In [None]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
transfer_learning_unfrozen(model)

## Pretrained Model Parameter 일부만 학습 (Frozen Backbone)

In [None]:
########## Training Code ##########
def transfer_learning_frozen(model, img_channels=3, img_size=224, num_classes=2, lr=1e-2, total_epochs=10, seed=42, batch_size=16, src="cats_and_dogs_filtered") :
    # Load Dataset
    train_transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                          transforms.RandomHorizontalFlip(0.5),
                                          transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                                                         saturation=0.2, hue=0.1)], p=0.8),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                               std=[0.229, 0.224, 0.225])]) # ImageNet의 RGB 통계량
    test_transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                              std=[0.229, 0.224, 0.225])]) # ImageNet의 RGB 통계량

    # Create Custom Dataset Instance
    train_dataset = PyTorchCustomDataset(join(src, "train"), train_transform)
    test_dataset = PyTorchCustomDataset(join(src, "validation"), test_transform)

    # Fix Seed
    fix_seed(seed)

    # Create DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # Check Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Current Device : {device}")

    # Fix Seed
    fix_seed(seed)

    # Freeze CNN Backbone
    for param in model.parameters() :
      param.requires_grad = False

    # Replace Linear Layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes) # Customize Classifier

    # Assign Device
    model = model.to(device)

    # Summarize Model
    summary(model, (img_channels, img_size, img_size))

    # Create Optimizer Instance
    optimizer = optim.SGD(model.parameters(), lr=lr)

    # Create Loss Instance
    criterion = nn.CrossEntropyLoss()

    # Create Metric Instance
    metric = Accuracy("multiclass", num_classes=num_classes).to(device)

    # Create AverageMeter Instance
    train_loss, train_acc = AverageMeter(), AverageMeter()
    test_loss, test_acc = AverageMeter(), AverageMeter()

    # Create List Instance
    train_loss_list, train_acc_list = [], []
    test_loss_list, test_acc_list = [], []

    # Create Directory
    ckpt_dir, graph_dir = "ckpt/backbone_frozen", "result/backbone_frozen"
    makedirs(ckpt_dir, exist_ok=True), makedirs(graph_dir, exist_ok=True)

    # Set Best Accuracy
    best_acc = 0

    # Start Training
    for epoch in range(total_epochs) :
      # Create TQDM Bar Instance
      train_bar = tqdm(train_loader)

      # Reset AverageMeter
      train_loss.reset(), train_acc.reset()

      # Set Training Mode
      model.train()

      # Training Phase
      for data in train_bar :
        img, label = data
        img, label = img.to(device), label.to(device)

        # Update Classifier Weights
        optimizer.zero_grad()
        pred = model(img)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()

        # Compute Metric
        acc = metric(pred, label)

        # Update AverageMeter
        train_loss.update(loss.cpu().item()), train_acc.update(acc.cpu().item())

        # Show Training Status
        train_bar.set_description(desc=f"[Train] [{epoch+1}/{total_epochs}] < Loss:{train_loss.avg:.4f} | Acc.:{train_acc.avg:.4f} >")

      # Add Training Loss and Accuracy
      train_loss_list.append(train_loss.avg), train_acc_list.append(train_acc.avg)

      # Create TQDM Bar Instance
      test_bar = tqdm(test_loader)

      # Reset AverageMeter
      test_loss.reset(), test_acc.reset()

      # Evaluate Model
      with torch.no_grad() :
        # Set Test Mode
        model.eval()

        for data in test_bar :
          img, label = data
          img, label = img.to(device), label.to(device)

          # Update Classifier Weights
          pred = model(img)
          loss = criterion(pred, label)

          # Compute Metric
          acc = metric(pred, label)

          # Update AverageMeter
          test_loss.update(loss.cpu().item()), test_acc.update(acc.cpu().item())

          # Show Training Status
          test_bar.set_description(desc=f"[Test] [{epoch+1}/{total_epochs}] < Loss:{test_loss.avg:.4f} | Acc.:{test_acc.avg:.4f} >")

      # Add Test Loss and Accuracy
      test_loss_list.append(test_loss.avg), test_acc_list.append(test_acc.avg)

      # Save Network
      if test_acc.avg > best_acc :
        best_acc = test_acc.avg
        torch.save(model.state_dict(), f"{ckpt_dir}/best.pth")
      torch.save(model.state_dict(), f"{ckpt_dir}/latest.pth")

      # Plot Training vs. Test Loss Graph
      plt.clf()
      plt.plot(np.arange(epoch+1), train_loss_list, label="Training Loss")
      plt.plot(np.arange(epoch+1), test_loss_list, label="Test Loss")
      plt.title("Loss (Training vs. Test)")
      plt.xlabel("Epoch"), plt.ylabel("Loss")
      plt.legend(loc="best")
      plt.savefig(f"{graph_dir}/loss.png")

      # Plot Training vs. Test Accuracy Graph
      plt.clf()
      plt.plot(np.arange(epoch+1), train_acc_list, label="Training Accuracy")
      plt.plot(np.arange(epoch+1), test_acc_list, label="Test Accuracy")
      plt.title("Accuracy (Training vs. Test)")
      plt.xlabel("Epoch"), plt.ylabel("Accuracy")
      plt.legend(loc="best")
      plt.savefig(f"{graph_dir}/accuracy.png")

### 훈련 진행

In [None]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
transfer_learning_frozen(model)