In [1]:
import glob
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
from tqdm.auto import tqdm
from torch.optim.adam import Adam
from torch.utils.data.dataloader import DataLoader
import torchvision.models as models
import random
import os
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import datetime
import pytz
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import re
import torch.nn.functional as F

## 파라미터 설정

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [3]:
'''
# 맥 mps 설정
device = torch.device("mps:0" if torch.backends.mps.is_available() else "cpu")
print(f"현재 디바이스는 {device} 입니다.")
'''

'\n# 맥 mps 설정\ndevice = torch.device("mps:0" if torch.backends.mps.is_available() else "cpu")\nprint(f"현재 디바이스는 {device} 입니다.")\n'

In [4]:
CFG = {
    'IMG_SIZE':512,
    'EPOCHS':50,
    'LEARNING_RATE':1e-4,
    'BATCH_SIZE': 64,
    'SEED':41,
    'MEAN' : [0.485, 0.456, 0.406],
    'STD'  : [0.229, 0.224, 0.225],
    'train_magnification':"20X",
    'test_magnification':"20X",
}

In [5]:
kst = pytz.timezone('Asia/Seoul')
current_datetime = datetime.datetime.now(kst)
formatted_datetime = current_datetime.strftime("%Y_%m_%d_%I:%M_%p")
print(formatted_datetime)

2023_06_08_05:26_PM


In [6]:
# server path
pth_name=f"/data/pthfile/train:{CFG['train_magnification']}_test:{CFG['test_magnification']}_epoch:{CFG['EPOCHS']}_({formatted_datetime}).pth"
output_name = f"train:{CFG['train_magnification']}_test:{CFG['test_magnification']}_epoch:{CFG['EPOCHS']}"
score_path = f"/data/output/Loss_Score"
figure_path = f"/data/output/figure/figure_{output_name}_({formatted_datetime})"
test_data_path = f"/data/PDA_mask_img/test_mask/{CFG['test_magnification']}/**/*.png"



print(f"pth_name:{pth_name}")
print(f"score_path:{score_path}")
print(f"figure_path:{figure_path}")
print(f"test_data_path:{test_data_path}")


pth_name:/data/pthfile/train:20X_test:20X_epoch:50_(2023_06_08_05:26_PM).pth
score_path:/data/output/Loss_Score
figure_path:/data/output/figure/figure_train:20X_test:20X_epoch:50_(2023_06_08_05:26_PM)
test_data_path:/data/PDA_mask_img/test_mask/20X/**/*.png


In [7]:
# local path
'''
pth_name=f"git_ignore/pthfile/train:{CFG['train_magnification']}X_test:{CFG['test_magnification']}X_epoch:{CFG['EPOCHS']}_({formatted_datetime}).pth"
output_name = f"train:{CFG['train_magnification']}X_test:{CFG['test_magnification']}X_epoch:{CFG['EPOCHS']}"
score_path = f"git_ignore/output/Loss_Score/score_{output_name}_({formatted_datetime})"
figure_path = f"git_ignore/output/figure/figure_{output_name}_({formatted_datetime})"
train_data_path = f"git_ignore/PDA_mask_img/train/{CFG['train_magnification']}X/**/*.png"
test_data_path = f"git_ignore/PDA_mask_img/test/{CFG['test_magnification']}X/**/*.png"
val_data_path = f"git_ignore/PDA_mask_img/validation/{CFG['train_magnification']}X/**/*.png"


print(f"pth_name:{pth_name}")
print(f"output_path:{output_path}")
print(f"figure_path:{figure_path}")
print(f"train_data_path:{train_data_path}")
print(f"test_data_path:{test_data_path}")
print(f"test_data_path:{val_data_path}")
'''

'\npth_name=f"git_ignore/pthfile/train:{CFG[\'train_magnification\']}X_test:{CFG[\'test_magnification\']}X_epoch:{CFG[\'EPOCHS\']}_({formatted_datetime}).pth"\noutput_name = f"train:{CFG[\'train_magnification\']}X_test:{CFG[\'test_magnification\']}X_epoch:{CFG[\'EPOCHS\']}"\nscore_path = f"git_ignore/output/Loss_Score/score_{output_name}_({formatted_datetime})"\nfigure_path = f"git_ignore/output/figure/figure_{output_name}_({formatted_datetime})"\ntrain_data_path = f"git_ignore/PDA_mask_img/train/{CFG[\'train_magnification\']}X/**/*.png"\ntest_data_path = f"git_ignore/PDA_mask_img/test/{CFG[\'test_magnification\']}X/**/*.png"\nval_data_path = f"git_ignore/PDA_mask_img/validation/{CFG[\'train_magnification\']}X/**/*.png"\n\n\nprint(f"pth_name:{pth_name}")\nprint(f"output_path:{output_path}")\nprint(f"figure_path:{figure_path}")\nprint(f"train_data_path:{train_data_path}")\nprint(f"test_data_path:{test_data_path}")\nprint(f"test_data_path:{val_data_path}")\n'

