In [1]:
import random
import os
from time import time

import click
import numpy as np

import torch
from torch.utils.tensorboard import SummaryWriter

from neutorch.model.IsoRSUNet import UNetModel
from neutorch.model.io import save_chkpt, log_tensor
from neutorch.model.loss import BinomialCrossEntropyWithLogits
from neutorch.dataset.affinity import Dataset

%load_ext autoreload
%autoreload 2

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits



In [9]:
path = '../../data'
patch_size = (6,64,64)
length = 200000
dataset = Dataset(path, patch_size=patch_size, length=length)

In [9]:
for i in range(3):
    x, y = dataset[i]
    print(x.shape,y.shape)

print(len(dataset))

(1, 6, 64, 64) (13, 6, 64, 64)
(1, 6, 64, 64) (13, 6, 64, 64)
(1, 6, 64, 64) (13, 6, 64, 64)
200000


In [10]:
import torch 

def merge(x,H,W):

        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        # reduce
        x = x[:,:,:C*2]

        return x


def expand(x, H, W):
        B, L, C = x.shape
        # expansion(x)  # B L 2C
        x = torch.cat([x,x],-1)

        x = x.view(B, H, W, C*2)  # B H W 2C

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        c = 2*C//4
        # get mutliple pixel repesentation from channels
        c0 = x[:, :, :, :c]  # B H W 2C/4
        c1 = x[:, :, :, c:2*c]  # B H W 2C/4
        c2 = x[:, :, :, 2*c:3*c]  # B H W 2C/4
        c3 = x[:, :, :, 3*c:]  # B H W 2C/4

        # # insert side by side into new array
        new_x = torch.zeros((B, H*2, W*2, c))  # B 2H 2W C/2

        new_x[:, 0::2, 0::2, :] = c0
        new_x[:, 1::2, 0::2, :] = c1
        new_x[:, 0::2, 1::2, :] = c2
        new_x[:, 1::2, 1::2, :] = c3

        x = x.view(B, -1, C//2)  # B L C/2
        return x

In [11]:
H = 4
W = 4
C = 1
B = 1
input = torch.arange(0,H*W*C*B)
input = torch.reshape(input,(B,H,W,C))
print(input[0,:,:,0], input.shape)
input = input.view(C,-1,B)
merged = merge(input,H,W)

print(merged.shape)
print(merged[0,:,0])
print(merged[0,:,1])

expanded = expand(merged,2,2)
print(expanded.shape)
expanded = expanded.view(C,H,W,B)
print(expanded[0,:,:,0], expanded.shape)


tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]]) torch.Size([1, 4, 4, 1])
torch.Size([1, 4, 2])
tensor([ 0,  2,  8, 10])
tensor([ 4,  6, 12, 14])
torch.Size([1, 8, 1])


RuntimeError: shape '[1, 4, 4, 1]' is invalid for input of size 8

In [3]:

zz = torch.zeros((4,4,2))
zz[:,:,0] = torch.reshape(torch.arange(0,16),(4,4))
zz[:,:,1] = torch.reshape(torch.arange(0,16),(4,4))
print(zz[:,:,1])
print(zz[:,:,0])
zz = torch.moveaxis(zz, -1, 0)
print(zz.shape)
print(zz[0,:,:])
print(zz[1,:,:])

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
torch.Size([2, 4, 4])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])


tensor([[0., 1., 2., 3.],
        [0., 1., 2., 3.]])
tensor([[4., 5., 6., 7.],
        [4., 5., 6., 7.]])
