In [1]:
import torch
from soma import aims
# import anatomist.api as ana
from betaVAE.beta_vae import VAE
from betaVAE.visualisation_anatomist import adjust_in_shape
from pathlib import Path
from betaVAE.preprocess import UkbDataset
from torch.utils.data import DataLoader

ROOT_SAVE = Path("/neurospin/dico/tsanchez/tmp")

PATH_EXP = Path("/neurospin/dico/tsanchez/Test_BetaVAE/2025-05-27/16-52-59_")
PATH_MODEL = PATH_EXP / "checkpoint.pt"
device = "cuda:0"

N_LATENT = 1024
DEPTH = 3

IN_SHAPE_WOUT_ADJUST = [1, 54, 120, 139] #One from the config.yaml file
IN_SHAPE = adjust_in_shape(IN_SHAPE_WOUT_ADJUST, depth=DEPTH)

CONFIG = {
    "in_shape" : IN_SHAPE,
    "root" : "/neurospin/dico/tsanchez/preprocessed/UKBio1000"
}

In [2]:
from torchsummary import summary
model = VAE(
    in_shape=IN_SHAPE, 
    n_latent=N_LATENT,
    depth=DEPTH
)
summary(model.cuda(), tuple(IN_SHAPE))

state_dict = torch.load(PATH_MODEL)
print(state_dict.keys())
model.load_state_dict(state_dict)
model.to(device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [-1, 16, 60, 120, 144]             448
       BatchNorm3d-2     [-1, 16, 60, 120, 144]              32
         LeakyReLU-3     [-1, 16, 60, 120, 144]               0
            Conv3d-4       [-1, 16, 30, 60, 72]          16,400
       BatchNorm3d-5       [-1, 16, 30, 60, 72]              32
         LeakyReLU-6       [-1, 16, 30, 60, 72]               0
            Conv3d-7       [-1, 32, 30, 60, 72]          13,856
       BatchNorm3d-8       [-1, 32, 30, 60, 72]              64
         LeakyReLU-9       [-1, 32, 30, 60, 72]               0
           Conv3d-10       [-1, 32, 15, 30, 36]          65,568
      BatchNorm3d-11       [-1, 32, 15, 30, 36]              64
        LeakyReLU-12       [-1, 32, 15, 30, 36]               0
           Conv3d-13       [-1, 64, 15, 30, 36]          55,360
      BatchNorm3d-14       [-1, 64, 15,

  state_dict = torch.load(PATH_MODEL)


odict_keys(['encoder.conv0.weight', 'encoder.conv0.bias', 'encoder.norm0.weight', 'encoder.norm0.bias', 'encoder.norm0.running_mean', 'encoder.norm0.running_var', 'encoder.norm0.num_batches_tracked', 'encoder.conv0a.weight', 'encoder.conv0a.bias', 'encoder.norm0a.weight', 'encoder.norm0a.bias', 'encoder.norm0a.running_mean', 'encoder.norm0a.running_var', 'encoder.norm0a.num_batches_tracked', 'encoder.conv1.weight', 'encoder.conv1.bias', 'encoder.norm1.weight', 'encoder.norm1.bias', 'encoder.norm1.running_mean', 'encoder.norm1.running_var', 'encoder.norm1.num_batches_tracked', 'encoder.conv1a.weight', 'encoder.conv1a.bias', 'encoder.norm1a.weight', 'encoder.norm1a.bias', 'encoder.norm1a.running_mean', 'encoder.norm1a.running_var', 'encoder.norm1a.num_batches_tracked', 'encoder.conv2.weight', 'encoder.conv2.bias', 'encoder.norm2.weight', 'encoder.norm2.bias', 'encoder.norm2.running_mean', 'encoder.norm2.running_var', 'encoder.norm2.num_batches_tracked', 'encoder.conv2a.weight', 'encoder.

VAE(
  (encoder): Sequential(
    (conv0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (norm0): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (LeakyReLU0): LeakyReLU(negative_slope=0.01)
    (conv0a): Conv3d(16, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (norm0a): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (LeakyReLU0a): LeakyReLU(negative_slope=0.01)
    (conv1): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (norm1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (LeakyReLU1): LeakyReLU(negative_slope=0.01)
    (conv1a): Conv3d(32, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (norm1a): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (LeakyReLU1a): LeakyReLU(negative_slope=0.01)
    (conv2): Conv3d(32, 64, kernel_size=(3, 3

In [3]:
dataset = UkbDataset(CONFIG)
loader = DataLoader(dataset=dataset, shuffle=True, batch_size=5)
split_batch, full_batch ,list_sub = next(iter(loader))

In [4]:
batch = split_batch[:,0,:,:,:].unsqueeze(1)
split_batch.shape, batch.shape

(torch.Size([5, 2, 60, 120, 144]), torch.Size([5, 1, 60, 120, 144]))

In [5]:
batch = batch.to(device, dtype = torch.float32)
batch.shape

torch.Size([5, 1, 60, 120, 144])

In [6]:
output_batch = model(batch)

In [7]:
torch.save((batch, list_sub, output_batch), ROOT_SAVE / "proba_test.pt")

In [None]:
batch, list_sub, output_batch = torch.load(ROOT_SAVE / "proba_test.pt")
batch = batch.to(device = "cpu")
out_proba = output_batch[0].to(device = "cpu")
z = output_batch[1].to(device = "cpu")
logvar = output_batch[2].to(device = "cpu")

  batch, list_sub, output_batch = torch.load(ROOT_SAVE / "proba_test.pt")


In [9]:
torch.max(out_proba[0]), torch.min(out_proba[0])

(tensor(0.3260, grad_fn=<MaxBackward1>),
 tensor(-0.5988, grad_fn=<MinBackward1>))

In [10]:
import anatomist.api as ana
from soma import aims
anatomist = ana.Anatomist()
win = anatomist.createWindow("3D")

create qapp
done
Starting Anatomist.....
config file : /home/ts283124/.anatomist/config/settings.cfg
PyAnatomist Module present
PythonLauncher::runModules()


existing QApplication: 0


global modules: /neurospin/dico/tsanchez/2025_tsanchez_cerrebellum/soma-env/build/share/anatomist-5.2/python_plugins
home   modules: /home/ts283124/.anatomist/python_plugins
loading module gltf_io
loading module palettecontrols
loading module paletteViewer
loading module meshsplit
loading module profilewindow
loading module ana_image_math
loading module anacontrolmenu
loading module foldsplit
loading module modelGraphs
loading module bsa_proba
loading module histogram
loading module gradientpalette
loading module infowindow
loading module simple_controls
loading module volumepalettes
loading module statsplotwindow
loading module save_resampled
loading module valuesplotwindow
loading module selection
all python modules loaded
Anatomist started.
Multitexturing present
function glActiveTexture found.
function glClientActiveTexture found.
function glBlendEquation found.
function glTexImage3D found.
function glMultiTexCoord3f found.
function glBindFramebuffer found.
function glBindRenderbuf

In [11]:
# Input volume
input_vol = aims.Volume(batch[0].squeeze(0).numpy())
ana_input = anatomist.toAObject(input_vol)



In [None]:
from torch.nn.functional import sigmoid
sample = out_proba[0]
print(sample.shape)
proba = sigmoid(sample)

torch.Size([1, 60, 120, 144])


In [13]:
to_plot = proba[0]
print(to_plot.shape)
proba_vol = aims.Volume(to_plot.detach().numpy())
proba_ana = anatomist.toAObject(proba_vol)

torch.Size([60, 120, 144])


no position could be read at 213, 32
Exiting QApplication
