In [None]:
cd ..

In [None]:
cd ..

In [None]:
cd ./disk1/colonoscopy_datasetv2/cropped

GPU 지정 & 모듈 임포트

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
# 모듈 import
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torchvision import transforms

print(torch.cuda.is_available())

Custom Dataset 구현

In [None]:
class CustomDataset(Dataset):
    def readData(self):
        all_files = []
        all_labels = []

        img_files = os.walk(self.data_path).__next__()[2]

        for img in img_files:
            img_path = os.path.join(self.data_path, img)
            image = Image.open(img_path)
            if image is not None:
                if img[4:8] == 'MASK':
                    all_labels.append(img_path)
                else:
                    all_files.append(img_path)

        # 오름차순 정렬
        all_files.sort()
        all_labels.sort()

        return all_files, all_labels, len(all_files), len(all_labels)

    def __init__(self, data_path, img_transforms=None, mask_transforms=None):
        self.data_path = data_path
        self.img_transforms = img_transforms
        self.mask_transforms = mask_transforms
        self.img_files, self.img_masks, self.data_size, self.mask_size = self.readData()

    def __getitem__(self, index):
        image = self.img_files[index]
        mask = self.img_masks[index]

        image = Image.open(image)
        mask = Image.open(mask)
        if self.img_transforms is not None:
            image = self.img_transforms(image)
        
        if self.mask_transforms is not None:
            mask = self.mask_transforms(mask)

        return {'image':image, 'mask':mask}

    def __len__(self):
        return len(self.img_files)

Dataset 불러오기 & 시각화

In [None]:
import matplotlib.pyplot as plt

data_transforms = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), 
                        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
mask_transforms = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
total_dataset = CustomDataset('./ADC/', data_transforms, mask_transforms)
print('data size: {}'.format(total_dataset.data_size))
print('mask size: {}'.format(total_dataset.mask_size))

to_image = transforms.ToPILImage()
plt.figure(figsize=(8,20))
cnt = 1
for idx, item in enumerate(total_dataset):
    if idx==5: 
        break
    sample_image = to_image(total_dataset[idx]['image'])
    sample_mask = to_image(total_dataset[idx]['mask'])

    plt.subplot(5,2,cnt)
    plt.title('{}'.format(total_dataset.img_files[idx]))
    plt.imshow(sample_image)
    plt.subplot(5,2,cnt+1)
    plt.title('{}'.format(total_dataset.img_masks[idx]))
    plt.imshow(sample_mask)
    cnt+=2


모듈 구현

In [None]:
# 모듈 구현
def train_epoch(dataloader, model, optimizer):
    model.train()
    sum_loss = 0
    crit = nn.BCELoss()
   
    for item in dataloader:
        images = item['image'].to(device)
        labels = item['mask'].to(device)

        aux_out = model(images)['aux']
        aux_out = sigmoid(aux_out)
        outputs = model(images)['out']
        outputs = sigmoid(outputs)
        aux_loss = crit(aux_out, labels)
        main_loss = crit(outputs, labels)
        loss = main_loss*0.9 + aux_loss*0.1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sum_loss += loss.item() * len(images)

    # print('loss: {}'.format(loss.item()))
    # print('outputs: {}'.format(outputs[0]))
    # print('labels: {}'.format(labels[0]))
    
    return sum_loss / len(dataloader.dataset)

def validate_epoch(dataloader, model, optimizer):
    model.eval()
    sum_loss = 0
    crit = nn.BCELoss()

    with torch.no_grad():
        for item in dataloader:
            images = item['image'].to(device)
            labels = item['mask'].to(device)

            outputs = model(images)['out']
            outputs = sigmoid(outputs)
            loss = crit(outputs, labels)

            sum_loss += loss.item() * len(images)

    return sum_loss / len(dataloader.dataset)

def inference(dataloader, PATH, model_n, model_st, model_all):
    model = torch.load(PATH + model_n)
    model.load_state_dict(torch.load(PATH + model_st))
    checkpoint = torch.load(PATH+model_all)
    model.load_state_dict(checkpoint['model'])
    
    model.eval()
    threshold = 0.5
    with torch.no_grad():
        total = 0
        correct = 0
        iou_score = 0
        cnt = 0
    for item in dataloader:
        cnt += 1
        images = item['image'].to(device)
        labels = item['mask'].to(device)     

        outputs = model(images)['out']
        outputs = sigmoid(outputs)
        outputs[outputs<threshold] = 0
        outputs[outputs>threshold] = 1


        # find pixel accuracy
        total += 224.*224*len(images)
        result = torch.eq(outputs,labels)
        result = result.view(-1) # 1차원 배열로 만들기
        for val in result:
            if val == True:
                correct += 1

        # Mean Intersection-over-Union
        intersection = torch.logical_and(labels,outputs)
        union = torch.logical_or(labels,outputs)
        iou_score += torch.sum(intersection) / torch.sum(union)
    
    # print('total: {}, correct: {}'.format(total, correct))
    iou_score = iou_score/cnt
    pixel_accuracy = float(correct/total) * 100
    print('Test pixel accuracy of the model on the {} test images: {}%'.format(len(dataloader.dataset), pixel_accuracy))
    print('IOU score of the model on the {} test images: {}'.format(len(dataloader.dataset), iou_score))

    # 시각화
    to_image = transforms.ToPILImage()

    plt.figure(figsize=(8,12))
    fig = 1

    for i in range(3):
        out = to_image(outputs[i])
        # img = to_image(images[i])
        mask = to_image(labels[i])
        #plt.subplot(3,2,fig)
        # plt.imshow(img)
        # plt.title('image')
        plt.subplot(3,2,fig)
        plt.imshow(out)
        plt.title('predict')
        plt.subplot(3,2,fig+1)
        plt.imshow(mask)
        plt.title('mask')
        fig+=2


