In [14]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [15]:
import numpy as np
import matplotlib.pyplot as plt

In [23]:
import torch
from torch.utils.data import Dataset

In [24]:
class BatchesDataset(Dataset):
    def __init__(self, batches_file_paths, batch_size):
        self.batches_file_paths = batches_file_paths
        self.batch_size = batch_size

    def __len__(self):
        return np.ceil(np.load(list(self.batches_file_paths.values())[0], mmap_mode='r').shape[0] / self.batch_size).astype(np.int32)

    def __getitem__(self, idx):
        data = {k: np.copy(np.load(bf, mmap_mode='r')[idx * self.batch_size: (idx + 1) * self.batch_size])
                for k, bf in self.batches_file_paths.items()}
        return data

In [25]:
class RandomCoordinateDataset(Dataset):
    def __init__(self, cube_shape, spatial_norm, batch_size, buffer=None):
        super().__init__()
        cube_shape = np.array([[0, cube_shape[0] - 1], [0, cube_shape[1] - 1], [0, cube_shape[2] - 1]])
        if buffer:
            buffer_x = (cube_shape[0, 1] - cube_shape[0, 0]) * buffer
            buffer_y = (cube_shape[1, 1] - cube_shape[1, 0]) * buffer
            cube_shape[0, 0] -= buffer_x
            cube_shape[0, 1] += buffer_x
            cube_shape[1, 0] -= buffer_y
            cube_shape[1, 1] += buffer_y
        self.cube_shape = cube_shape
        self.spatial_norm = spatial_norm
        self.batch_size = batch_size
        self.float_tensor = torch.FloatTensor

    def __len__(self):
        return 1

    def __getitem__(self, item):
        random_coords = self.float_tensor(self.batch_size, 3).uniform_()
        random_coords[:, 0] = (random_coords[:, 0] * (self.cube_shape[0, 1] - self.cube_shape[0, 0]) + self.cube_shape[0, 0])
        random_coords[:, 1] = (random_coords[:, 1] * (self.cube_shape[1, 1] - self.cube_shape[1, 0]) + self.cube_shape[1, 0])
        random_coords[:, 2] = (random_coords[:, 2] * (self.cube_shape[2, 1] - self.cube_shape[2, 0]) + self.cube_shape[2, 0])
        random_coords /= self.spatial_norm
        return random_coords

In [26]:
batches_path = {'coords': 'coords.npy',
                'values': 'values.npy'}
boundary_batch_size = 10000

cube_shape = [513, 257, 257]
spatial_norm = 320
random_batch_size = 20000
buffer=None

In [27]:
dataset = BatchesDataset(batches_path, boundary_batch_size)
random_dataset = RandomCoordinateDataset(cube_shape, spatial_norm, random_batch_size, buffer=buffer)

In [29]:
dataset[0]

{'coords': array([[  0.,   0.,   0.],
        [  0.,   1.,   0.],
        [  0.,   2.,   0.],
        ...,
        [ 38., 231.,   0.],
        [ 38., 232.,   0.],
        [ 38., 233.,   0.]], dtype=float32),
 'values': array([[ 27.81,  48.85, -13.17],
        [ -1.57, -78.04,  33.92],
        [-25.44,  96.69, -57.32],
        ...,
        [ 92.06,  20.57,  46.14],
        [  0.78, -76.94,  19.8 ],
        [-40.19, -61.35,  -7.04]], dtype=float32)}

In [30]:
random_dataset[0]

tensor([[0.6744, 0.3834, 0.5365],
        [0.0892, 0.4419, 0.7103],
        [1.4901, 0.5261, 0.7151],
        ...,
        [0.7834, 0.3754, 0.5140],
        [0.3527, 0.2094, 0.6267],
        [0.7048, 0.7431, 0.1538]])

In [31]:
from torch.utils.data import DataLoader, RandomSampler

In [32]:
iterations = 10000

In [33]:
data_loader = DataLoader(dataset, batch_size=None, num_workers=8, pin_memory=True,
                         sampler=RandomSampler(dataset, replacement=True, num_samples=iterations))
random_loader = DataLoader(random_dataset, batch_size=None, num_workers=8, pin_memory=True,
                           sampler=RandomSampler(dataset, replacement=True, num_samples=iterations))

In [40]:
next(iter(data_loader))

{'coords': tensor([[311.,  73.,   0.],
         [311.,  74.,   0.],
         [311.,  75.,   0.],
         ...,
         [350.,  47.,   0.],
         [350.,  48.,   0.],
         [350.,  49.,   0.]]),
 'values': tensor([[ 393.0100, -579.0800,  402.1200],
         [ 468.9400, -517.9100,  375.7600],
         [ 528.1800, -517.6700,  252.9300],
         ...,
         [ -58.9500, -115.0400,   -7.7000],
         [  38.3600,   94.1800,   -8.7200],
         [  43.9000,  115.1900,  -17.6400]])}

In [41]:
next(iter(random_loader))

tensor([[0.3767, 0.5381, 0.3489],
        [0.7747, 0.1157, 0.5143],
        [0.2035, 0.3301, 0.2012],
        ...,
        [0.8248, 0.5338, 0.4677],
        [0.7827, 0.4453, 0.7107],
        [0.0829, 0.1674, 0.7622]])

In [42]:
from torch import nn

In [43]:
class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)

class BModel(nn.Module):

    def __init__(self, in_coords, out_values, dim, pos_encoding=False):
        super().__init__()
        if pos_encoding:
            posenc = PositionalEncoding(8, 20)
            d_in = nn.Linear(in_coords * 40, dim)
            self.d_in = nn.Sequential(posenc, d_in)
        else:
            self.d_in = nn.Linear(in_coords, dim)
        lin = [nn.Linear(dim, dim) for _ in range(8)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(dim, out_values)
        self.activation = Sine()  # torch.tanh

    def forward(self, x):
        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        x = self.d_out(x)
        return x

In [44]:
dim = 256
model = BModel(3, 3, dim)