In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torchvision 
import numpy as np 
import PIL
from PIL import Image, ImageFilter 
import os 
import random
import glob
import pdb
import math
import random
import time
from torchvision.utils import save_image
import torchvision.transforms.functional as tf 

In [2]:
def get_mean_std(images_list) : 
    pixels = [] 
    for i, filepath in enumerate(images_list) : 
        img = Image.open(filepath)
        try : 
            img = tf.to_tensor(img) 
            pixels.append(img.view(-1))
        except TypeError : 
            print(f'{filepath} is truncated')
        if i % 500 == 0 : 
            print(f'{i}/{len(images_list)}')
        if i == 2000 :
            break # Out of memory..
    pixels = torch.cat(pixels, dim=0)
    return torch.std_mean(pixels, dim=0)
        

In [3]:
def sync_transform(*images) : 
    w,h = images[0].size # assuming w=512, h<=1024
    assert w == 512 
    if h < 1024 : 
        # pad to 1024
        diff = 1024-h 
        images =  [tf.pad(image, (0,0,diff//2, diff-diff//2)) for image in images]
    w,h = images[0].size
    assert h == 1024
    
    # random horizontal flip.un
        images = [tf.hflip(image) for image in images]
    
    # random pad
    #if random.random() < 0.5 : 
    #    images = [tf.pad(image, 64) for image in images]
    
    # random rotation
    angle = 0 
    if random.random() < 0.5 : 
        angle = random.randint(-15, 15)
        #images = [tf.rotate(image, angle, resample=PIL.Image.BILINEAR) for image in images]
    
    # random scale 
    scale = 1 
    if random.random() < 0.5 : 
        scale = random.uniform(7/8, 9/8)
    
    images = [tf.affine(image, angle=angle, scale=scale, translate = (0,0), shear=0,
                        resample = PIL.Image.BILINEAR) for image in images]
    
    images = [tf.pad(image, 64) for image in images]
    
    W,H = images[0].size 
    assert H >= h 
    assert W >= w 
    
    h_diff = H-h
    w_diff = W-w 
    
    if random.random() < 0.5 : 
        # Center crop with 50% chance 
        h_start, w_start = h_diff//2, w_diff//2
    else : 
        h_start, w_start = random.randint(0, h_diff), random.randint(0, w_diff)
    
    images = [tf.crop(image, h_start, w_start, h, w) for image in images]
    
    return images
    

In [4]:
def integer_to_channels(target) : 
    target = (tf.to_tensor(target)*255).int()
    c,h,w = target.shape 
    assert c == 1 
    target = target.squeeze(0).numpy() # (h,w)
    
    output = np.zeros((h,w,3), dtype=np.uint8)
    output[:,:,0] = (target==1).astype(np.uint8)*255
    output[:,:,1] = (target==2).astype(np.uint8)*255
    output[:,:,2] = (target==3).astype(np.uint8)*255
    
    return Image.fromarray(output)

In [5]:
class oct_dataset(object) : 
    def __init__(self, data_path='./oct_data/images', label_path='./oct_data/labels', 
                 sync_transform = None, transform=None) : 
        self.data = []
        self.transform=transform
        self.sync_transform = sync_transform
        self.totensor = torchvision.transforms.ToTensor()
        
        filenames = os.listdir(data_path)
        self.data = [(os.path.join(data_path,filename),os.path.join(label_path,filename)) for filename in filenames]
        
            
    def __len__(self) : 
        return len(self.data)
    
    def __getitem__(self, index) : 
        image_path, label_path = self.data[index]
        image_name = os.path.basename(image_path)
        
        
        image = Image.open(image_path)
        label = Image.open(label_path)
        label = integer_to_channels(label)
        
        if self.sync_transform is not None : 
            image, label = self.sync_transform(image, label)
        if self.transform is not None : 
            image = self.transform(image)
            
        image = image.filter(ImageFilter.MedianFilter(size = 5)) 
        
        image = self.totensor(image)
        label = self.totensor(label)
        
        if True : 
            image = tf.normalize(image, 0.1410, 0.0941, inplace=True)
            
        label = label*255 
        
        return image, label,image_name

In [6]:
dataset = oct_dataset(sync_transform=sync_transform)
image, target, name = dataset[1000]
print(name)
save_image(image, 'test_i.jpg')
save_image(target, 'test_t.jpg')

FileNotFoundError: [WinError 3] 지정된 경로를 찾을 수 없습니다: './oct_data/images'

In [8]:
#std, mean = get_mean_std(glob.glob('./oct_data/images/*.png'))
#print(std, mean) # 0.0941, 0.1410 for 2000 samples 

In [None]:
img = Image.open('./oct_data/labels/190.png')
tensor = tf.to_tensor(img)

In [9]:
# from : https://github.com/milesial/Pytorch-UNet

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    
    
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [11]:
model = UNet(1, 3)
criterion = torch.nn.BCELoss()
output = model(dataset[0][0].unsqueeze(0))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
print(output.shape)
del output

torch.Size([1, 3, 1024, 512])


In [13]:
for i, (image, target, name) in enumerate(dataloader) : 
    output = model(image)
    loss = criterion(F.sigmoid(output), target)
    print(loss)
    break

tensor(0.6245, grad_fn=<BinaryCrossEntropyBackward0>)




In [None]:
output_.save('test.jpg')

In [None]:
x = torch.randn(10)*10
print(x, torch.round(x).int())