데이터셋 분할

In [None]:
train_size = int(total_dataset.data_size*0.8)
validation_size = int(total_dataset.data_size*0.1)
test_size = total_dataset.data_size - train_size - validation_size

print('train size: {}'.format(train_size))
print('validation size: {}'.format(validation_size))
print('test size: {}'.format(test_size))

train_dataset, validation_dataset, test_dataset = random_split(total_dataset, [train_size, validation_size, test_size],generator=torch.Generator().manual_seed(42))

FCN_Resenet50 모델 불러오기 & 훈련

In [None]:
import ssl

ssl._create_default_https_context = ssl._create_unverified_context
model = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet50', pretrained=True)

#fine tuning
num_classes = 1
num_ftrs1 = model.aux_classifier[4].in_channels
model.aux_classifier[4] = nn.Conv2d(num_ftrs1,num_classes,kernel_size=(1,1),stride=(1,1))
num_ftrs2 = model.classifier[4].in_channels
model.classifier[4] = nn.Conv2d(num_ftrs2,num_classes,kernel_size=(1,1),stride=(1,1))
sigmoid = nn.Sigmoid()

model.to(device)
print(model)

#하이퍼파라미터 설정
hy_batch = 32
hy_lr = 0.00001
hy_epoch = 100


In [None]:
PATH = '../../../home/bokyoungk/segmentation_models/resnet50/'
train_loader = DataLoader(train_dataset, batch_size=hy_batch, shuffle=False)
val_loader = DataLoader(validation_dataset, batch_size=hy_batch, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=hy_batch, shuffle=False)
optimizer = torch.optim.Adam(model.parameters(),lr=hy_lr)

all_train_loss = []
all_val_loss = []
min_loss = 0.2765420353412628

for e in range(0,hy_epoch):
    print('----------------------epoch {}--------------------------'.format(e+1))
    train_loss = train_epoch(train_loader,model,optimizer)
    val_loss = validate_epoch(val_loader,model,optimizer)
    print('train loss: {}'.format(train_loss))
    print('val loss: {}'.format(val_loss))
    all_train_loss.append(train_loss)
    all_val_loss.append(val_loss)

    # 모델 저장
    if min_loss > val_loss:
        min_loss = val_loss
        torch.save(model, PATH + 'min_model_{}th.pt'.format(e+1)) # 전체 모델 저장
        torch.save(model.state_dict(), PATH + 'min_model_state_dict_{}th.pt'.format(e+1))
        torch.save({
            'model':model.state_dict(),
            'optimizer':optimizer.state_dict()
        }, PATH+'min_all_{}th.tar'.format(e+1))



In [None]:
import numpy as np

# loss 그래프 그리기
x = np.arange(1,hy_epoch+1,step=1)
plt.figure(figsize=(8,6))
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(x,all_train_loss,label='train loss')
plt.plot(x,all_val_loss,label='val loss')
plt.legend()
plt.show()

In [None]:
PATH = '../../../home/bokyoungk/segmentation_models/resnet50/'
model_n = 'min_model_16th.pt'
model_st = 'min_model_state_dict_16th.pt'
model_all = 'min_all_16th.tar'
test_loader = DataLoader(test_dataset, batch_size=hy_batch, shuffle=True)
# Inference
inference(test_loader, PATH, model_n, model_st, model_all)

에폭에 따른 변화 관찰

In [None]:
test_loader = DataLoader(test_dataset, batch_size=hy_batch, shuffle=False)

PATH = '../../../home/bokyoungk/segmentation_models/resnet50/'
model_n = 'min_model_28th.pt'
model_st = 'min_model_state_dict_28th.pt'
model_all = 'min_all_28th.tar'

model = torch.load(PATH + model_n)
model.load_state_dict(torch.load(PATH + model_st))
checkpoint = torch.load(PATH+model_all)
model.load_state_dict(checkpoint['model'])
    
model.eval()
threshold = 0.5

with torch.no_grad():
    for item in test_loader:
        origins = item['image'].to(device)
        images = item['image'].to(device)
        labels = item['mask'].to(device)
        outputs = model(images)['out']
        outputs[outputs<threshold] = 0
        outputs[outputs>threshold] = 1
        break

# 특정 이미지만 출력
to_image = transforms.ToPILImage()
plt.figure(figsize=(5,3))
plt.title('epoch 28 predict')
output = to_image(outputs[2])
plt.imshow(output)

plt.figure(figsize=(5,3))
plt.title('mask image')
label = to_image(labels[2])
plt.imshow(label)

# plt.figure(figsize=(5,3))
# plt.title('image')
# origin = to_image(origins[1])
# plt.imshow(origin)