In [1]:
from torch.utils.data import Dataset, DataLoader

class Colored3DMNIST(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        
        if self.transform:
            x = Image.fromarray(self.data[index].astype(np.float64))
            x = self.transform(x)
        
        return x
    
    def __len__(self):
        return len(self.data)

In [2]:
import numpy as np

train_a_np = np.load("x_train_a.npy")
train_b_np = np.load("x_train_b.npy")
test_np = np.load("x_test.npy")

In [3]:
train_a = Colored3DMNIST(train_a_np)
train_a_loader = DataLoader(train_a, batch_size=32)

train_b = Colored3DMNIST(train_b_np)
train_b_loader = DataLoader(train_b, batch_size=32)

test = Colored3DMNIST(test_np)
test_loader = DataLoader(test, batch_size=32)

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import pandas as pd

%matplotlib inline

#@dataclass
class data_config:
    threshold = 0.2
    upper = 1 
    lower = 0
    x_shape = 16
    y_shape = 16
    z_shape = 16  
    

def plot_img(image, plot_only_figure = True, color_ind = 0):
    plot_img_3d = np.squeeze(image[color_ind]) # 0 - red, 1 - yellow and green, 2 - blue
    plot_label = "3"
    color = np.zeros(3)
    plot_data = []
    for x in range(data_config.x_shape):
        for y in range(data_config.y_shape):
            for z in range(data_config.z_shape):
                isVoxel = (plot_img_3d[x, y, z] > 0)
                if isVoxel:
                    color[0] = int((image[0, x, y, z] + 1) * 255)
                    color[1] = int((image[1, x, y, z] + 1) * 255)
                    color[2] = int((image[2, x, y, z] + 1) * 255)
                else:
                    color[0] = 255
                    color[1] = 255
                    color[2] = 255             
                plot_data.append([x, y, z, isVoxel, f'rgb({color[0]}, {color[1]}, {color[2]})'])
    plot_df = pd.DataFrame(plot_data, columns=["x", "y", "z", "isVoxel", "color"])
    if plot_only_figure:
        plot_df = plot_df.loc[plot_df["isVoxel"] == True]

    fig = go.Figure(data=[go.Scatter3d(x=plot_df['x'], y=plot_df['y'], z=plot_df['z'], 
                                       mode='markers',
                                       text=f"current label: {plot_label}",
                                       marker=dict(
                                       color = plot_df["color"],
                                       size=5,       
                                       colorscale='Viridis',
                                       opacity= 0.8 ))])
    fig.show()

plot_img(next(iter(test_loader))[10], color_ind=1)