In [1]:
import skimage

In [2]:
skimage.__version__

'0.15.0'

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import os
from skimage.transform import rescale, resize, downscale_local_mean
from skimage import io, transform

In [4]:
import torch
import torch.nn as nn
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as utils

In [6]:
DATA_DIR_DEEPTHOUGHT="/storage/yw18581/data"
data_dir = DATA_DIR_DEEPTHOUGHT
train_test = os.path.join(data_dir, "train_validation_test")

In [7]:
data = np.load("/storage/yw18581/data/train_validation_test/Xy_train+val_clean_300_24_10_25.npz")
x = data["x"]
y = data['y']



In [17]:
tensor_x = torch.from_numpy(x) # transform to torch tensors
tensor_y = torch.from_numpy(y)

xy_dataset = utils.TensorDataset(tensor_x,tensor_y)

In [22]:
class UNetDataset(Dataset):
    def __init__(self, X, Y, transform=None):
        self.transform = transform
        self._X = X
        self._Y = Y

    def __getitem__(self, idx):
        image = self._X[idx]
        mask = self._Y[idx]
        sample = {'image': image, 'masks': mask}

        if self.transform:
            sample = self.transform(sample)
        return sample

    def __len__(self):
        return len(self._X)


class ChannelsFirst:
    def __call__(self, sample):
        image, mask = sample['image'], sample['masks']
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.swapaxes(2,0)
        mask = mask.swapaxes(2,0)
        return {'image': image,
                'masks': mask}



```python

resizer = transforms.Resize(350)
out_image = resizer(image)

```

```python 

from skimage.transform import rescale
from functools import partial

resizer = partial(rescale, scale=0.25, anti_aliasing=True, multichannel=True)
out_image = resizer(image)

```

In [9]:
from functools import partial

In [10]:

class Rescale:
    
    def __init__(self, scale):
        assert isinstance(scale, float)
        self.output_scale = scale
    
    def __call__(self, sample):
        image, mask = sample['image'], sample['masks']

        resizer = partial(rescale, scale=self.output_scale, anti_aliasing=True, multichannel=True)
        out_image = resizer(image)
        out_mask = resizer(mask)

        return {'image': out_image,
                'masks': out_mask}

In [15]:
class ToTensor:
    def __call__(self, sample):
        image, mask = sample['image'], sample['masks']
        img_tensor = torch.from_numpy(image)
        mask_tensor = torch.from_numpy(mask)
        return {'image': img_tensor,
               'masks': mask_tensor}

In [16]:
composed = transforms.Compose([Rescale(0.25), ChannelsFirst(), ToTensor()])

In [24]:
train_dataset = UNetDataset(x,y, transform=composed)

In [25]:
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)

In [26]:
for i, batch in enumerate(dataloader):
    print(batch['image'].size(), batch['image'].dtype)
    if i==4:
        break

torch.Size([4, 1, 350, 350]) torch.float64
torch.Size([4, 1, 350, 350]) torch.float64
torch.Size([4, 1, 350, 350]) torch.float64
torch.Size([4, 1, 350, 350]) torch.float64
torch.Size([4, 1, 350, 350]) torch.float64


In [27]:
resizer = partial(rescale, scale=.25, anti_aliasing=True, multichannel=True)

In [28]:
x1 = np.asarray(list(map(resizer, x)))

In [29]:
x1.shape

(960, 350, 350, 1)

In [30]:
x_resize = x1.swapaxes(3, 1)
x_resize.shape

(960, 1, 350, 350)

In [42]:
m = map(lambda x: x, range(10))
np.fromiter(m, dtype=int)

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [43]:
y_resize = np.asarray(list(map(resizer, y)))
y_resize = y_resize.swapaxes(3, 1)

y_resize.shape

(960, 1, 350, 350)

In [44]:
np.savez_compressed("/storage/yw18581/data/train_validation_test/Xy_train+val_clean_300_24_10_25_resized_ch_first.npz",
                   x = x_resize, y=y_resize)

In [None]:
rescale(x[...,0][0], 1./4., anti_aliasing=True).shape

In [None]:
reshaped = []
for i in range(x.shape[0]):
    print(i)
    reshaped.append(rescale(x[...,0][i], 1./4., anti_aliasing=True))
reshaped = np.asarray(reshaped)