In [85]:
import os

new_directory = '/home/franciscoperez/Documents/GitHub/CNN-PELSVAE2/cnn-pels-vae/'
os.chdir(new_directory)

import yaml
import pickle
import torch
import numpy as np
import src.utils as utils
from typing import Union, Tuple, Optional, Any, Dict, List
import src.gmm.modifiedgmm as mgmm
from src.utils import load_pp_list


In [86]:
with open('src/paths.yaml', 'r') as file:
    file = yaml.safe_load(file)
PATHS = file['paths']
PATH_MODELS =PATHS['PATH_MODELS']
PATH_DATA = PATHS['PATH_DATA_FOLDER']

PATH_MODELS, PATH_DATA

('/home/franciscoperez/Documents/GitHub/CNN-PELSVAE2/cnn-pels-vae/models/',
 '/home/franciscoperez/Documents/GitHub/CNN-PELSVAE2/cnn-pels-vae/data')

In [87]:
with open(PATH_MODELS+'label_encoder_vae.pkl', 'rb') as f:
    label_encoder = pickle.load(f)
print("Classes in Label Encoder:", label_encoder.classes_)

Classes in Label Encoder: ['ACEP' 'CEP' 'DSCT' 'ECL' 'ELL' 'LPV' 'RRLYR' 'T2CEP']


In [88]:
with open('src/nn_config.yaml', 'r') as file:
    nn_config = yaml.safe_load(file)

In [89]:
with open('src/gmm/priors.yaml', 'r') as file:
    mean_prior_dict = yaml.safe_load(file)

In [90]:
with open('src/regressor.yaml', 'r') as file:
    config_file: Dict[str, Any] = yaml.safe_load(file)
vae_model: str =   config_file['model_parameters']['ID']  
print('Using vae model: '+ vae_model)
PP_list = load_pp_list(vae_model)
PP_list

Using vae model: 16f09v2s


['Period', 'teff_val', '[Fe/H]_J95', 'abs_Gmag', 'radius_val', 'logg']

In [98]:
plot_example=False
b=1.0
wandb_active=False
samples_dict = None
lb = []
n_samples = 8
PP = PP_list
priors = True
sinthetic_samples_by_class = 8
dict_priorization 

In [99]:
from typing import Union, Tuple, Optional, Any, Dict, List
import numpy as np
import pickle
import yaml
import torch
from torch.utils.data import DataLoader, TensorDataset
import src.utils as utils
import src.gmm.modifiedgmm as mgmm
import src.sampler.fit_regressor as reg
import matplotlib.pyplot as plt
from src.sampler.LightCurveRandomSampler import LightCurveRandomSampler

gpu: bool = True 
with open('src/nn_config.yaml', 'r') as file:
    nn_config = yaml.safe_load(file)

