In [None]:
from visualization.utils import display_img
from skimage import io, exposure
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import time

## Visualization

In [None]:
stpt_img = io.imread('data/train/STPT/S014_Z00.tif')
# plt.imshow(img, cmap='gray')

imc_img = io.imread('data/train/IMC/SECTION_01/131Xe.tif')

In [None]:
imc_img = torch.from_numpy(imc_img.astype(np.int16))
imc_img = torch.unsqueeze(imc_img, 0)

stpt_img = torch.from_numpy(stpt_img).permute((2,0,1))

In [None]:
stpt_img = torchvision.transforms.Resize(18720)(stpt_img)
combine = torch.cat((stpt_img, imc_img), 0)
combine.shape

In [None]:
x = [torch.split(combine, 4)]
y = map(list, zip(*x))

In [None]:
type(x[0])

## Concatenating IMC images

In [10]:
import os, shutil
import torch
import cv2 as cv
import numpy as np
import time
from torch.multiprocessing import Pool, set_start_method
from skimage import io
from torchvision import transforms

In [2]:
def process_imc_image(file_name, plot=False, bits=8, v=[0, 256]):
    # read image file
    img = cv.imread(file_name, cv.IMREAD_UNCHANGED)

    # normalize image
    norm_img = img.copy()
    cv.normalize(img, norm_img, alpha=0, beta=2**bits - 1, norm_type=cv.NORM_MINMAX)

    # Apply log transformation method
    c = (2**bits - 1) / np.log(1 + np.max(norm_img))
    log_image = c * (np.log(norm_img + 1))
    
    if plot:
        plt.imshow(log_image, cmap='gray', vmin=v[0], vmax=v[1])
    else:
        # Specify the data type so that
        # float value will be converted to int
        return torch.from_numpy(log_image.astype('uint8'))
    
def process_stpt_image(file_name):
    img = io.imread(file_name)
    return torch.from_numpy(img.astype('uint8'))

    
imc_section_01_folder = 'data/train/IMC/SECTION_01/'
imc_img_paths = [os.path.join(imc_section_01_folder, imc_img)
            for imc_img in os.listdir(imc_section_01_folder)
            if imc_img.endswith('.tif')]

stpt_img_paths = [os.path.join('data/train/STPT/',
                               'S{0}_Z{1}.tif'.format(str(1).zfill(3),
                                                  optical_section.zfill(2)))
                  for optical_section in ['0', '1']]

In [3]:
print('LOADING IMAGES')
start = time.time()
with Pool(maxtasksperchild=100) as p:
    imc_imgs = list(p.imap(process_imc_image, imc_img_paths))
    stpt_imgs = list(p.imap(process_stpt_image, stpt_img_paths))
end = time.time()
print('Loading STPT images took', end-start, 'seconds')

LOADING IMAGES
Loading STPT images took 41.762733697891235 seconds


In [4]:
# postprocess loaded images
print('process imc')
imc_imgs = [torch.unsqueeze(img, 0) for img in imc_imgs] # add an extra dimesion for channel
print('concat imc')
imc_imgs_cat = torch.cat(imc_imgs, 0) # (40, 18720, 18720)

print('permute stpt')
stpt_imgs = [img.permute((2,0,1)) for img in stpt_imgs] # (C,H,W) tensor
print('cat stpt')
stpt_imgs_cat = torch.cat(stpt_imgs, 0) # concatenate two stpt images (8, 20800, 20800)

process imc
concat imc
permute stpt
cat stpt


In [8]:
imc_imgs[0].shape[1]

18720

In [None]:
# ====== TRANSFORMS ======
print('PERFORMING TRANSFORMS')

stpt_imgs_cat = transforms.Resize(imc_imgs[0].shape[1])(stpt_imgs_cat)  # make STPT img same size as IMC (..., 18720, 18720)
combine = torch.cat((imc_imgs_cat, stpt_imgs_cat), 0) # combine imc and stpt -> (48, 18720, 18720)

# obtain a batch of random crops
img_set = [transforms.Compose([transforms.RandomCrop(256)])(combine) for i in range(64)]

# separate imc and stpt -> (40, 18720, 18720), (8, 18720, 18720)
imc_imgs = [torch.split(img, 40)[0] for img in img_set]
stpt_imgs = [torch.split(img, 40)[1] for img in img_set]

In [19]:
print('imc',imc_imgs[0].shape)
print('stpt', stpt_imgs[0].shape)

imc torch.Size([40, 256, 256])
stpt torch.Size([8, 256, 256])


## Batches

In [7]:
import torch
imgs = [[torch.ones((8, 256, 256)) for i in range(64)]]
targets = [[torch.ones((40, 256, 256)) for i in range(64)]]

In [11]:
x = [torch.ones((8, 256,256)) for i in range(64)]
y = [torch.ones((40, 256,256)) for i in range(64)]
batch = (x, y)

In [17]:
print(len(batch))
print(len(batch[0]))
print(len(batch[1]))
print(batch[0][0].shape)
print(batch[1][0].shape)

2
64
64
torch.Size([8, 256, 256])
torch.Size([40, 256, 256])


In [23]:
imc = torch.stack(batch[0]).cuda().double()
imc.type()

'torch.cuda.DoubleTensor'