In [2]:
import torch
import numpy as np
import scipy.ndimage

from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap


class Rotate3D:
    def __init__(self, angle, axis):
        """
        angle: rotation angle in degrees
        axis: the axis to rotate around (0 for x, 1 for y, 2 for z)
        """
        self.angle = angle
        self.axis = axis

    def __call__(self, input_tensor, target_tensor):
        """
        Both tensors have shape (C, D, H, W) and should be rotated identically.
        """

        # Convert tensors to numpy (if necessary)
        if isinstance(input_tensor, torch.Tensor):
            input_np = input_tensor.numpy()
        else:
            input_np = input_tensor

        if isinstance(target_tensor, torch.Tensor):
            target_np = target_tensor.numpy()
        else:
            target_np = target_tensor

        # Define correct rotation plane
        axes_map = {0: (1, 2), 1: (0, 2), 2: (0, 1)}
        rotation_axes = axes_map[self.axis]

        # Rotate each channel independently
        rotated_input = np.array(
            [
                scipy.ndimage.rotate(
                    ch, self.angle, axes=rotation_axes, reshape=False, order=1
                )
                for ch in input_np
            ]
        )

        rotated_target = np.array(
            [
                scipy.ndimage.rotate(
                    ch, self.angle, axes=rotation_axes, reshape=False, order=1
                )
                for ch in target_np
            ]
        )

        # Convert back to tensors
        return (
            torch.tensor(rotated_input, dtype=torch.float32),
            torch.tensor(rotated_target, dtype=torch.float32),
        )

In [3]:
from torch.utils.data import DataLoader, Dataset


# Custom Dataset
class customDataset(Dataset):
    def __init__(
        self,
        path,
        name_data,
        name_targets,
        num_samples,
        start_idx,
        transform=None,
    ):
        self.path = path
        self.name_quat = name_data
        self.name_toughness = name_targets
        self.num_samples = num_samples
        self.start_idx = start_idx
        self.data = np.load(path + name_data)[
            self.start_idx : self.start_idx + self.num_samples
        ]
        self.targets = np.load(path + name_targets)[
            self.start_idx : self.start_idx + self.num_samples
        ]
        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        input_img, target_img = self.data[idx], self.targets[idx]

        if self.transform:
            input_img, target_img = self.transform((input_img, target_img))

        return torch.tensor(input_img), torch.tensor(target_img)


transform = None

dataset_dir = "./data/"
for_train = 8000
for_val = 2000

phi = np.load(dataset_dir + "phi_10000_32^3.npy")
phi_avg = np.mean(phi, axis=0)
phi_std = np.std(phi, axis=0)

mytrain_dataset = customDataset(
    path=dataset_dir,
    name_data="quat_10000_32^3.npy",
    name_targets="phi_10000_32^3.npy",
    num_samples=for_train,
    start_idx=0,
    transform=transform,
)
myval_dataset = customDataset(
    path=dataset_dir,
    name_data="quat_10000_32^3.npy",
    name_targets="phi_10000_32^3.npy",
    num_samples=for_val,
    start_idx=for_train,
    transform=transform,
)

mytrain_loader = DataLoader(mytrain_dataset, batch_size=10, shuffle=False)
myval_loader = DataLoader(myval_dataset, batch_size=10, shuffle=False)

for data in mytrain_loader:
    input_img, target_img = data
    break
print(input_img.shape, target_img.shape)

image_idx = 0

input_img, target_img = input_img[image_idx], target_img[image_idx]


In [7]:
import pyvista as pv

import torch
import numpy as np
import scipy.ndimage

#  Define the dimensions of the 3D grid
nx, ny, nz = 32,32,32

# Create a 3D grid of points
x = np.linspace(1, 32, nx)
y = np.linspace(1, 32, ny)
z = np.linspace(1, 32, nz)
x, y, z = np.meshgrid(x, y, z, indexing='ij')

# Define a scalar field (e.g., a Gaussian function)

micro=(input_img[0,:,:,:]).cpu().numpy().squeeze()

output=target_img.cpu().numpy().squeeze()


scalar_field = output

# Create a PyVista ImageData (uniform grid)
grid = pv.ImageData(dimensions=(nx, ny, nz), spacing=(1, 1, 1), origin=(0, 0, 0))

