In [1]:
import os
import pytest
import torch
import numpy as np
from pprint import pprint
from multimodal_particles import config_dir

from multimodal_particles.data.particle_clouds.utils import sizes_to_histograms
from multimodal_particles.utils.experiment_configs import load_config
from multimodal_particles.models.generative.transdimensional import TransdimensionalJumpDiffusion
from multimodal_particles.config_classes.transdimensional_config_unconditional import TransdimensionalEpicConfig
from multimodal_particles.data.particle_clouds.dataloader import JetsGraphicalStructure
from multimodal_particles.models.generative.transdimensional.structure import Structure
from multimodal_particles.data.particle_clouds.jets import JetDataclass
from multimodal_particles.data.particle_clouds.dataloader import MultimodalBridgeDataloaderModule

In [2]:
#obtain configs
config = TransdimensionalEpicConfig()
config.data.return_type = "list"

# create datamodule
jets = JetDataclass(config=config)
jets.preprocess()
dataloader = MultimodalBridgeDataloaderModule(config=config, jetdataset=jets)

INFO: building dataloaders...
INFO: train/val/test split ratios: 0.8/0.2/0.0
INFO: train size: 800, validation size: 200, testing sizes: 0


  discrete = torch.tensor(discrete).long()


In [7]:
databatch = next(dataloader.train.__iter__())

In [10]:
dataloader.name_to_index

{'target_continuous': 0, 'target_discrete': 1, 'target_mask': 2}

In [3]:
dataloader.with_onehot_shapes

[torch.Size([128, 3]), torch.Size([128, 8]), torch.Size([128, 1])]

In [4]:
dataloader.without_onehot_shapes

[torch.Size([128, 3]), torch.Size([128, 1])]

In [6]:
def create_and_apply_mask_3(one_tensor_from_databatch,new_dims_dev,device):
    one_tensor_mask = torch.arange(one_tensor_from_databatch.shape[1], device=device).view(1, -1, 1).repeat(one_tensor_from_databatch.shape[0], 1, one_tensor_from_databatch.shape[2])
    one_tensor_mask = (one_tensor_mask < new_dims_dev.view(-1, 1, 1))
    one_tensor_from_databatch = one_tensor_from_databatch * one_tensor_mask
    return one_tensor_from_databatch,one_tensor_mask

def create_and_apply_mask_2(one_tensor_from_databatch,new_dims_dev,device):
    one_tensor_mask = torch.arange(one_tensor_from_databatch.shape[1], device=device).view(1, -1).repeat(one_tensor_from_databatch.shape[0], 1)
    one_tensor_mask = (one_tensor_mask < new_dims_dev.view(-1, 1))
    one_tensor_from_databatch = one_tensor_from_databatch * one_tensor_mask
    return one_tensor_from_databatch,one_tensor_mask

def remove_problem_dims(self, data, new_dims,name,name_to_index):
    # pos, atom_type, charge, alpha, homo, lumo, gap, mu, Cv = data

    #B = pos.shape[0]
    #assert atom_type.shape == (B, *self.shapes_with_onehot()[1])
    #assert charge.shape == (B, *self.shapes_with_onehot()[2])

    device = data[0].device
    new_dims_dev = new_dims.to(device)

    databatch_with_dimensions_removed = []
    for name_index, name in enumerate(self.names_in_batch):
        if "target_continuous" == name:
            tensor_index = name_to_index["target_continuous"]
            one_tensor_from_databatch = data[tensor_index]
            B = one_tensor_from_databatch.size(0)
            new_tensor = create_and_apply_mask_3(one_tensor_from_databatch,new_dims_dev,device)
            databatch_with_dimensions_removed.append(new_tensor)
            #assert pos.shape == (B, *self.shapes_with_onehot()[0])
        if "target_discrete" == name:
            tensor_index = name_to_index["target_continuous"]
            one_tensor_from_databatch = data[tensor_index]
            B = one_tensor_from_databatch.size(0)
            new_tensor = create_and_apply_mask_3(one_tensor_from_databatch,new_dims_dev,device)
            databatch_with_dimensions_removed.append(new_tensor)        
        if "target_mask" == name:
            tensor_index = name_to_index["target_continuous"]
            one_tensor_from_databatch = data[tensor_index]
            B = one_tensor_from_databatch.size(0)
            new_tensor = create_and_apply_mask_3(one_tensor_from_databatch,new_dims_dev,device)
            databatch_with_dimensions_removed.append(new_tensor)        
        if "context_continuous" == name:
            tensor_index = name_to_index["target_continuous"]
            one_tensor_from_databatch = data[tensor_index]
            B = one_tensor_from_databatch.size(0)
            new_tensor = create_and_apply_mask_3(one_tensor_from_databatch,new_dims_dev,device)
            databatch_with_dimensions_removed.append(new_tensor)
        if "context_discrete" == name:
            tensor_index = name_to_index["target_continuous"]
            one_tensor_from_databatch = data[tensor_index]
            B = one_tensor_from_databatch.size(0)
            new_tensor = create_and_apply_mask_3(one_tensor_from_databatch,new_dims_dev,device)
            databatch_with_dimensions_removed.append(new_tensor)

    return databatch_with_dimensions_removed

In [None]:
import torch

def create_masked_tensor(tensor, new_dims_dev, device):
    # Ensure new_dims_dev is properly expanded to match the batch size of the tensor
    new_dims_expanded = new_dims_dev.view(-1, 1, *(1 for _ in range(tensor.dim() - 2))).expand(tensor.shape[0], -1, *(tensor.shape[2:] if tensor.dim() > 2 else []))
    tensor_mask = torch.arange(tensor.shape[1], device=device).view(1, -1, *(1 for _ in range(tensor.dim() - 2))).expand(tensor.shape[0], -1, *(tensor.shape[2:] if tensor.dim() > 2 else []))
    tensor_mask = tensor_mask < new_dims_expanded
    return tensor * tensor_mask

# Example usage:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
new_dims_dev = torch.tensor([10, 20, 15, 12, 18], device=device)  # Ensure the size matches the batch size

one_tensor_from_databatch = torch.randn(5, 10, 3, device=device)
atom_type = torch.randn(5, 15, 3, device=device)
one_tensor_from_databatch = torch.randn(5, 12, device=device)

masked_pos = create_masked_tensor(one_tensor_from_databatch, new_dims_dev[:one_tensor_from_databatch.shape[0]], device)
masked_atom_type = create_masked_tensor(atom_type, new_dims_dev[:atom_type.shape[0]], device)
masked_charge = create_masked_tensor(one_tensor_from_databatch, new_dims_dev[:one_tensor_from_databatch.shape[0]], device)
