In [1]:
import random
import os
from time import time

import click
import numpy as np

import torch
from torch.utils.tensorboard import SummaryWriter

from neutorch.model.IsoRSUNet import UNetModel
from neutorch.model.io import save_chkpt, log_tensor
from neutorch.model.loss import BinomialCrossEntropyWithLogits
from neutorch.dataset.affinity import Dataset

%load_ext autoreload
%autoreload 2

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits



In [9]:
path = '../../data'
patch_size = (6,64,64)
length = 200000
dataset = Dataset(path, patch_size=patch_size, length=length)

In [9]:
for i in range(3):
    x, y = dataset[i]
    print(x.shape,y.shape)

print(len(dataset))

(1, 6, 64, 64) (13, 6, 64, 64)
(1, 6, 64, 64) (13, 6, 64, 64)
(1, 6, 64, 64) (13, 6, 64, 64)
200000


In [16]:
import torch 
from einops import rearrange

def merge(x):

        B, D, H, W, C = x.shape

        x0 = x[:, :, 0::2, 0::2, :]  # B D H/2 W/2 C
        x1 = x[:, :, 1::2, 0::2, :]  # B D H/2 W/2 C
        x2 = x[:, :, 0::2, 1::2, :]  # B D H/2 W/2 C
        x3 = x[:, :, 1::2, 1::2, :]  # B D H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B D H/2 W/2 4*C


        # reduce
        x = x[:,:,:,:,:C*2]

        return x

      
def expand(x):
        B, D, H, W, C = x.shape

        # expansion(x)  # B L 2C
        x = torch.cat([x,x],-1)


        # c = 2*C//4
        # # get mutliple pixel repesentation from channels
        # c0 = x[:, :, :, :, :c]       # B D H W 2C/4
        # c1 = x[:, :, :, :, c:2*c]    # B D H W 2C/4
        # c2 = x[:, :, :, :, 2*c:3*c]  # B D H W 2C/4
        # c3 = x[:, :, :, :, 3*c:]     # B D H W 2C/4

        # # insert side by side into new array
        # # maybe there is better way to do this that doesnt init new array
        # device = x.get_device()
        # if device < 0:
        #     device = None
        # new_x = torch.zeros(
        #     (B, D, H*2, W*2, c), device=device)  # B 2H 2W C/2

        # new_x[:, :, 0::2, 0::2, :] = c0
        # new_x[:, :, 1::2, 0::2, :] = c1
        # new_x[:, :, 0::2, 1::2, :] = c2
        # new_x[:, :, 1::2, 1::2, :] = c3

        return x

In [12]:
D = 1
H = 4
W = 4
C = 1
B = 1
input = torch.arange(0,D*H*W*C*B)
input = torch.reshape(input,(B,D,H,W,C))
print(input[0,:,:,:,0], input.shape)

merged = merge(input)
print(merged.shape)
expanded = expand(merged)
print(expanded.shape)

print(expanded[0,:,:,:,0], expanded.shape)


tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]]) torch.Size([1, 1, 4, 4, 1])
torch.Size([1, 1, 2, 2, 2])
torch.Size([1, 1, 4, 4, 1])
tensor([[[ 0.,  0.,  2.,  2.],
         [ 4.,  4.,  6.,  6.],
         [ 8.,  8., 10., 10.],
         [12., 12., 14., 14.]]]) torch.Size([1, 1, 4, 4, 1])


In [17]:
D = 1
H = 4
W = 4
C = 1
B = 1
input = torch.arange(0,D*H*W*C*B)
input = torch.reshape(input,(B,D,H,W,C))
print(input[0,:,:,:,0], input.shape)

merged = merge(input)
print(merged.shape)
expanded = expand(merged)
print(expanded.shape)

print(expanded[0,:,:,:,0], expanded.shape)


tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]]) torch.Size([1, 1, 4, 4, 1])
torch.Size([1, 1, 2, 2, 2])
torch.Size([1, 1, 4, 4, 1])
tensor([[[ 0,  4,  2,  6],
         [ 0,  4,  2,  6],
         [ 8, 12, 10, 14],
         [ 8, 12, 10, 14]]]) torch.Size([1, 1, 4, 4, 1])


In [14]:

zz = torch.zeros((4,4,2))
zz[:,:,0] = torch.reshape(torch.arange(0,16),(4,4))
zz[:,:,1] = torch.reshape(torch.arange(0,16),(4,4))
print(zz[:,:,1])
print(zz[:,:,0])
zz = torch.moveaxis(zz, -1, 0)
print(zz.shape)
print(zz[0,:,:])
print(zz[1,:,:])

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
torch.Size([2, 4, 4])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])


tensor([[0., 1., 2., 3.],
        [0., 1., 2., 3.]])
tensor([[4., 5., 6., 7.],
        [4., 5., 6., 7.]])