# Assign the scalar field to the grid
grid.point_data["scalar_field"] = scalar_field.flatten(order="F")

# Plot the scalar field as a discrete voxel grid
plotter = pv.Plotter()

# Use add_volume with custom opacity mapping to avoid interpolation
plotter.add_volume(
    grid,
    cmap="jet",  # Use a colormap
    clim=[0.2,1],
    shade=False,  # Disable shading to avoid interpolation
    opacity=1,  # Use linear opacity for discrete transitions
    opacity_unit_distance=0,  # Ensure no smoothing between voxels
)

plotter.show_grid()
plotter.show()

Widget(value='<iframe src="http://localhost:43691/index.html?ui=P_0x7fd875c99400_0&reconnect=auto" class="pyvi…

In [None]:
import pyvista as pv
import torch
import torch.nn as nn

#  Define the dimensions of the 3D grid
nx, ny, nz = 32,32,32

# Create a 3D grid of points
x = np.linspace(1, 32, nx)
y = np.linspace(1, 32, ny)
z = np.linspace(1, 32, nz)
x, y, z = np.meshgrid(x, y, z, indexing='ij')

# Define a scalar field (e.g., a Gaussian function)

micro=(input_img[0,:,:,:]).cpu().numpy().squeeze()


class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_c, out_c, 3, padding=1, padding_mode="circular"),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_c, out_c, 3, padding=1, padding_mode="circular"),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


## DO NOT USE, USE UNet3D_upsample_periodic_conv INSTEAD
class UNet3D_conv_transpose(nn.Module):
    def __init__(self, encoder_channels=[16, 32, 64]):
        super().__init__()
        self.encoder_channels = encoder_channels

        # Encoder
        self.encoders = nn.ModuleList()
        current_in = 4  # Input channels
        for out_c in encoder_channels:
            self.encoders.append(DoubleConv(current_in, out_c))
            current_in = out_c

        # Pooling layers
        self.pools = nn.ModuleList([nn.MaxPool3d(2) for _ in encoder_channels])

        # Bottleneck
        self.bottleneck = DoubleConv(encoder_channels[-1], 2 * encoder_channels[-1])

        # Decoder
        self.up_layers = nn.ModuleList()
        self.decoder_convs = nn.ModuleList()
        reversed_channels = encoder_channels[::-1]

        # Create decoder components
        up_in = [2 * reversed_channels[0]] + reversed_channels[:-1]
        for in_c, out_c in zip(up_in, reversed_channels):
            self.up_layers.append(nn.ConvTranspose3d(in_c, out_c, 2, 2))
            self.decoder_convs.append(DoubleConv(2 * out_c, out_c))

        # Final output
        self.final = nn.Conv3d(reversed_channels[-1], 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Encoder path
        skips = []
        for enc, pool in zip(self.encoders, self.pools):
            x = enc(x)
            skips.append(x)
            x = pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path
        for up, dec_conv, skip in zip(
            self.up_layers, self.decoder_convs, reversed(skips)
        ):
            x = up(x)
            x = torch.cat([x, skip], 1)
            x = dec_conv(x)

        return self.sigmoid(self.final(x))


class UpConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv3d(in_c, out_c, 3, padding=1, padding_mode="circular"),
        )

    def forward(self, x):
        return self.up(x)


