In [None]:
cd ..

In [None]:
cd ..

In [None]:
cd ./disk1/colonoscopy_datasetv2/cropped

GPU 지정 & 모듈 임포트

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
# 모듈 import
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torchvision import transforms
import numpy as np

print(torch.cuda.is_available())

Custom Dataset 구현

In [None]:
class CustomDataset(Dataset):
    def readData(self):
        all_files = []
        all_labels = []

        img_files = os.walk(self.data_path).__next__()[2]

        for img in img_files:
            img_path = os.path.join(self.data_path, img)
            image = Image.open(img_path)
            if image is not None:
                if img[4:8] == 'MASK':
                    all_labels.append(img_path)
                else:
                    all_files.append(img_path)

        # 오름차순 정렬
        all_files.sort()
        all_labels.sort()

        return all_files, all_labels, len(all_files), len(all_labels)

    def __init__(self, data_path, transforms=None):
        self.data_path = data_path
        self.transforms = transforms
        self.img_files, self.img_masks, self.data_size, self.mask_size = self.readData()

    def __getitem__(self, index):
        image = self.img_files[index]
        mask = self.img_masks[index]

        image = Image.open(image)
        mask = Image.open(mask)
        if self.transforms is not None:
            image = self.transforms(image)
            mask = self.transforms(mask)

        return {'image':image, 'mask':mask}

    def __len__(self):
        return len(self.img_files)

Dataset 불러오기 & 시각화

In [None]:
import matplotlib.pyplot as plt

check_data_transforms = transforms.Compose([transforms.Resize((256,256)), transforms.ToTensor()])
total_dataset = CustomDataset('./ADC/', check_data_transforms)
print('data size: {}'.format(total_dataset.data_size))
print('mask size: {}'.format(total_dataset.mask_size))

to_image = transforms.ToPILImage()
plt.figure(figsize=(8,20))
cnt = 1
for idx, item in enumerate(total_dataset):
    if idx==5: 
        break
    sample_image = to_image(total_dataset[idx]['image'])
    sample_mask = to_image(total_dataset[idx]['mask'])

    plt.subplot(5,2,cnt)
    plt.title('{}'.format(total_dataset.img_files[idx]))
    plt.imshow(sample_image)
    plt.subplot(5,2,cnt+1)
    plt.title('{}'.format(total_dataset.img_masks[idx]))
    plt.imshow(sample_mask)
    cnt+=2


Annotation 확인

In [None]:
print('img shape: {}'.format(total_dataset[4]['image'].shape))
print('mask shape: {}'.format(total_dataset[4]['mask'].shape))
print('check img value: {}'.format(total_dataset[4]['image']))
print('check mask value: {}'.format(total_dataset[4]['mask']))


Custom Network 구현