In [8]:
# 시드 고정
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

## 데이터 경로지정

In [9]:
test_path_list = sorted(glob.glob(test_data_path))
test_mask_path = test_path_list[0::2]
test_img_path = test_path_list[1::2]
test_mask_path[:5]

['/data/PDA_mask_img/test_mask/20X/C3L-01637-21/C3L-01637-21 [d=1.01174,x=10360,y=3108,w=518,h=518]-labelled.png',
 '/data/PDA_mask_img/test_mask/20X/C3L-01637-21/C3L-01637-21 [d=1.01174,x=10360,y=3626,w=518,h=518]-labelled.png',
 '/data/PDA_mask_img/test_mask/20X/C3L-01637-21/C3L-01637-21 [d=1.01174,x=10360,y=4144,w=518,h=518]-labelled.png',
 '/data/PDA_mask_img/test_mask/20X/C3L-01637-21/C3L-01637-21 [d=1.01174,x=10360,y=4662,w=518,h=518]-labelled.png',
 '/data/PDA_mask_img/test_mask/20X/C3L-01637-21/C3L-01637-21 [d=1.01174,x=10878,y=3626,w=518,h=518]-labelled.png']

In [10]:
# Custom Dataset
class CustomDataset(Dataset):
    def __init__(self, img_path, mask_path, transform = None):
        self.image = img_path
        self.mask = mask_path
        
        n_samples =  len(self.image)
        
        # 데이터 미리 섞어줌
        np.random.seed(CFG['SEED'])
        idxs = np.random.permutation(range(n_samples))
        
        self.image = np.array(self.image)[idxs]
        self.mask = np.array(self.mask)[idxs]
        self.transform = transform

    def __len__(self):
        return len(self.image) # 데이터셋 길이
    
    def __getitem__(self, i):
        image = np.array(Image.open(self.image[i]))
        mask = np.array(Image.open(self.mask[i]))
        data = self.transform(image = image, mask = mask)
        image = data["image"]
        mask = data["mask"]
        return image, mask

## 데이터 불러오기

In [11]:
test_transform = A.Compose([
        A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
        A.Normalize(mean=CFG['MEAN'], std = CFG['STD']),
        ToTensorV2(transpose_mask=True)
])

In [12]:
# 테스트 데이터
test_set = CustomDataset(img_path = test_img_path,
                         mask_path= test_mask_path,
                         transform = test_transform)

In [13]:

test_loader = DataLoader(test_set, batch_size = CFG["BATCH_SIZE"])

In [14]:
print(f"test_data : {len(test_set)}")

test_data : 81


## 학습

In [15]:
# Modeling
class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU(inplace=True)
        self.upconv1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.encoder.conv1(x)
        x1 = self.encoder.bn1(x1)
        x1 = self.encoder.relu(x1)
        x1 = self.encoder.maxpool(x1)

        x2 = self.encoder.layer1(x1)
        x3 = self.encoder.layer2(x2)
        x4 = self.encoder.layer3(x3)
        x5 = self.encoder.layer4(x4)

        # Decoder
        x = self.upconv1(x5)
        x = torch.cat((x, x4), dim=1)
        x = self.relu(self.conv1(x))

        x = self.upconv2(x)
        x = torch.cat((x, x3), dim=1)
        x = self.relu(self.conv2(x))

        x = self.upconv3(x)
        x = torch.cat((x, x2), dim=1)
        x = self.relu(self.conv3(x))

        x = self.conv4(x)

        # Resize to 512x512
        x = nn.functional.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)

        return x

In [16]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        inputs = F.sigmoid(inputs) # sigmoid를 통과한 출력이면 주석처리
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice 

In [17]:
# 학습 파라미터
model = UNet(num_classes=1).to(device)
model = nn.DataParallel(model)
optimizer = Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, threshold_mode='abs', min_lr=1e-8, verbose=True)
criterion = DiceLoss().to(device)



In [18]:
def dice_score(pred, target, smooth=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice

In [19]:
# 평균값 계산
class AverageMeter: 
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [20]:
class EarlyStop:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0

In [21]:
loss_meter = AverageMeter()
score_meter = AverageMeter()
early_stopping = EarlyStop(patience = 20, delta = 0)

## Test

In [22]:
model.load_state_dict(torch.load("data/pthfile/train:20X_test:20X_epoch:50_(2023_06_07_03:24_PM).pth", map_location=device))

<All keys matched successfully>

In [26]:
for i in range(len(test_set)):
    
    data, label = test_set[i]
    label = torch.squeeze(label)

    with torch.no_grad():
        out = model(torch.unsqueeze(data, dim=0).to(device))
    out = torch.squeeze(out).sigmoid().to('cpu')
    pred = torch.ge(out, 0.5).float().to('cpu')
    pred_array = pred.numpy() * 255
    pred_array = pred_array.astype(np.uint8)
    coordinates = re.search(r'x=(\d+),y=(\d+)', test_set.mask[i])
    cv2.imwrite(f'data/output/pred/{coordinates[0]}.png',pred_array)