class SyntheticDataBatcher:
    def __init__(self, config_file_path: str = 'src/regressor.yaml', 
                 nn_config_path: str = 'src/nn_config.yaml', paths: str = 'src/paths.yaml', PP=[], vae_model=None, 
                 n_samples=16, seq_length = 100, batch_size=128, prior=False):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.config_file = self.load_yaml(config_file_path)
        self.nn_config = self.load_yaml(nn_config_path)
        self.path = self.load_yaml(paths)['paths']
        self.mean_prior_dict = self.load_yaml(self.path['PATH_PRIOS'])  # to be filled in later
        self.priors = prior
        self.PP = PP
        self.vae_model = vae_model
        self.n_samples = n_samples
        self.seq_length = seq_length
        self.delta_max = 100
        self.CLASSES = ['ACEP','CEP', 'DSCT', 'ECL',  'ELL', 'LPV',  'RRLYR', 'T2CEP']
        self.batch_size = batch_size
        self.x_array = None
        self.y_array = None

    @staticmethod
    def load_yaml(path: str) -> Dict[str, Any]:
        with open(path, 'r') as file:
            return yaml.safe_load(file)

    def construct_model_name(self, star_class: str, base_path: str = 'PATH_MODELS'):
        """Construct a model name given parameters."""
        file_name = f"{base_path}bgm_model_{str(star_class)}_priors_{self.priors}_PP_{len(self.PP)}.pkl"
        return file_name

    @staticmethod
    def count_subclasses(star_type_data: Dict[str, Any]) -> int:
        excluded_keys = ['CompleteName', 'min_period', 'max_period']
        return len([key for key in star_type_data.keys() if key not in excluded_keys])

    def process_in_batches(self, model, mu_, times, onehot, phy, batch_size):
        # Split tensors into smaller batches
        total_samples = mu_.size(0)
        n_batches = (total_samples + batch_size - 1) // batch_size

        results = []
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, total_samples)

            mu_batch = mu_[start_idx:end_idx]
            times_batch = times[start_idx:end_idx]
            onehot_batch = onehot[start_idx:end_idx]
            phy_batch = phy[start_idx:end_idx]

            xhat_mu_batch = model.decoder(mu_batch, times_batch, label=onehot_batch, phy=phy_batch)
            results.append(xhat_mu_batch)
            del xhat_mu_batch
            torch.cuda.empty_cache()

        # Concatenate results from all batches
        xhat_mu = torch.cat(results, dim=0)
        return xhat_mu

    def attempt_sample_load(self, model_name: str, sampler: 'YourSamplerType', n_samples=nn_config['training']['sinthetic_samples_by_class']) -> Tuple[Union[np.ndarray, None], bool]:
        try:
            samples = sampler.modify_and_sample(model_name, n_samples=n_samples, 
                                                mode= nn_config['sampling']['mode'])
            return samples, True
        except Exception as e:
            raise Exception(f"Failed to load samples from model {model_name}. Error: {str(e)}")

    def create_time_sequences(self, lb, period, based_on_real_lc = True):
        np.set_printoptions(suppress=True)
        print(lb)
        print(period)
        if based_on_real_lc:
            times, original_sequences =  utils.get_only_time_sequence(n=1, star_class=lb, 
                                                                     period = period, factor1=0.8, 
                                                                     factor2= 1.2)
            times = np.array(times) 
            original_sequences = np.array(original_sequences) 
            times = torch.from_numpy(times).to(self.device)
            times = times.to(dtype=torch.float32)
        else: 
            times = [i/600 for i in range(600)]
            times = np.tile(times, (self.n_samples*len(list(self.nn_config['data']['classes'])), 1))
            times = np.array(times)  
            times = torch.from_numpy(times).to(self.device)
            times = times.to(dtype=torch.float32)
            original_sequences = None #TODO
        
        return times, original_sequences

    def create_synthetic_batch(self, plot_example=False, b=1.0, wandb_active=False, samples_dict = None):
        print(self.path)
        PATH_MODELS = self.path['PATH_MODELS']
        PATH_DATA = self.path['PATH_DATA_FOLDER']
        lb = []

        with open(PATH_MODELS+'label_encoder_vae.pkl', 'rb') as f:
            label_encoder = pickle.load(f)

        for star_class in list(self.nn_config['data']['classes']):
            torch.cuda.empty_cache()
            print('------- sampling ' +star_class+'---------')
            
            if samples_dict==None:
                n_samples = self.n_samples
                lb += [star_class] * self.n_samples
            else: 
                n_samples = int(samples_dict[star_class])
                lb += [star_class] * n_samples
            
            print(samples_dict)


            integer_encoded = label_encoder.transform(lb)
            n_values = len(label_encoder.classes_)
            onehot = np.eye(n_values)[integer_encoded]

            encoded_labels, _ = utils.transform_to_consecutive(integer_encoded, label_encoder)
            n_values = len(np.unique(encoded_labels))
            onehot_to_train = np.eye(n_values)[encoded_labels]

            components = self.count_subclasses(self.mean_prior_dict['StarTypes'][star_class])

            print(star_class +' includes '+ str(components) +' components ')

            sampler: mgmm.ModifiedGaussianSampler = mgmm.ModifiedGaussianSampler(b=b, 
                                                                                components=components, 
                                                                                features=self.PP)
            model_name = self.construct_model_name(star_class, PATH_MODELS)
            samples, error = self.attempt_sample_load(model_name, sampler, n_samples=n_samples)
            
            # If we have priors and failed to load the model, try with priors=False
            if self.priors and samples is None:
                model_name = self.construct_model_name(star_class, PATH_MODELS)
                samples, error = self.attempt_sample_load(model_name, sampler, n_samples=n_samples)
            
            # If still not loaded, raise an error
            if samples is None:
                raise ValueError("The model can't be loaded." + str(error))

            if 'all_classes_samples' in locals() and all_classes_samples is not None: 
                all_classes_samples = np.vstack((all_classes_samples, samples))
            else: 
                all_classes_samples = samples
                print(all_classes_samples.shape)

        print('cuda: ', torch.cuda.is_available())
        print('model: ', self.vae_model)


        columns = ['Period', 'teff_val', '[Fe/H]_J95', 'abs_Gmag', 'radius_val', 'logg']
        index_period = columns.index('Period')
        mu_ = reg.process_regressors(self.config_file, phys2=columns, samples= all_classes_samples, 
                                            from_vae=False, train_rf=False)


        # Directly convert to tensors and move to the GPU
        mu_ = torch.tensor(mu_, device=self.device)
        onehot = torch.tensor(onehot, device=self.device)
        lb = np.array(lb)  
        pp = torch.tensor(all_classes_samples, device=self.device)

        # Clear GPU cache
        vae, _ = utils.load_model_list(ID=self.vae_model, device=self.device)
        times, original_sequences = self.create_time_sequences(lb, all_classes_samples[:,index_period])

        torch.cuda.empty_cache()
        xhat_mu = self.process_in_batches(vae, mu_, times, onehot, pp, 1)
        xhat_mu = torch.cat([times.unsqueeze(-1), xhat_mu], dim=-1).cpu().detach().numpy()

        #TODO: filter here light curves

        indices = np.random.choice(xhat_mu.shape[0], 24, replace=False)
        sampled_arrays = xhat_mu[indices, :, :]

        utils.plot_wall_lcs_sampling(sampled_arrays, sampled_arrays,  cls=lb[indices],  column_to_sensivity=index_period,
                                to_title = pp[indices], sensivity = 'Period', all_columns=columns, save=False, wandb_active=wandb_active) 

        lc_reverted = utils.revert_light_curve(pp[:,index_period], xhat_mu, original_sequences, classes = lb)

        if plot_example:
            plt.figure()
            plt.scatter(lc_reverted[0][1], lc_reverted[0][0])
            plt.show()

        mean_value = np.nanmean(lc_reverted)
        lc_reverted[np.isnan(lc_reverted)] = mean_value
        oversampling = True

        if oversampling: 
            sampler = LightCurveRandomSampler(lc_reverted, onehot_to_train, self.seq_length, 12)
            lc_reverted, onehot_to_train = sampler.sample()
        else:
            lc_reverted = lc_reverted[:, :, :self.seq_length]

        lc_reverted = np.diff(lc_reverted, axis=-1)
        mean_value = np.nanmean(lc_reverted)
        lc_reverted[(lc_reverted)> self.delta_max] = mean_value

        if plot_example:
            plt.figure()
            plt.scatter(lc_reverted[0][1], lc_reverted[0][0])
            plt.show()

        if np.sum(np.isnan(lc_reverted)) > 0:
            print(f"Number of NaN values detected: {np.sum(np.isnan(lc_reverted))}")
            raise ValueError("NaN values detected in lc_reverted array")

        utils.save_arrays_to_folder(lc_reverted, onehot_to_train , PATH_DATA)

        numpy_array_x = np.load(PATH_DATA+'/x_batch_pelsvae.npy', allow_pickle=True)
        numpy_array_y = np.load(PATH_DATA+'/y_batch_pelsvae.npy', allow_pickle=True)

        self.x_array = numpy_array_x
        self.y_array = numpy_array_y 
        
        if plot_example:
            plt.figure()
            plt.scatter(numpy_array_x[0][0], numpy_array_x[0][1])
            plt.show()

        synth_data = utils.move_data_to_device((numpy_array_x, numpy_array_y), self.device)
        synthetic_dataset = TensorDataset(*synth_data)
        synthetic_dataloader = DataLoader(synthetic_dataset, batch_size=self.batch_size, shuffle=True)

        return synthetic_dataloader