In [None]:
# Simple Network 구현
class SimpleNet(nn.Module):
    def __init__(self, num_classes=1):
        super(SimpleNet, self).__init__()

        self.num_classes = num_classes

        # encoder
        self.enc_1 = self.conv_module(3,32)
        self.enc_2 = nn.Dropout2d(0.2)
        self.enc_3 = self.conv_module(32,32) # conv1
        self.enc_4 = nn.MaxPool2d(kernel_size=2) 
        self.enc_5 = self.conv_module(32,64)
        self.enc_6 = nn.Dropout2d(0.2)
        self.enc_7 = self.conv_module(64,64) # conv2
        self.enc_8 = nn.MaxPool2d(kernel_size=2)
        
        # decoder
        self.dec_1 = self.conv_module(64,128)
        self.dec_2 = nn.Dropout2d(0.2)
        self.dec_3 = self.conv_module(128,64) # conv3
        self.dec_4 = self.conv_module(128,64) # concat upsampling(conv3) + conv2
        self.dec_5 = nn.Dropout2d(0.2)
        self.dec_6 = self.conv_module(64,32) # conv4
        self.dec_7 = self.conv_module(64,32) # concat upsampling(conv4) + conv1 
        self.dec_8 = nn.Dropout2d(0.2)
        self.dec_9 = self.conv_module(32,32) # conv5

        self.output = nn.Conv2d(32, num_classes, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def conv_module(self, in_num, out_num):
        layer = nn.Sequential(nn.Conv2d(in_num, out_num, kernel_size=3, padding=1),
                            nn.ReLU())
        
        return layer

    def forward(self, x):
        out = self.enc_1(x)
        out = self.enc_2(out)
        conv1 = self.enc_3(out)
        out = self.enc_4(conv1)
        out = self.enc_5(out)
        out = self.enc_6(out)
        conv2 = self.enc_7(out)
        out = self.enc_8(conv2)

        out = self.dec_1(out)
        out = self.dec_2(out)
        conv3 = self.dec_3(out)
        up1 = nn.UpsamplingNearest2d(scale_factor=(2,2))(conv3)
        out = torch.cat((up1,conv2), dim=1)
        out = self.dec_4(out)
        out = self.dec_5(out)
        conv4 = self.dec_6(out)
        up2 = nn.UpsamplingNearest2d(scale_factor=(2,2))(conv4)
        out = torch.cat((up2, conv1), dim=1)
        out = self.dec_7(out)
        out = self.dec_8(out)
        out = self.dec_9(out)
        out = self.output(out)
        out = self.sigmoid(out)

        return out

In [None]:
from torchsummary import summary

# model summary
model = SimpleNet()
model.to(device)
summary(model, (3,256,256))

모듈 구현

In [None]:
# 모듈 구현
from torch import logical_and


def train_epoch(dataloader, model, optimizer):
    model.train()
    sum_loss = 0
    crit = nn.BCELoss()
   

    for item in dataloader:
        images = item['image'].to(device)
        labels = item['mask'].to(device)

        outputs = model(images)
        loss = crit(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sum_loss += loss.item() * len(images)

    # print('loss: {}'.format(loss.item()))
    # print('outputs: {}'.format(outputs[0]))
    # print('labels: {}'.format(labels[0]))
    
    return sum_loss / len(dataloader.dataset)

def validate_epoch(dataloader, model, optimizer):
    model.eval()
    sum_loss = 0
    crit = nn.BCELoss()

    with torch.no_grad():
        for item in dataloader:
            images = item['image'].to(device)
            labels = item['mask'].to(device)

            outputs = model(images)
            loss = crit(outputs, labels)

            sum_loss += loss.item() * len(images)

    return sum_loss / len(dataloader.dataset)

def inference(dataloader, PATH, model_n, model_st, model_all):
    model = torch.load(PATH + model_n)
    model.load_state_dict(torch.load(PATH + model_st))
    checkpoint = torch.load(PATH+model_all)
    model.load_state_dict(checkpoint['model'])
    
    model.eval()
    threshold = 0.5
    with torch.no_grad():
        total = 0
        correct = 0
        iou_score = 0
        cnt = 0
    for item in dataloader:
        cnt += 1
        images = item['image'].to(device)
        labels = item['mask'].to(device)     

        outputs = model(images)
        outputs[outputs<threshold] = 0
        outputs[outputs>threshold] = 1

        # find pixel accuracy
        total += 256.*256*len(images)
        result = torch.eq(outputs,labels)
        result = result.view(-1) # 1차원 배열로 만들기
        for val in result:
            if val == True:
                correct += 1

        # Mean Intersection-over-Union
        intersection = torch.logical_and(labels,outputs)
        union = torch.logical_or(labels,outputs)
        iou_score += torch.sum(intersection) / torch.sum(union)
    
    # print('total: {}, correct: {}'.format(total, correct))
    iou_score = iou_score/cnt
    pixel_accuracy = float(correct/total) * 100
    print('Test pixel accuracy of the model on the {} test images: {}%'.format(len(dataloader.dataset), pixel_accuracy))
    print('IOU score of the model on the {} test images: {}'.format(len(dataloader.dataset), iou_score))

    # 시각화
    to_image = transforms.ToPILImage()

    plt.figure(figsize=(8,12))
    fig = 1

    for i in range(3):
        out = to_image(outputs[i])
        img = to_image(images[i])
        mask = to_image(labels[i])
        plt.subplot(3,3,fig)
        plt.imshow(img)
        plt.title('image')
        plt.subplot(3,3,fig+1)
        plt.imshow(out)
        plt.title('predict')
        plt.subplot(3,3,fig+2)
        plt.imshow(mask)
        plt.title('mask')
        fig+=3


데이터셋 분할

In [None]:
train_size = int(total_dataset.data_size*0.8)
validation_size = int(total_dataset.data_size*0.1)
test_size = total_dataset.data_size - train_size - validation_size

print('train size: {}'.format(train_size))
print('validation size: {}'.format(validation_size))
print('test size: {}'.format(test_size))

train_dataset, validation_dataset, test_dataset = random_split(total_dataset, [train_size, validation_size, test_size],generator=torch.Generator().manual_seed(42))

모델 훈련

In [None]:
# 하이퍼파라미터 설정
hy_batch = 32
hy_lr = 0.00001
hy_epoch = 80

train_loader = DataLoader(train_dataset, batch_size=hy_batch, shuffle=False)
validation_loader = DataLoader(validation_dataset, batch_size=hy_batch, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=hy_batch, shuffle=False)


In [None]:
# 모델 저장
PATH = '../../../home/bokyoungk/segmentation_models/net1/'

# model_n = 'min_model_98th.pt'
# model_st = 'min_model_state_dict_98th.pt'
# model_all = 'min_all_98th.tar'

# model = torch.load(PATH+model_n)
# optimizer = torch.optim.Adam(model.parameters(), lr=hy_lr)
# model.load_state_dict(torch.load(PATH+model_st))

# checkpoint = torch.load(PATH+model_all)
# model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])

optimizer = torch.optim.Adam(model.parameters(), lr=hy_lr)
# Training
min_loss = 0.510586462020874
all_train_loss = []
all_val_loss = []

for e in range(0,hy_epoch):
    print('---------------------------epoch {}-------------------------------'.format(e+1))
    train_loss = train_epoch(train_loader,model,optimizer)
    validation_loss = validate_epoch(validation_loader,model,optimizer)
    print('train loss= {}'.format(train_loss))
    print('validation loss= {}'.format(validation_loss))
    all_train_loss.append(train_loss)
    all_val_loss.append(validation_loss)


    # 모델 저장
    # if min_loss > validation_loss:
    #     min_loss = validation_loss
    #     torch.save(model, PATH + 'min_model_{}th.pt'.format(e+1)) # 전체 모델 저장
    #     torch.save(model.state_dict(), PATH + 'min_model_state_dict_{}th.pt'.format(e+1))
    #     torch.save({
    #         'model':model.state_dict(),
    #         'optimizer':optimizer.state_dict()
    #     }, PATH+'min_all_{}th.tar'.format(e+1))

In [None]:
import numpy as np
# loss 그래프 그리기
x = np.arange(1,hy_epoch+1,step=1)
plt.figure(figsize=(8,6))
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(x,all_train_loss,label='train loss')
plt.plot(x,all_val_loss,label='val loss')
plt.legend()
plt.show()

In [None]:
PATH = '../../../home/bokyoungk/segmentation_models/net1/'
model_n = 'min_model_69th.pt'
model_st = 'min_model_state_dict_69th.pt'
model_all = 'min_all_69th.tar'

# Inference
inference(test_loader, PATH, model_n, model_st, model_all)

에폭에 따른 변화 관찰

In [None]:
PATH = '../../../home/bokyoungk/segmentation_models/net1/'
model_n = 'min_model_144th.pt'
model_st = 'min_model_state_dict_144th.pt'
model_all = 'min_all_144th.tar'

model = torch.load(PATH + model_n)
model.load_state_dict(torch.load(PATH + model_st))
checkpoint = torch.load(PATH+model_all)
model.load_state_dict(checkpoint['model'])
    
model.eval()
threshold = 0.5

with torch.no_grad():
    for item in test_loader:
        origins = item['image'].to(device)
        images = item['image'].to(device)
        labels = item['mask'].to(device)
        outputs = model(images)
        outputs[outputs<threshold] = 0
        outputs[outputs>threshold] = 1
        break

# 특정 이미지만 출력
to_image = transforms.ToPILImage()
plt.figure(figsize=(5,3))
plt.title('epoch 144 predict')
output = to_image(outputs[3])
plt.imshow(output)

plt.figure(figsize=(5,3))
plt.title('mask image')
label = to_image(labels[3])
plt.imshow(label)

# plt.figure(figsize=(5,3))
# plt.title('image')
# origin = to_image(origins[1])
# plt.imshow(origin)

