In [13]:
import torch

In [296]:
chunk_duration = 500
num_channels = 4

In [297]:
import matplotlib.pyplot as plt

In [1]:
import imageio
import numpy as np

In [299]:
mask_path = "../data/temp_annotation_masks_new/videos/01_video_mask.tif"

In [300]:
mask = np.asarray(imageio.volread(mask_path))

In [301]:
mask.shape

(500, 64, 512)

In [302]:
mask_chunk = mask[:chunk_duration]

In [303]:
mask_chunk.shape

(500, 64, 512)

In [304]:
def get_new_voxel_label(voxel_seq):
    # voxel_seq is a vector of 'num_channels' elements
    # {0} -> 0
    # {0, i}, {i} -> i, i = 1,2,3
    # {0, 1, i}, {1, i} -> 1, i = 2,3
    # {0, 2 ,3}, {2, 3} -> 3
    #print(voxel_seq)

    if np.max(voxel_seq == 0):
        return 0
    elif 1 in voxel_seq:
        return 1
    elif 3 in voxel_seq:
        return 3
    else:
        return np.max(voxel_seq)

In [305]:
def shrink_mask(mask, num_channels):
    # input is an annotation mask with the number of channels of the unet
    # output is a shrinked mask where :
    # {0} -> 0
    # {0, i}, {i} -> i, i = 1,2,3
    # {0, 1, i}, {1, i} -> 1, i = 2,3
    # {0, 2 ,3}, {2, 3} -> 2
    # and each voxel in the output corresponds to 'num_channels' voxels in
    # the input

    assert mask.shape[0] % num_channels == 0, \
    "in shrink_mask the duration of the mask is not a multiple of num_channels"

    # get subtensor of duration 'num_channels'
    sub_masks = np.split(mask, mask.shape[0]//num_channels)
    
    #print(sub_masks[0].shape)
    #print(len(sub_masks))

    new_mask = []
    # for each subtensor get a single frame
    for sub_mask in sub_masks:
        new_frame = np.array([[get_new_voxel_label(sub_mask[:,y,x]) for x in range(sub_mask.shape[2])] for y in range(sub_mask.shape[1])])
        #print(new_frame.shape)
        new_mask.append(new_frame)

    new_mask = np.stack(new_mask)
    return new_mask

In [306]:
new_mask = shrink_mask(mask_chunk, num_channels)

In [307]:
new_mask.shape

(125, 64, 512)

In [308]:
np.unique(new_mask)

array([0, 1, 3, 4])

In [310]:
imageio.volwrite("test_new_mask.tif", np.uint8(new_mask))

In [6]:
a = np.array([[[1,2],[1,2],[1,2]],[[1,2],[1,2],[1,2]],[[1,2],[1,2],[1,2]]])

In [9]:
np.pad(a,((0,),(0,),(0,)), mode='constant').shape

(3, 3, 2)

In [8]:
a.shape

(3, 3, 2)