In [1]:
import sys
sys.path.append('..')
from utils.dataloader_image_classification import ImageTransform, make_datapath_list, HymenopteraDataset

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm

In [2]:
train_list = make_datapath_list(phase='train')
val_list = make_datapath_list(phase='val')

./data/hymenoptera_data/train/**/*.jpg
./data/hymenoptera_data/val/**/*.jpg


In [7]:
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
train_dataset = HymenopteraDataset(
    file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')

val_dataset = HymenopteraDataset(
    file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')

batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}

In [12]:
use_pretrained = True
net = models.vgg16(pretrained=use_pretrained)

net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

net.train()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [13]:
criterion = nn.CrossEntropyLoss()

In [17]:
params_to_update_1 = []
params_to_update_2 = []
params_to_update_3 = []

update_param_names_1 = ['features']
update_param_names_2 = ['classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias']
update_param_names_3 = ['classifier.6.weight', 'classifier.6.bias']

for name, param in net.named_parameters():
    if update_param_names_1[0] in name:
        param.requires_grad = True
        params_to_update_1.append(param)
        print(f'Added into params_to_update_1: {name}')
        
    elif name in update_param_names_2:
        param.requires_grad = True
        params_to_update_2.append(param)
        print(f'Added into params_to_update_2: {name}')
        
    elif name in update_param_names_3:
        param.requires_grad = True
        params_to_update_3.append(param)
        print(f'Added into params_to_update_3: {name}')
        
    else:
        param.requires_grad = False
        print(f'Will not be trained: {name}')

Added into params_to_update_1: features.0.weight
Added into params_to_update_1: features.0.bias
Added into params_to_update_1: features.2.weight
Added into params_to_update_1: features.2.bias
Added into params_to_update_1: features.5.weight
Added into params_to_update_1: features.5.bias
Added into params_to_update_1: features.7.weight
Added into params_to_update_1: features.7.bias
Added into params_to_update_1: features.10.weight
Added into params_to_update_1: features.10.bias
Added into params_to_update_1: features.12.weight
Added into params_to_update_1: features.12.bias
Added into params_to_update_1: features.14.weight
Added into params_to_update_1: features.14.bias
Added into params_to_update_1: features.17.weight
Added into params_to_update_1: features.17.bias
Added into params_to_update_1: features.19.weight
Added into params_to_update_1: features.19.bias
Added into params_to_update_1: features.21.weight
Added into params_to_update_1: features.21.bias
Added into params_to_update_

In [19]:
optimizer = optim.SGD([
    {'params': params_to_update_1, 'lr': 1e-4},
    {'params': params_to_update_2, 'lr': 5e-4},
    {'params': params_to_update_3, 'lr': 1e-3}
], momentum=0.9)

In [24]:
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-----------')
        
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()
            else:
                net.eval()
                
            epoch_loss = 0.0
            epoch_corrects = 0
            
            if (epoch == 0) and (phase == 'train'):
                continue
                
            for inputs, labels in tqdm(dataloaders_dict[phase]):
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                    epoch_loss += loss.item() * inputs.size(0)
                    epoch_corrects += torch.sum(preds == labels.data)
                    
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

In [25]:
num_epochs = 2
train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

  0%|                                                                                                                                                                           | 0/5 [00:00<?, ?it/s]

Epoch 1/2
-----------


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:51<00:00, 10.30s/it]
  0%|                                                                                                                                                                           | 0/8 [00:00<?, ?it/s]

val Loss: 0.8237 Acc: 0.4444
Epoch 2/2
-----------


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [07:07<00:00, 53.45s/it]
  0%|                                                                                                                                                                           | 0/5 [00:00<?, ?it/s]

train Loss: 0.4884 Acc: 0.7449


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:55<00:00, 11.19s/it]

val Loss: 0.1896 Acc: 0.9412





In [26]:
save_path = './weights_fine_tuning.pth'
torch.save(net.state_dict(), save_path)