In [None]:
cd ../../disk1/colonoscopy_datasetv2/cropped

GPU 지정 & 라이브러리

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix

Custom Dataset 구현

In [None]:
class CustomDataset(Dataset):
    def readImg(self):
        all_img_files = []

        class_names = os.walk(self.dataset_path).__next__()[1]

        for idx, class_name in enumerate(class_names):
            img_dir = os.path.join(self.dataset_path, class_name)
            img_files = os.walk(img_dir).__next__()[2]
            
            for img in img_files:
                if img[4:8] == 'MASK':
                    continue
                img_path = os.path.join(img_dir,img)
                image = Image.open(img_path)
                if image is not None:
                    all_img_files.append(img_path)

        all_img_files.sort()

        return all_img_files, len(class_names), len(all_img_files)

    def __init__(self, dataset_path, img_transforms=None):
        self.dataset_path = dataset_path
        self.img_transforms = img_transforms
        self.img_files, self.num_classes, self.num_images = self.readImg()

    def __getitem__(self, index):
        label = None
        image = self.img_files[index]
        if image[2:5]=="NOR":
            label = 0
        elif image[2:5]=="ADC":
            label = 1
        elif image[2:5]=="HGD":
            label = 2
        elif image[2:5]=="LGD":
            label = 3
            
        image = Image.open(image)
        image = image.convert('RGB')

        
        if self.img_transforms is not None:
            image = self.img_transforms(image)


        return {'image':image, 'label':label}

    def __len__(self):
        return self.num_images


Focal Loss

In [None]:
class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2, reduction='mean'):
        super(FocalLoss,self).__init__(weight, reduction=reduction)
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1-pt)**self.gamma*ce_loss).mean()
        
        return focal_loss

모듈 구현

