In [1]:
import torch
import numpy as np
from utils import sample_indices

In [2]:
def indices_to_mask(indices, size=128):
    mask = torch.zeros(size, dtype=torch.bool)
    mask[indices] = True
    return mask

In [4]:
batch_size = 3

In [5]:
batch = torch.rand(batch_size, 128, 128, 128)

In [6]:
for i, item in enumerate(batch):
    print(item.shape)
    print(torch.all(item == batch[i]))

torch.Size([128, 128, 128])
tensor(True)
torch.Size([128, 128, 128])
tensor(True)
torch.Size([128, 128, 128])
tensor(True)


In [57]:
condition, target, encoding = sample_indices()

In [58]:
condition

[80, 71, 61, 17, 74, 78, 83, 103, 52]

In [59]:
target

[49]

In [60]:
condition_mask = indices_to_mask(condition)
target_mask = indices_to_mask(target)

In [61]:
combined_mask = condition_mask | target_mask

In [66]:
batch[0, ~combined_mask].shape

torch.Size([118, 128, 128])

OLD DESIGN - TOO MANY INPUT/OUTPUT CHANNELS

In [None]:
input_tensor = torch.zeros_like(batch)
timesteps_tensor = torch.zeros(batch_size)
encoding_tensor = torch.zeros(batch_size, 128)

for i, image in enumerate(batch):
    condition, target, encoding = sample_indices()
    
    condition_mask = indices_to_mask(condition)
    target_mask = indices_to_mask(target)
    combined_mask = condition_mask | target_mask

    image[target_mask] = noise_scheduler.add_noise(image[target_mask], noise, timesteps)
    image[~combined_mask] = 0

    input_tensor[i] = image
    encoding_tensor[i] = encoding


In [75]:
torch.randint(0, 128, ())

tensor(110)

NEW DESIGN

May need to change encoding to be a 20 length vector of indices

Or should also encode whether they are condition or target - so concatenate both to make 40 length vector of indices? Seems wrong but we should just try it

In [7]:
for i, image in enumerate(batch):
    condition, target = sample_indices()

    num_slices = len(condition) + len(target)

    condition_mask = indices_to_mask(condition)
    target_mask = indices_to_mask(target)

    print(image[condition_mask].shape, image[target_mask].shape, num_slices)

    target_slices = image[target_mask]
    condition_slices = image[condition_mask]

    print(torch.cat((target_slices, condition_slices)).shape)

torch.Size([0, 128, 128]) torch.Size([1, 128, 128]) 1
torch.Size([1, 128, 128])
torch.Size([8, 128, 128]) torch.Size([3, 128, 128]) 11
torch.Size([11, 128, 128])
torch.Size([4, 128, 128]) torch.Size([2, 128, 128]) 6
torch.Size([6, 128, 128])


In [11]:
len(condition + [0] * (20 - len(condition)) + target + [0] * (20 - len(target)))

40

NEWER IDEA

Input size is 16

You always have 8 target slices
Sometimes you have no condition slices, otherwise you also have 8

For the encoding, you have a two-hot encoding

Keep target and condition slices ordered... or maybe more robust if you don't do this

In [19]:
import random

In [25]:
random.random()

0.2184232158997489

In [1]:
from utils import sample_16_indices

In [9]:
sample_16_indices()

([54, 78, 29, 80, 30, 44, 90, 31],
 [],
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.]))

In [15]:
torch.randint(0, 128, (3,))

tensor([ 39,  66, 125])