In [1]:
import os
import glob
import datetime
import time
from tqdm import tqdm

import numpy as np

import cv2
%matplotlib inline
from matplotlib import pyplot as plt
import skimage
from skimage import exposure, io



import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.utils import save_image

#torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark = False

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp


dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)
print(torch.__version__)



cuda
1.8.0+cu111


In [2]:
# image path
img_path_train="../train/"
img_path_valid="../val/"


# figure size 
figsize = 512

# figure size in dataloader
figsize2=512

# figure size for the correction of the position 
figsize3 = 512
cutwidth = figsize3//20


# batch size 
batchsize=4

# epoch
epoch_init = 0
epoch_max = 10

# encoder
encoder='resnet18'
#encoder='resnet34'
#encoder='resnet50'
#encoder='resnet101'
#encoder='resnet151'

# weights
encoder_weights='imagenet'
#encoder_weights='ssl'
#encoder_weights='swsl'

In [3]:
def preprocess(img,mn=0,mx=100):
    mx1 = np.max(img).astype(np.float32)
    mn1 = np.min(img).astype(np.float32)
    img = (img.astype(np.float32)-mn1)/(mx1-mn1)
#    img = exposure.equalize_adapthist(img.astype(np.uint16))
#    img = exposure.equalize_hist(img)

    i, j = np.percentile(img, (mn,mx))
    img = exposure.rescale_intensity(img, in_range=(i,j)).astype(np.float32)
    return img

In [4]:
def preprocess_all(img1, img2):
    img1_p = preprocess(img1, mn=0, mx=100)
    img2_p = preprocess(img2, mn=1, mx=80)

    img1_t=torch.from_numpy(img1_p)
    img2_t=torch.from_numpy(img2_p)

    img1_t = img1_t.unsqueeze(0).unsqueeze(0)
    img2_t = img2_t.unsqueeze(0).unsqueeze(0)
    return img1_t, img2_t

In [5]:
imglist = sorted(glob.glob(img_path_train+"*.tif"))

image_truth_train = []
image_input_train = []

    
j=0
for i in tqdm(imglist):
#    print(i)
    if j%4 == 0:
        img1 = io.imread(i).astype(np.float32)
        img1 = cv2.resize(img1, (figsize3,figsize3))
        i2 = i
        j=j+1
        continue
    elif j%4 == 1:
        img2 = io.imread(i).astype(np.float32)
        img2 = cv2.resize(img2, (figsize3,figsize3))
        j=j+1
    elif j%4 == 2:
        img2 = io.imread(i).astype(np.float32)
        img2 = cv2.resize(img2, (figsize3,figsize3))
        j=j+2
    else:
        print('error: j={}'.format(j))
        
    
    img1_h, img1_w = img1.shape
    img2_h, img2_w = img2.shape
        
    img1_p = preprocess(img1)
    img2_p = preprocess(img2, mn=1, mx=80)

    image_truth_train.append(img1_p)
    image_input_train.append(img2_p)



100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  7.56it/s]


In [6]:
imglist = sorted(glob.glob(img_path_valid+"*.tif"))

image_truth_valid = []
image_input_valid = []
    
j=0
for i in tqdm(imglist):
    if j%4 == 0:
        img1 = io.imread(i).astype(np.float32)
        img1 = cv2.resize(img1, (figsize3,figsize3))
        i2 = i
        j=j+1
        continue
    elif j%4 == 1:
        img2 = io.imread(i).astype(np.float32)
        img2 = cv2.resize(img2, (figsize3,figsize3))
        j=j+1
    elif j%4 == 2:
        img2 = io.imread(i).astype(np.float32)
        img2 = cv2.resize(img2, (figsize3,figsize3))
        j=j+2
    else:
        print('error: j={}'.format(j))
        
        
    img1_p = preprocess(img1)
    img2_p = preprocess(img2, mn=1, mx=80)


    image_truth_valid.append(img1_p)
    image_input_valid.append(img2_p)
    


100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  7.93it/s]


In [7]:
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    ToTensorV2(),
])

trans2_valid = A.Compose([
    ToTensorV2(),
])

In [8]:
class My_train_dataset(torch.utils.data.Dataset):
    def __init__(self,img_in,img_th):
        self.img_in = img_in
        self.img_th = img_th
        self.transform = transform

    def __len__(self):
        return(self.img_in.__len__())
    
    def __getitem__(self,idx):
        transformed = transform(image=self.img_in[idx], mask=self.img_th[idx])
        x = transformed['image']
        y = transformed['mask']
        y = y.unsqueeze(0)
        return x, y

train_dataset = My_train_dataset(image_input_train, image_truth_train)

In [9]:
class My_valid_dataset(torch.utils.data.Dataset):
    def __init__(self,img_in,img_th):
        self.img_in = img_in
        self.img_th = img_th
        self.transform = transform

    def __len__(self):
        return(self.img_in.__len__())
    
    def __getitem__(self,idx):
        x = self.img_in[idx]
        y = self.img_th[idx]
        
        transformed = trans2_valid(image=x)
        x = transformed['image']
        transformed =trans2_valid(image=y)
        y = transformed['image']
        return x, y

valid_dataset = My_valid_dataset(image_input_valid, image_truth_valid)

In [10]:
print(train_dataset.__len__(), valid_dataset.__len__())

2 2


In [11]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1)

In [12]:
model = smp.Unet(encoder, in_channels=1)

In [13]:
loss_fn = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [14]:
res_c_tloss = []
res_c_vloss = []

model = model.to(dev)
model.train()


optimizer.zero_grad()

c_vloss_best=100.

for i in range(epoch_init+1, epoch_max+1):
    c_tloss = 0.
    for i2, (inputs, labels) in enumerate(train_loader):

        inputs, labels = inputs.to(dev), labels.to(dev)
        outputs = model(inputs)
        train_loss = loss_fn(outputs, labels)
        c_loss = train_loss
        c_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        c_tloss += c_loss.item()
        
    model.eval()
    c_vloss=0.
    
    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(dev), labels.to(dev)
            outputs = model(inputs)
            valid_loss = loss_fn(outputs, labels)
            c_loss = valid_loss
            c_vloss += c_loss.item()

    res_c_tloss.append(c_tloss/len(train_loader))
    res_c_vloss.append(c_vloss/len(valid_loader))
    print('Epoch {} {}: {:.4} {:.4} {:.4}'.format(i, datetime.datetime.now(), optimizer.param_groups[0]['lr'], c_tloss/len(train_loader), c_vloss/len(valid_loader)))

    model.train()
    

Epoch 1 2022-07-26 14:51:45.485583: 0.0001 0.7412 0.2703
Epoch 2 2022-07-26 14:51:45.756867: 0.0001 0.6967 0.3771
Epoch 3 2022-07-26 14:51:46.036672: 0.0001 0.6608 0.4503
Epoch 4 2022-07-26 14:51:46.305525: 0.0001 0.6476 0.5204
Epoch 5 2022-07-26 14:51:46.577028: 0.0001 0.6144 0.5834
Epoch 6 2022-07-26 14:51:46.847632: 0.0001 0.559 0.6137
Epoch 7 2022-07-26 14:51:47.130790: 0.0001 0.5632 0.6261
Epoch 8 2022-07-26 14:51:47.405477: 0.0001 0.5106 0.6288
Epoch 9 2022-07-26 14:51:47.672136: 0.0001 0.5028 0.6222
Epoch 10 2022-07-26 14:51:47.953168: 0.0001 0.4664 0.6197