In [None]:
def train_epoch(model, dataloader, optimizer):
    model.train()
    crit = FocalLoss()
    sum_loss = 0

    for item in dataloader:
        images = item['image'].to(device)
        labels = item['label'].to(device)

        outputs = model(images)
        loss = crit(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sum_loss += loss.item()*len(images)

    return sum_loss / len(dataloader.dataset)

def val_epoch(model, dataloader):
    model.eval()
    crit = FocalLoss()
    sum_loss = 0
    correct = 0
    total = len(dataloader.dataset)

    with torch.no_grad():
        for item in dataloader:
            images = item['image'].to(device)
            labels = item['label'].to(device)

            outputs = model(images)
            loss = crit(outputs, labels)

            sum_loss += loss.item()*len(images)
            _, predict = torch.max(outputs.data,1)
            correct += (predict==labels).sum().item()

        accuracy = correct/total * 100

        return sum_loss/len(dataloader.dataset), accuracy
    

데이터셋 분할

In [None]:
data_transforms = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
total_dataset = CustomDataset('./', img_transforms=data_transforms)
total_loader = DataLoader(total_dataset, batch_size=8, shuffle=False)

# total images: 653, total classes: 4
print('total images: {}'.format(total_dataset.num_images))
print('total classes: {}'.format(total_dataset.num_classes))

# class별 데이터 개수 확인
# nor0, adc1, hgd2, lgd3 = 0, 0, 0, 0
# for idx, item in enumerate(total_dataset):
#     label = item['label']
#     if label==0: nor0 += 1
#     elif label==1: adc1 += 1
#     elif label==2: hgd2 += 1
#     else: lgd3 += 1
# print("nor: {}\nadc: {}\nhgd: {}\nlgd: {}".format(nor0,adc1,hgd2,lgd3))

train_size = int(total_dataset.num_images * 0.8)
val_size = int(total_dataset.num_images * 0.1)
test_size = total_dataset.num_images - train_size - val_size
print('train size: {}\nvalidation_size: {}\ntest_size: {}'.format(train_size, val_size, test_size))

train_dataset, val_dataset, test_dataset = random_split(total_dataset, [train_size, val_size, test_size],generator=torch.Generator().manual_seed(42))

VGG16 모델 불러오기 & 훈련

In [None]:
from torchvision import models
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

pre_model = models.vgg16(pretrained=True)
# scr_model = models.vgg16(pretrained=False)

# 하이퍼파라미터 설정
hy_batch = 64
hy_epoch = 100
hy_lr = 0.00001

# fine tuning
num_classes = total_dataset.num_classes
num_ftrs = pre_model.classifier[6].in_features

pre_model.classifier[6] = nn.Linear(num_ftrs,num_classes)
pre_model.cuda()
pre_model = nn.DataParallel(pre_model).to(device)
# scr_model.classifier[6] = nn.Linear(num_ftrs,num_classes)
# scr_model.cuda()
# scr_model = nn.DataParallel(scr_model).to(device)

optimizer_pre = torch.optim.Adam(pre_model.parameters(), lr=hy_lr)
# optimizer_scr = torch.optim.Adam(scr_model.parameters(), lr=hy_lr)

train_loader = DataLoader(train_dataset, batch_size=hy_batch, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=hy_batch, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=hy_batch, shuffle=False)

In [None]:
pre_model

모델 훈련

In [None]:

PATH_pre = '../../../home/bokyoungk/classification_models/vgg16_pretrained/'
PATH_scr = '../../../home/bokyoungk/classification_models/vgg16_scratch/'
# pre_model = torch.load(PATH_pre+'checkpoint/model_11.pt')
# pre_model.load_state_dict(torch.load(PATH_pre+'checkpoint/model_state_11.pt'))
# checkpoint = torch.load(PATH_pre+'checkpoint/all_11.tar')
# pre_model.load_state_dict(checkpoint['model'])

# scr_model = torch.load(PATH_scr+'checkpoint/model_11.pt')
# scr_model.load_state_dict(torch.load(PATH_scr+'checkpoint/model_state_11.pt'))
# checkpoint = torch.load(PATH_scr+'checkpoint/all_11.tar')
# scr_model.load_state_dict(checkpoint['model'])

min_loss_pre = 1.2
# min_loss_scr = 1.2

all_train_loss_pre = []
all_val_loss_pre = []
all_accuracy_pre = []
# all_train_loss_scr = []
# all_val_loss_scr = []
# all_accuracy_scr = []


for e in range(0,hy_epoch):
    print('------------------------------------------------epoch {}/{}---------------------------------------------------'.format(e+1,hy_epoch))
    train_loss_pre = train_epoch(pre_model,train_loader,optimizer_pre)
    val_loss_pre, val_acc_pre = val_epoch(pre_model,val_loader)
    # train_loss_scr = train_epoch(scr_model,train_loader,optimizer_scr)
    # val_loss_scr, val_acc_scr = val_epoch(scr_model,val_loader)
    print('train loss pretrained: {}, val loss pretrained: {}, val acc pretrained: {}'.format(train_loss_pre,val_loss_pre,val_acc_pre))
    # print('train loss scratch: {}, val loss scratch: {}, val acc scratch: {}'.format(train_loss_scr,val_loss_scr,val_acc_scr))

    all_train_loss_pre.append(train_loss_pre)
    all_val_loss_pre.append(val_loss_pre)
    all_accuracy_pre.append(val_acc_pre)

    # all_train_loss_scr.append(train_loss_scr)
    # all_val_loss_scr.append(val_loss_scr)
    # all_accuracy_scr.append(val_acc_scr)

    # # loss 최소일 때 저장
    if min_loss_pre > val_loss_pre:
        min_loss_pre = val_loss_pre
        torch.save(pre_model,PATH_pre+'min_loss/focal_min_model_{}.pt'.format(e+1))
        torch.save(pre_model.state_dict(),PATH_pre+'min_loss/focal_min_model_state_{}.pt'.format(e+1))
        torch.save({
            'model':pre_model.state_dict(),
            'optimizer':optimizer_pre.state_dict()
        },PATH_pre+'min_loss/focal_min_all_{}.tar'.format(e+1))

    # if min_loss_scr > val_loss_scr:
    #     min_loss_scr = val_loss_scr
    #     torch.save(scr_model,PATH_scr+'min_loss/cr_min_model_{}.pt'.format(e+1))
    #     torch.save(scr_model.state_dict(),PATH_scr+'min_loss/cr_min_model_state_{}.pt'.format(e+1))
    #     torch.save({
    #         'model':scr_model.state_dict(),
    #         'optimizer':optimizer_scr.state_dict()
    #     },PATH_scr+'min_loss/cr_min_all_{}.tar'.format(e+1))

    
    # checkpoint
    # if (e+1)%10==0:
    #     torch.save(pre_model,PATH_pre+'checkpoint/model_{}.pt'.format(e+1))
    #     torch.save(pre_model.state_dict(),PATH_pre+'checkpoint/model_state_{}.pt'.format(e+1))
    #     torch.save({
    #         'model':pre_model.state_dict(),
    #         'optimizer':optimizer_pre.state_dict()
    #     },PATH_pre+'checkpoint/all_{}.tar'.format(e+1))

    # if (e+1)%10==0:
    #     torch.save(scr_model,PATH_scr+'checkpoint/focal_model_{}.pt'.format(e+1))
    #     torch.save(scr_model.state_dict(),PATH_scr+'checkpoint/focal_model_state_{}.pt'.format(e+1))
    #     torch.save({
    #         'model':scr_model.state_dict(),
    #         'optimizer':optimizer_scr.state_dict()
    #     },PATH_scr+'checkpoint/focal_all_{}.tar'.format(e+1))

결과 시각화

In [None]:
# 시각화
x = np.arange(1,hy_epoch+1,step=1)

plt.figure(figsize=(10,8))
plt.subplot(1,2,1)
plt.title('Pretrained-VGG16')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(x,all_train_loss_pre,label='train loss')
plt.plot(x,all_val_loss_pre,label='val loss')
plt.legend()
plt.show()

plt.subplot(1,2,2)
plt.title('Pretrained-VGG16')
plt.xlabel('Epoch')
plt.ylabel('Accuracy(%)')
plt.plot(x,all_accuracy_pre)
plt.show()

# plt.figure(figsize=(10,8))
# plt.subplot(1,2,1)
# plt.title('Scratch-VGG16')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.plot(x,all_train_loss_scr,label='train loss')
# plt.plot(x,all_val_loss_scr,label='val loss')
# plt.legend()
# plt.show()

# plt.subplot(1,2,2)
# plt.title('Scratch-VGG16')
# plt.xlabel('Epoch')
# plt.ylabel('Accuracy(%)')
# plt.plot(x,all_accuracy_scr)
# plt.show()


모델 성능 평가

In [None]:
PATH_pre = '../../../home/bokyoungk/classification_models/vgg16_pretrained/'
PATH_scr = '../../../home/bokyoungk/classification_models/vgg16_scratch/'

#inference
true_labels = []
pre_labels = []
# scr_labels = []

pre_model = torch.load(PATH_pre+'min_loss/focal_min_model_10.pt')
pre_model.load_state_dict(torch.load(PATH_pre+'min_loss/focal_min_model_state_10.pt'))
checkpoint = torch.load(PATH_pre+'min_loss/focal_min_all_10.tar')
pre_model.load_state_dict(checkpoint['model'])

# scr_model = torch.load(PATH_scr+'min_loss/cr_min_model_35.pt')
# scr_model.load_state_dict(torch.load(PATH_scr+'min_loss/cr_min_model_state_35.pt'))
# checkpoint = torch.load(PATH_scr+'min_loss/cr_min_all_35.tar')
# scr_model.load_state_dict(checkpoint['model'])

with torch.no_grad():
    pre_model.eval()
    # scr_model.eval()

    correct_pre = 0
    # correct_scr = 0
    total = len(test_loader.dataset)
    for item in test_loader:
        images = item['image'].to(device)
        labels = item['label'].to(device)

        outputs_pre = pre_model(images)
        # outputs_scr = scr_model(images)
        _, predict_pre = torch.max(outputs_pre,1)
        # _, predict_scr = torch.max(outputs_scr,1)
        correct_pre += (predict_pre==labels).sum().item()
        # correct_scr += (predict_scr==labels).sum().item()
        true_labels.extend(labels)
        pre_labels.extend(predict_pre)
        # scr_labels.extend(predict_scr)
    
    print('Test accuracy of the pre-trained VGG16 on the {} test images: {}%'.format(total, 100*correct_pre/total))
    # print('Test accuracy of the scratch VGG16 on the {} test images: {}%'.format(total, 100*correct_scr/total))

true_labels = torch.tensor(true_labels)
true_labels = true_labels.tolist()
pre_labels = torch.tensor(pre_labels)
pre_labels = pre_labels.tolist()
# scr_labels = torch.tensor(scr_labels)
# scr_labels = scr_labels.tolist()

print('Pre-trained VGG16')
print(classification_report(true_labels,pre_labels))
# print('Scratch VGG16')
# print(classification_report(true_labels,scr_labels))

print('Pre-trained VGG16')
print(confusion_matrix(true_labels,pre_labels))
# print('Scratch VGG16')
# print(confusion_matrix(true_labels,scr_labels))