In [102]:
batcher = SyntheticDataBatcher(PP = PP, vae_model=vae_model, n_samples=sinthetic_samples_by_class, 
                                    seq_length = 300, prior=priors)

In [None]:
synthetic_data_loader = batcher.create_synthetic_batch(b=b, wandb_active=wandb_active, samples_dict = None)

{'PATH_PRIOS': 'src/gmm/priors.yaml', 'PATH_PP': '/home/franciscoperez/Documents/GitHub/CNN-PELSVAE2/cnn-pels-vae/data/metadata_updated_0823.csv', 'PATH_LIGHT_CURVES_OGLE': '/home/franciscoperez/Desktop/Code/FATS/LCsOGLE/data/', 'PATH_FEATURES_TRAIN': '/home/franciscoperez/Documents/GitHub/data/BIASEDFATS/Train_rrlyr-1.csv', 'PATH_FEATURES_TEST': '/home/franciscoperez/Documents/GitHub/data/BIASEDFATS/Test_rrlyr-1.csv', 'PATH_NUMPY_DATA_X_TRAIN': 'data/train_np_array.npy', 'PATH_NUMPY_DATA_X_TEST': 'data/test_np_array.npy', 'PATH_NUMPY_DATA_Y_TRAIN': 'data/train_np_array_y.npy', 'PATH_NUMPY_DATA_Y_TEST': 'data/test_np_array_y.npy', 'PATH_SUBCLASSES': '/home/franciscoperez/Documents/GitHub/vsbms_multiple_classes/bayesianMLP/src/data/all_subclasses', 'PATH_DATA_FOLDER': '/home/franciscoperez/Documents/GitHub/CNN-PELSVAE2/cnn-pels-vae/data', 'PATH_COLAB_ROOT': '/content/drive/My Drive/Colab_Notebooks/data', 'PATH_EXALEARN_ROOT': '/home/user/data', 'PATH_MODELS': '/home/franciscoperez/Docum