class UNet3D_upsample_periodic_conv(nn.Module):
    def __init__(self, encoder_channels=[16, 32, 64]):
        super().__init__()
        self.encoder_channels = encoder_channels

        # Encoder
        self.encoders = nn.ModuleList()
        current_in = 4
        for out_c in encoder_channels:
            self.encoders.append(DoubleConv(current_in, out_c))
            current_in = out_c

        self.pools = nn.ModuleList([nn.MaxPool3d(2) for _ in encoder_channels])
        self.bottleneck = DoubleConv(encoder_channels[-1], 2 * encoder_channels[-1])

        # Decoder
        self.up_layers = nn.ModuleList()
        self.decoder_convs = nn.ModuleList()
        reversed_channels = encoder_channels[::-1]

        for i, out_c in enumerate(reversed_channels):
            in_c = 2 * reversed_channels[0] if i == 0 else reversed_channels[i - 1]
            self.up_layers.append(UpConv(in_c, out_c))
            self.decoder_convs.append(DoubleConv(2 * out_c, out_c))

        self.final = nn.Conv3d(reversed_channels[-1], 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        skips = []
        # Encoder path
        for enc, pool in zip(self.encoders, self.pools):
            x = enc(x)
            skips.append(x)
            x = pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path
        for up, dec_conv, skip in zip(
            self.up_layers, self.decoder_convs, reversed(skips)
        ):
            x = up(x)
            x = torch.cat([x, skip], 1)
            x = dec_conv(x)


        return self.final(x)


model_dir='/models/'
model = UNet3D_upsample_periodic_conv(encoder_channels=[256, 512, 1024])
pretrained_dict=torch.load(model_dir+'scalar_field_predictor_epoch_44.pth', weights_only=False, map_location='cpu')

model.load_state_dict(pretrained_dict, strict=False)

model.eval()

output = model(input_img.unsqueeze(0))

scalar_field = (output.detach().cpu().numpy().squeeze()*phi_std) + phi_avg

# Create a PyVista ImageData (uniform grid)
grid = pv.ImageData(dimensions=(nx, ny, nz), spacing=(1, 1, 1), origin=(0, 0, 0))

# Assign the scalar field to the grid
grid.point_data["-"] = scalar_field.flatten(order="F")

# Plot the scalar field as a discrete voxel grid
plotter = pv.Plotter()

# Use add_volume with custom opacity mapping to avoid interpolation
plotter.add_volume(
    grid,
    cmap='jet',  # Use a colormap
    clim=[0.2,1],
    shade=False,  # Disable shading to avoid interpolation
    opacity=1,  # Use linear opacity for discrete transitions
    opacity_unit_distance=0,  # Ensure no smoothing between voxels
)

plotter.show_grid()
plotter.show()


Widget(value='<iframe src="http://localhost:43691/index.html?ui=P_0x7fd875c995b0_1&reconnect=auto" class="pyvi…

In [None]:
model_sst = UNet3D_upsample_periodic_conv(encoder_channels=[256, 512, 1024])
pretrained_dict=torch.load(model_dir+'scalar_field_predictor_epoch_18.pth', weights_only=False, map_location='cpu')

model_sst.load_state_dict(pretrained_dict, strict=False)

model_sst.eval()

output = model_sst(input_img.unsqueeze(0))

scalar_field = (output.detach().cpu().numpy().squeeze()*phi_std) + phi_avg

# Create a PyVista ImageData (uniform grid)
grid = pv.ImageData(dimensions=(nx, ny, nz), spacing=(1, 1, 1), origin=(0, 0, 0))

# Assign the scalar field to the grid
grid.point_data["-"] = scalar_field.flatten(order="F")

# Plot the scalar field as a discrete voxel grid
plotter = pv.Plotter()

# Use add_volume with custom opacity mapping to avoid interpolation
plotter.add_volume(
    grid,
    cmap='jet',  # Use a colormap
    clim=[0.2,1],
    shade=False,  # Disable shading to avoid interpolation
    opacity=1,  # Use linear opacity for discrete transitions
    opacity_unit_distance=0,  # Ensure no smoothing between voxels
)

plotter.show_grid()
plotter.show()

Widget(value='<iframe src="http://localhost:43691/index.html?ui=P_0x7fd875c92640_3&reconnect=auto" class="pyvi…

In [14]:
scalar_field = micro

# Create a PyVista ImageData (uniform grid)
grid = pv.ImageData(dimensions=(nx, ny, nz), spacing=(1, 1, 1), origin=(0, 0, 0))

# Assign the scalar field to the grid
grid.point_data["scalar_field"] = scalar_field.flatten(order="F")

# Plot the scalar field as a discrete voxel grid
plotter = pv.Plotter()

# Use add_volume with custom opacity mapping to avoid interpolation
plotter.add_volume(
    grid,
    cmap="gist_rainbow",  # Use a colormap
    shade=False,  # Disable shading to avoid interpolation
    opacity=1,  # Use linear opacity for discrete transitions
    opacity_unit_distance=0,  # Ensure no smoothing between voxels
)

plotter.show_grid()
plotter.show()


Widget(value='<iframe src="http://localhost:43691/index.html?ui=P_0x7fd7dfc8c5b0_7&reconnect=auto" class="pyvi…