In [1]:
import numpy as np
import os
import torch 
import yaml 

new_directory = '/home/franciscoperez/Documents/GitHub/CNN-PELSVAE2/cnn-pels-vae/'
os.chdir(new_directory)

from src.utils import get_time_sequence, get_time_from_period

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
                      
with open('src/configuration/nn_config.yaml', 'r') as file:
    nn_config = yaml.safe_load(file)

In [3]:
def revert_light_curve(period, folded_normed_light_curve, original_sequences, faintness=1.0, classes = None):
    """
    Revert previously folded and normed light curves back to the original light curves.

    Parameters:
        period (float): The period of the variability in the light curves.
        folded_normed_light_curve (numpy.ndarray): A 3D array representing the folded and normed light curves.
        faintness (float, optional): A scaling factor to adjust the magnitude values of the reverted light curves.
                                     Defaults to 1.0, meaning no scaling.

    Returns:
        list: A list of 2D arrays representing the reverted real light curves with time and magnitude values.
    """
    num_sequences = folded_normed_light_curve.shape[0]
    print('num_sequences: ', num_sequences)
    reverted_light_curves = []
    time_sequences = get_time_sequence(n=1, star_class=classes)
    print('time_sequences: ', time_sequences)
    for i in range(num_sequences):
        # Extract the time (period) and magnitude values from the folded and normed light curve
        time = original_sequences[i] #folded_normed_light_curve[i,:,0]
        print('time: ', time)
        if np.max(folded_normed_light_curve[i,:,0])<0.95: 
            continue

        normed_magnitudes_min = np.min(folded_normed_light_curve[i,:,1])
        normed_magnitudes_max = np.max(folded_normed_light_curve[i,:,1])
        #To ensure maximum 1 and minimum 0
        normed_magnitudes = ((folded_normed_light_curve[i,:,1]-normed_magnitudes_min)/
                            (normed_magnitudes_max-normed_magnitudes_min))
        
        print('normed_magnitudes: ', normed_magnitudes)

        # Generate the time values for the reverted light curve
        [original_min, original_max] = time_sequences[i]
        
        print('original_min, original_max: ', original_min, original_max)

        #real_time =  time #get_time_from_period(period[i], time, example_sequence, sequence_length=600)

        real_time =  get_time_from_period(period[i],  folded_normed_light_curve[i,:,0], time, sequence_length=600)

        print('real_time: ', real_time)
        # Revert the normed magnitudes back to the original magnitudes using min-max scaling and faintness factor
        original_magnitudes = ((normed_magnitudes * (original_max - original_min)) + original_min) * faintness
        
        print('original_magnitudes: ', original_magnitudes)
        
        if isinstance(real_time, torch.Tensor):
            if real_time.is_cuda:
                real_time = real_time.cpu().numpy()
            else:
                real_time = real_time.numpy()

        # Convert original_magnitudes to NumPy array if it's a PyTorch tensor
        if isinstance(original_magnitudes, torch.Tensor):
            if original_magnitudes.is_cuda:
                original_magnitudes = original_magnitudes.cpu().numpy()
            else:
                original_magnitudes = original_magnitudes.numpy()
        # No need to convert if original_magnitudes is already a NumPy array

        # Now, you can use np.column_stack without issues
        reverted_light_curve = np.column_stack((original_magnitudes, real_time))

        reverted_light_curves.append(reverted_light_curve)
    
    reverted_light_curves = np.stack(reverted_light_curves)

    reverted_light_curves = np.swapaxes(reverted_light_curves, 1, 2)
    # Generate random unique indices along the last dimension
    random_indices = np.random.choice(350, nn_config['data']['seq_length']+1, replace=False)

    # Sort the indices for easier interpretation and debugging (optional)
    random_indices = random_indices.sort()

    # Select 200 random observations
    reverted_light_curves_random = reverted_light_curves[:, :, random_indices]

    reverted_light_curves_random = reverted_light_curves_random.squeeze(2) 

    for i in range(reverted_light_curves_random.shape[0]):
        sort_indices = np.argsort(reverted_light_curves_random[i, 1, :])
        for j in range(reverted_light_curves_random.shape[1]):
            reverted_light_curves_random[i, j, :] = reverted_light_curves_random[i, j, sort_indices]

    for i in range(reverted_light_curves_random.shape[0]):
        reverted_light_curves_random[i] = np.flipud(reverted_light_curves_random[i])

    print('Shape of reverted_light_curves_random[i, :, :]:', reverted_light_curves_random[i, :, :].shape)
    print('Shape of sort_indices:', sort_indices.shape)

    return reverted_light_curves_random

In [4]:
folded_normed_light_curve = np.random.rand(1, 600, 2)*10
period =  torch.from_numpy(np.asarray([1.0])).to(device) 
original_sequences = [[1.0, 30.0]]
classes = ['RRLYR']

In [5]:
revert_light_curve(period, folded_normed_light_curve, original_sequences, faintness=1.0, classes = classes)

num_sequences:  1


  lc_train['NUMBER'] = lc_train['NUMBER'].str.replace('.dat', '')
Getting time sequences: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.10it/s]

time_sequences:  [[18.251, 19.162]]
time:  [1.0, 30.0]
normed_magnitudes:  [5.60378482e-02 3.52190064e-01 4.35820342e-01 4.32411732e-01
 6.11407468e-01 8.26671677e-01 9.62115055e-01 7.93354642e-01
 7.40778967e-01 3.78368316e-01 9.34879202e-01 4.01317874e-01
 4.31374766e-01 6.09764909e-01 8.62446401e-01 4.00295605e-01
 5.56226647e-01 2.73382840e-01 8.24668649e-01 5.31192273e-01
 2.44875482e-01 2.40518863e-01 2.97799904e-01 8.26435358e-02
 6.69545673e-01 3.99551502e-01 1.09843714e-01 5.66102763e-01
 9.80205000e-01 5.13422337e-01 7.28379075e-01 1.94508853e-01
 7.55085957e-01 2.69847686e-01 9.83789370e-01 3.71846850e-01
 1.59803664e-01 2.49264398e-01 4.33402352e-01 3.19471829e-02
 6.36355244e-01 5.23919545e-01 6.03490293e-01 7.92391745e-01
 2.21130257e-02 9.59904939e-01 8.25638624e-01 2.33015313e-01
 3.66214525e-01 8.09112218e-01 8.62938566e-01 2.69903048e-01
 3.60248075e-01 7.48181384e-01 6.12128361e-01 5.14008342e-01
 6.37635266e-01 3.41957357e-01 5.20612104e-01 6.08258320e-01
 5.3172642




array([[[ 2.94907045,  3.34225178,  3.5141747 , ..., 38.57460022,
         38.68961334, 38.87120438],
        [18.85438276, 18.36449261, 19.04206773, ..., 18.37673247,
         18.7714721 , 18.96429197]]])