### 전이학습

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# 이미지 데이터셋, 전처리, 전이학습 모듈
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights

from torchinfo import summary
from torchmetrics.functional.classification import multiclass_accuracy

  warn(


In [2]:
### ===> 데이터 준비
img_dir = '../data/img'

### ===> Resnet 전처리
# resize_size=[256] interpolation=InterpolationMode.BILINEAR
# central crop of crop_size = [224]
# first rescaled to [0.0, 1.0] and then
# normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
preprocessing = transforms.Compose(transforms=[
    transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

### ===> 이미지 데이터셋 생성
imgDS = ImageFolder(root=img_dir, transform=preprocessing)

print(imgDS.classes)

['kakarote', 'vegeta']


In [3]:
### ===> 데이터 로더 생성
imgDL = DataLoader(imgDS, batch_size=3, shuffle=True, drop_last=True)
# for (img, label) in imgDL:
#     print(img.shape, label)

In [4]:
### ==> 모델 설계 및 설정

### 사전학습된 모델 인스턴스 생성
res_model = resnet18(weights=ResNet18_Weights.DEFAULT)

### 전결합층 변경
### in_features : FeatureMap에서 받은 피쳐 수, out_features : 출력/분류 클래스 수
res_model.fc = nn.Linear(in_features=512, out_features=3)

In [5]:
summary(model=res_model
        , input_size=(3,3,224,224)
        )

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [3, 3]                    --
├─Conv2d: 1-1                            [3, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [3, 64, 112, 112]         128
├─ReLU: 1-3                              [3, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [3, 64, 56, 56]           --
├─Sequential: 1-5                        [3, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [3, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [3, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [3, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [3, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [3, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [3, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [3, 64, 56, 56]           --
│

In [6]:
### Resnet18 Feature Module 파라미터 requires_grad = True ==> False 비활성화
for name, param in res_model.named_parameters():
    print(name, param.requires_grad, end='  ===>    ')
    param.requires_grad = False
    print(param.requires_grad)
    
### Resnet18 Full Connected Module 파리머터 requires_grad = False ==> True 활성화
for name, param in res_model.fc.named_parameters():
    print(name, param.requires_grad, end='  ===>    ')
    param.requires_grad = True
    print(param.requires_grad)

conv1.weight True  ===>    False
bn1.weight True  ===>    False
bn1.bias True  ===>    False
layer1.0.conv1.weight True  ===>    False
layer1.0.bn1.weight True  ===>    False
layer1.0.bn1.bias True  ===>    False
layer1.0.conv2.weight True  ===>    False
layer1.0.bn2.weight True  ===>    False
layer1.0.bn2.bias True  ===>    False
layer1.1.conv1.weight True  ===>    False
layer1.1.bn1.weight True  ===>    False
layer1.1.bn1.bias True  ===>    False
layer1.1.conv2.weight True  ===>    False
layer1.1.bn2.weight True  ===>    False
layer1.1.bn2.bias True  ===>    False
layer2.0.conv1.weight True  ===>    False
layer2.0.bn1.weight True  ===>    False
layer2.0.bn1.bias True  ===>    False
layer2.0.conv2.weight True  ===>    False
layer2.0.bn2.weight True  ===>    False
layer2.0.bn2.bias True  ===>    False
layer2.0.downsample.0.weight True  ===>    False
layer2.0.downsample.1.weight True  ===>    False
layer2.0.downsample.1.bias True  ===>    False
layer2.1.conv1.weight True  ===>    False


In [7]:
### ===> 학습 준비
optimizer = optim.Adam(params=res_model.parameters())
loss_fn = nn.CrossEntropyLoss() # 손실 함수 정의
EPOCHS = 3

In [8]:
torch.nn.Flatten()

Flatten(start_dim=1, end_dim=-1)