In [None]:
import numpy as np
import pandas as pd
import scipy
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [2]:
# Some Settings by default
SIDEREAL_SCALE = 86400. / 86164.0905 # days per sidereal year
TIME_WINDOW = 300
TIME_PAD = 100
ERROR_FLOOR = 0.01

In [3]:
def mag_to_flux(light_curves):
    # Iterar sobre cada curva de luz agrupada por 'oid'
    for id_light_curve, light_curve in light_curves.groupby(by='oid'):
        
        light_curves.loc[light_curve.index, 'flux'] = 10**(-0.4 * (light_curve['magpsf']+48.60))
        light_curves.loc[light_curve.index, 'fluxerr'] = 10**(-0.4 * (light_curve['sigmapsf']+48.60))

    light_curves = light_curves.drop(columns=['magpsf', 'sigmapsf'])

    return light_curves

In [4]:
def crossmatch_object_alerce(alerce_lc: pd.DataFrame, object: pd.DataFrame) -> pd.DataFrame:
    lightcurves = pd.merge(left=alerce_lc, right=object,
                       on='oid')
    return lightcurves

def process_light_curve_parsnip(ligth_curve):

    new_light_curve = ligth_curve.copy()

    SIDEREAL_SCALE = 86400. / 86164.0905

    time = ligth_curve['mjd'].to_numpy()
    sidereal_time = time * SIDEREAL_SCALE

    # Initial guess of the phase. Round everything to 0.1 days, and find the decimal
    # that has the largest count.
    mode, count = scipy.stats.mode(np.round(sidereal_time % 1 + 0.05, 1), keepdims=True)
    guess_offset = mode[0] - 0.05

    # Shift everything by the guessed offset
    guess_shift_time = sidereal_time - guess_offset

    # Do a proper estimate of the offset.
    sidereal_offset = guess_offset + np.median((guess_shift_time + 0.5) % 1) - 0.5

    # Shift everything by the final offset estimate.
    shift_time = sidereal_time - sidereal_offset

    # Selecting the 
    s2n = ligth_curve['magpsf'] / ligth_curve['sigmapsf']
    s2n_mask = np.argsort(s2n)[-5:]

    cut_times = shift_time[s2n_mask]

    max_time = np.round(np.median(cut_times))

    # Convert back to a reference time in the original units. This reference time
    # corresponds to the reference of the grid in sidereal time.
    reference_time = ((max_time + sidereal_offset) / SIDEREAL_SCALE)
    grid_times = (time - reference_time) * SIDEREAL_SCALE
    time_indices = np.round(grid_times).astype(int) + 300 // 2 # 300 days
    time_mask = (
        (time_indices >= -100)
        & (time_indices < 300 + 100)
    )
    new_light_curve['grid_time'] = grid_times
    new_light_curve['time_index'] = time_indices
    new_light_curve = new_light_curve[time_mask]

    return new_light_curve  

def _get_data(light_curves):
    device = 'cpu'  
    redshifts = []
    compare_data = []
    
    error_floor = 0.01
    # Build a grid for the input
    # The first grid is created for saved the data
    # The second grid is created for save the weights that will be used
    # on the loss_function
    try:
        len_light_curves = len(light_curves.oid.unique())
    except:
        len_light_curves = len(light_curves)
    #print(len_light_curves)
    grid_flux    = np.zeros((len_light_curves,1,300))
    grid_weights = np.zeros_like(grid_flux) 

    # Iterate over each unique object ID (oid)
    try:
        for idx, (oid, light_curve) in enumerate(light_curves.groupby('oid')):
            redshifts.append(0.01)
            print(light_curve)
            # Mask observations outside the window
            mask = (light_curve['time_index'] >= 0) & (light_curve['time_index'] < 300)
            light_curve = light_curve[mask]

            # Calculate weights
            weights = 1 / (light_curve['sigmapsf']**2 + error_floor**2)

            # Fill in the input arrays
            grid_flux[idx, 0, light_curve['time_index']] = light_curve['magpsf']
            grid_weights[idx, 0, light_curve['time_index']] = error_floor**2 * weights


            obj_compare_data = torch.FloatTensor(np.vstack([
                light_curve['grid_time'],
                light_curve['magpsf'],
                light_curve['sigmapsf'],
                weights,
            ]))
            compare_data.append(obj_compare_data.T)
    except:
        print('No paso con oids')
        pass
    try:
        for idx, light_curve in enumerate(light_curves):
                redshifts.append(0.01)
                #print(light_curve)
                light_curve = pd.DataFrame(light_curve[list(light_curve.keys())[0]])
                #print(light_curve)

                # Mask observations outside the window
                mask = (light_curve['time_index'] >= 0) & (light_curve['time_index'] < 300)
                light_curve = light_curve[mask]

                # Calculate weights
                weights = 1 / (light_curve['sigmapsf']**2 + error_floor**2)

                # Fill in the input arrays
                grid_flux[idx, 0, light_curve['time_index']] = light_curve['magpsf']
                grid_weights[idx, 0, light_curve['time_index']] = error_floor**2 * weights


                obj_compare_data = torch.FloatTensor(np.vstack([
                    light_curve['grid_time'],
                    light_curve['magpsf'],
                    light_curve['sigmapsf'],
                    weights,
                ]))
                compare_data.append(obj_compare_data.T)
    except:
        pass
        
    redshifts = np.array(redshifts)
    extra_input_data = [redshifts]

    input_data = np.concatenate(
            [i[:, None, None].repeat(300, axis=2) for i in extra_input_data]
            + [grid_flux, grid_weights],
            axis=1
        )
    
    input_data = torch.FloatTensor(input_data).to(device)
    redshifts = torch.FloatTensor(redshifts).to(device)

    # Pad all of the compare data to have the same shape.
    compare_data = nn.utils.rnn.pad_sequence(compare_data, batch_first=True)
    compare_data = compare_data.permute(0, 2, 1)
    compare_data = compare_data.to(device)

    data = {
        'input_data': input_data,
        'compare_data': compare_data,
        'redshift': redshifts,
    }

    return data

In [5]:
def plot_light_curve(light_curve, oid:any = None):

    time = light_curve['mjd'].to_numpy()


    fig, ax = plt.subplots()
    try:
      mag  = light_curve['magpsf'].to_numpy()
      ax.plot(time,mag,'o')
      ax.set_ylim(ax.get_ylim()[::-1])
      ax.set_ylabel('Apparent magnitude')
    except:
      flux = light_curve['flux'].to_numpy()
      ax.plot(time,flux,'o')
      ax.set_ylabel(r'Flux erg s−1 cm−2 Hz−1')
    ax.set_xlabel('MJD')


    if oid != None:
        ax.set_title(f'oid: {oid}')

# Real Data

In [None]:
object_table = pd.read_pickle('~/Supernovae_DeepLearning/object_ZTF_ALeRCE_19052024.pkl')
print(object_table)
print('\nNumber of Different Objects in object_table:', len(object_table.oid.unique()))

In [None]:
lightcurves_alercextns = pd.read_pickle('/home/jurados/Supernovae_DeepLearning/data/lightcurves/lcs_transients_20240517.pkl')
print(lightcurves_alercextns)
print('\nNumber of Different Objects in lightcurves_alercextns:', len(lightcurves_alercextns.oid.unique()))

In [None]:
# Here I realized a crossmatch between all lightcurves_alercextns and
# the object table
lightcurves = crossmatch_object_alerce(lightcurves_alercextns, object_table)
print(lightcurves)
print('\nNumber of Different Objects in lightcurves (Crossmatch):', len(lightcurves.oid.unique()))

In [None]:
lightcurves = lightcurves[lightcurves.fid == 1]
lightcurves.oid.unique().size

In [None]:
print('The oid of transients with maximum length is:', lightcurves.groupby(by='oid')['mjd'].size().idxmax())
print('The maximum length of observation is:', lightcurves.groupby(by='oid')['mjd'].size().max())

In [11]:
sn_unique = lightcurves.oid.unique()
selected_unique_values = np.random.choice(sn_unique, int(0.8 * len(sn_unique)), replace=False)
selected_unique_values_test = np.random.choice(sn_unique, int(0.2 * len(sn_unique)), replace=False)

In [12]:
train_data = lightcurves[lightcurves.oid.isin(selected_unique_values)]
test_data = lightcurves[lightcurves.oid.isin(selected_unique_values_test)]

In [None]:
train_data

In [None]:
print('The oid of transients with maximum length is:', train_data.groupby(by='oid')['mjd'].size().idxmax())
print('The maximum length of observation is:', train_data.groupby(by='oid')['mjd'].size().max())

In [None]:
test = train_data.copy()
test = test[['oid','mjd','magpsf']]
test

In [None]:
one_light_curve = train_data[train_data.oid == train_data.oid.unique()[3]]
one_light_curve

In [None]:
plot_light_curve(one_light_curve, one_light_curve.oid.unique()[0])

# Synthetic Model

The next Syntethic model is based on the papers of [Olivares et al. 2010](https://ui.adsabs.harvard.edu/abs/2010ApJ...715..833O/abstract).

This synthetic model created light-curves based on the next functions:

$$ f_{\text{DF}} = \frac{-a_0}{1+\exp\left( (t - t_{\text{PT}})\right) / w_0}$$
$$ l(t) = p_o(t-t_{\text{PT}}) + m_0$$
$$ g(t) = -P e^{\left(\frac{t-Q}{R}\right)^{2}}

In [18]:
class SyntheticLigthCurve:
    def __init__(self,tpt,a0,w0,p0,m0,P,Q,R):
        self.a0 = a0   # height of light_curve 
        self.w0 = w0   # width of the transition phase
        self.tpt = tpt # middle of the transition
        self.p0 = p0   # slope of the radioactive tail
        self.m0 = m0   # the zero point in the magnitude
        self.P = P     # height of the Gaussian peak
        self.Q = Q     # center of the Gaussian function
        self.R = R     # width of the Gaussian function

    def olivares(self,t):
        f_fd = -self.a0/(1+np.exp((t-self.tpt)/self.w0))
        f_ld = self.p0*(t-self.tpt)+self.m0
        f_gs =  -self.P * np.exp(-((t-self.Q)/self.R)**2)
        f = f_fd + f_ld  + f_gs
        return f

In [None]:
po_v = [95,1.744,3.602,0.008,14.482,1.675,20.148,-15.984] # tpt, a0, w0, p0, m0, P, Q, R
po_r = [88.948, 1.584, 4.485, 0.005, 13.759, 1.528, 101.390, -17.934]
fig, ax = plt.subplots()
synthetic_lc_v = SyntheticLigthCurve(po_v[0],po_v[1],po_v[2],po_v[3],po_v[4],po_v[5],po_v[6],po_v[7])
synthetic_lc_r = SyntheticLigthCurve(po_r[0],po_r[1],po_r[2],po_r[3],po_r[4],po_r[5],po_r[6],po_r[7])
time = np.linspace(50,150,300)
ax.plot(time, synthetic_lc_v.olivares(time))
ax.plot(time, synthetic_lc_r.olivares(time))
ax.invert_yaxis()
ax.set_xlabel('Time [days]')
ax.set_ylabel('Apparent Magnitude')

In [20]:
p0 = [91.026,1.744,3.602,0.008,14.482,1.675,102.148,-15.984] # tpt, a0, w0, p0, m0, P, Q, R

N = 10000
time_length = 300 # days

oid = []
mjd = []
mag = []

for n in range(N):
    increase_random = np.random.random()
    time = np.linspace(50, 150, 50)
    po = [p0[0] + increase_random, p0[1] + increase_random, p0[2] + 0.01*increase_random, p0[3], p0[4] + 0.01*increase_random,
          p0[5] + 0.01*increase_random, p0[6] + 0.01*increase_random, p0[7] + 0.01*increase_random]
    #print(po)
    synthetic_lc = SyntheticLigthCurve(po[0],po[1],po[2],po[3],po[4],po[5],po[6],po[7])
    f = synthetic_lc.olivares(time)
    mjd.append(np.array(time))
    mag.append(np.array(f))

# Crear una lista vacía para almacenar los datos
train_data_synthetic = []
test_data_synthetic  = []

# Iterar sobre cada conjunto de tiempo y brillo
for i in range(len(mjd)):
    time_set = mjd[i]
    brightness_set = mag[i]
    code = 'ZTF' + str(i + 1)  # Código de pertenencia (del 1 al 8000)
    for time, brightness in zip(time_set, brightness_set):
        if i+1 <= 8_000:
          train_data_synthetic.append([code, time, brightness])
        else:
          test_data_synthetic.append([code, time, brightness])

# Crear el DataFrame
train_data_synthetic = pd.DataFrame(train_data_synthetic, columns=['oid','mjd', 'magpsf'])
test_data_synthetic = pd.DataFrame(test_data_synthetic, columns=['oid','mjd', 'magpsf'])
train_data_synthetic['sigmapsf'] = 0.001
test_data_synthetic['sigmapsf'] = 0.001

In [None]:
oid_idx = 0
plot_light_curve(train_data_synthetic[train_data_synthetic.oid == train_data_synthetic.oid.unique()[oid_idx]], oid=train_data_synthetic.oid.unique()[oid_idx])

In [None]:
train_data_synthetic = mag_to_flux(train_data_synthetic)
test_data_synthetic  = mag_to_flux(test_data_synthetic)

In [None]:
oid_idx = 0
plot_light_curve(train_data_synthetic[train_data_synthetic.oid == train_data_synthetic.oid.unique()[oid_idx]], oid=train_data_synthetic.oid.unique()[oid_idx])

In [None]:
print(train_data_synthetic.head())
print('\nShape train_data_synthethic:', train_data_synthetic.shape)
print('Length train_data_synthethic:', len(train_data_synthetic))
print()
print(test_data_synthetic.head())
print('\nShape test_data_synthethic:', test_data_synthetic.shape)
print('Length test_data_synthethic:', len(test_data_synthetic))

# Processing the Data

In [23]:
def new_light_curves(light_curves):
    """ Create a new data set adding the grid_time and
    time_index columns
    """

    new_light_curves = pd.DataFrame()
    for id_group, group in light_curves.groupby('oid'):
        light_curve_processed = process_light_curve_parsnip(group)
        new_light_curves = pd.concat([new_light_curves, light_curve_processed])

    return new_light_curves


train_data_synthetic = new_light_curves(train_data_synthetic)
test_data_synthetic  = new_light_curves(test_data_synthetic)

In [None]:
print(train_data_synthetic.head())
print('\nShape train_data_synthethic:', train_data_synthetic.shape)
print('Length train_data_synthethic:', len(train_data_synthetic))
print()
print(test_data_synthetic.head())
print('\nShape test_data_synthethic:', test_data_synthetic.shape)
print('Length test_data_synthethic:', len(test_data_synthetic))

In [25]:
#matrix_train = _get_data(train_data_synthetic)
#matrix_test = _get_data(test_data_synthetic)

In [26]:
#matrix_train['compare_data'][:,0]

In [None]:
class LightCurveDataset():
    def __init__(self, dataframe):
        self.dataframe = dataframe
        # Agrupar por 'oid' y almacenar los grupos
        self.groups = dataframe.groupby('oid')

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        # Obtener el 'oid' por índice (esto devuelve una tupla: (oid, dataframe))
        oid, group = list(self.groups)[idx]

        # Convertir las columnas 'tiempo' y 'mag' en listas
        #mjd = torch.tensor(group['mjd'].values, dtype=torch.float32).to(device)
        #magpsf = torch.tensor(group['magpsf'].values, dtype=torch.float32).to(device)
        #sigmapsf = torch.tensor(group['sigmapsf'].values, dtype=torch.float32).to(device)
        #grid_time = torch.tensor(group['grid_time'].values, dtype=torch.float32).to(device)
        #time_index = torch.tensor(group['time_index'].values, dtype=torch.float32).to(device)

        mjd = group['mjd'].to_numpy()
        flux = group['flux'].to_numpy()
        fluxerr = group['fluxerr'].to_numpy()
        grid_time = group['grid_time'].to_numpy()
        time_index = group['time_index'].to_numpy()


        # Retornar un diccionario en el formato deseado
        # light_curve_dict = {
        #     oid: {
        #         'mjd': mjd,
        #         'magpsf': magpsf,
        #         'fluxerr': fluxerr,
        #         'grid_time': grid_time,
        #         'time_index': time_index
        #     }
        # }

        light_curve_dict = {
          'mjd': torch.tensor(mjd, dtype=torch.float32).to(device),
          'flux': torch.tensor(flux, dtype=torch.float32).to(device),
          'fluxerr': torch.tensor(fluxerr, dtype=torch.float32).to(device),
          'grid_time': torch.tensor(grid_time, dtype=torch.torch.uint32).to(device),
          'time_index': torch.tensor(time_index, dtype=torch.torch.uint32).to(device)
        }

        return light_curve_dict

# Instancia el dataset personalizado
train_pandas_dataset = LightCurveDataset(train_data_synthetic)
test_pandas_dataset = LightCurveDataset(test_data_synthetic)

# Crear el DataLoader
train_loader = DataLoader(train_pandas_dataset, batch_size=64, collate_fn=list, shuffle=True)
test_loader = DataLoader(test_pandas_dataset, batch_size=64, shuffle=True)

In [None]:
for batch in train_loader:
    print(type(batch))  # Verifica el tipo de batch
    print(batch[0]['time_index'])
    print(_get_data(batch))
    break

In [None]:
class ModelV1(nn.Module):
    """ This model will use Conv1d to encode and decode the NN
    The Output Size using Conv1d in Pytorch could be calculated using:

    output_size = [Lin + 2 x padding - dilation x (kernel_size - 1) - 1] / stride + 1 

    """

    def __init__(self, input_size: int, latent_size: int, **kwargs) -> None:
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels=input_size, out_channels=16, kernel_size=3, dilation=1, padding=2*1),  # Input size [batch_size, 3, input_size]
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, dilation=1, padding=2*1),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, dilation=1, padding=2*1),
            nn.ReLU(),
        )

        # Linear layers for encoding mean and logvar
        self.encode_mean = nn.Linear(in_features=64, out_features=latent_size)
        self.encode_logvar = nn.Linear(in_features=64, out_features=latent_size)

        # Decoder input layer
        self.decoder_input = nn.Linear(in_features=latent_size, out_features=64)

        # Decoder with transposed convolutions (ConvTranspose1d)
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(in_channels=64, out_channels=32, kernel_size=3, padding=2*1),
            nn.ReLU(),
            nn.ConvTranspose1d(in_channels=32, out_channels=16, kernel_size=3, padding=2*1),
            nn.ReLU(),
            nn.ConvTranspose1d(in_channels=16, out_channels=3, kernel_size=3, padding=2*1)  # Output size [batch_size, 3, input_size]
        )

    def encode(self, x):
        print('Shape Pre-encoder', x.shape)
        x = x.permute(0,2,1)
        print('second Shape pre-encoder', x.shape)
        # Flatten the output from the convolutional layers
        #x = x.view(x.size(0), -1)
        #print('Flattened shape:', x.shape)
        mu, log_var = self.encode_mean(x), self.encode_logvar(x)
        print('Shape mu:',mu.shape)
        print('Shape log_var:',log_var.shape)
        return mu, log_var

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        result = self.decoder_input(z)
        result = self.decoder(result)
        return result

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, input: torch.Tensor, **kwargs) -> list:
        print('Shape input:',input.shape)
        mu, log_var = self.encode(input)
        print('Shape mu:',mu.shape)
        print('Shape log_var:',log_var.shape)
        z = self.reparameterize(mu, log_var)
        return [self.decode(z), input, mu, log_var]

    def obtain_results(self, light_curves):

        data = _get_data(light_curves)
        print('Paso el Get data')

        # Encode the light_curves
        encoding_mu, encoding_logvar = self.encode(data['input_data'])
        #print(encoding_mu)
        #print(encoding_logvar)

        time = data['compare_data'][:,0]
        obs_flux = data['compare_data'][:, 1]
        obs_fluxerr = data['compare_data'][:, 2]
        obs_weight = data['compare_data'][:, 3]

        results ={
            'redshift': data['redshift'],
            'time': time,
            'obs_flux': obs_flux,
            'obs_fluxerr': obs_fluxerr,
            'obs_weight': obs_weight,
            'encoding_mu': encoding_mu,
            'encoding_logvar': encoding_logvar,
        }

        #results = {k:v.detach().cpu().numpy() for k,v in results.items()}
        #results = {k:v.detach().cpu().numpy() for k,v in results.items()}

        return results

In [29]:
class ModelV0(nn.Module):

    def __init__(self, input_size, hidden_dim, latent_size, device):

        self.device = device
        super().__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(in_features=input_size, out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim, out_features=latent_size),
            nn.ReLU()
            )

        # latent mean and variance
        # Acá tenia un dos en
        #mu_size, logvar_size = 3, 3
        self.encode_mean_layer = nn.Linear(in_features=latent_size, out_features=2)
        self.encode_logvar_layer = nn.Linear(in_features=latent_size, out_features=2)

        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(in_features=2, out_features=latent_size),
            nn.ReLU(),
            nn.Linear(in_features=latent_size, out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=hidden_dim, out_features=input_size),
            )

        self.to(self.device)

    def encode(self, x):
        print('Shape pre-encoder', x.shape)
        x = x.permute(0,2,1)
        print('second Shape pre-encoder', x.shape)
        x = self.encoder(x)

        print('Entro al Encoder')
        print(x.shape)
        mean, logvar = self.encode_mean_layer(x), self.encode_logvar_layer(x)
        return mean, logvar

    def reparameterization(self, mean, logvar):
        print('Esta parametrizando al Encoder')
        # Var(x) = std**2 -> 0.5*ln(Var(x)) = ln(std)
        # std = exp(0.5*ln(Var(x)))
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mean + eps*std
        return z

    def decode(self, x):
        print('Entro al DeEncoder')
        return self.decoder(x)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterization(mean, logvar)
        x_hat = self.decode(z)
        print('Melo el forward')
        return x_hat, mean, logvar

    def obtain_results(self, light_curves):

        data = _get_data(light_curves)
        print('Paso el Get data')

        # Encode the light_curves
        encoding_mu, encoding_logvar = self.encode(data['input_data'])
        #print(encoding_mu)
        #print(encoding_logvar)

        time = data['compare_data'][:,0]
        obs_flux = data['compare_data'][:, 1]
        obs_fluxerr = data['compare_data'][:, 2]
        obs_weight = data['compare_data'][:, 3]

        results ={
            'redshift': data['redshift'],
            'time': time,
            'obs_flux': obs_flux,
            'obs_fluxerr': obs_fluxerr,
            'obs_weight': obs_weight,
            'encoding_mu': encoding_mu,
            'encoding_logvar': encoding_logvar,
        }

        #results = {k:v.detach().cpu().numpy() for k,v in results.items()}
        #results = {k:v.detach().cpu().numpy() for k,v in results.items()}

        return results


In [30]:
input_size = 3 # input_size = 2 * N_bands + 1 -> 2: flux and error 1 for redshift.
h_dim = 200
z_dim = 20
num_epoch = 5
batch_size = 32
learning_rate = 1e-4

model = ModelV0(input_size=input_size, hidden_dim=h_dim, latent_size=z_dim, device=device)
#model = ModelV1(input_size = input_size, latent_size = z_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
from torchinfo import summary
summary(model)

In [32]:
def loss_function(results):

#     print('X_reconstructed:\nType:',type(X_reconstructed),\
#           '\nShape:',X_reconstructed.shape)
#     print('X_weights:\nType',type(X_weights),\
#           '\nShape:',X_weights.shape)
      nll = (0.5 * torch.tensor(results['obs_weight']) * torch.tensor(results['obs_flux'])**2)
      kld = -0.5 * (1 + torch.tensor(results['encoding_logvar'], requires_grad=True)
                      - torch.tensor(results['encoding_mu'], requires_grad=True)**2
                      - torch.exp(torch.tensor(results['encoding_logvar'], requires_grad=True)))

      nll = torch.sum(nll)
      #print(nll)
      kl_div = torch.sum(kld)
      #print(kl_div)

      return nll + kl_div

In [33]:
def replace_nan_grads(parameters, value=0.0):
    """Replace NaN gradients

    Parameters
    ----------
    parameters : Iterator[torch.Tensor]
        Model parameters, usually you can get them by `model.parameters()`
    value : float, optional
        Value to replace NaNs with
    """
    for p in parameters:
        if p.grad is None:
            continue
        grads = p.grad.data
        grads[torch.isnan(grads)] = value

In [None]:
train_losses = []
for epoch in tqdm(range(num_epoch)):
    # Set the model to training mode
    model.train()
    i = 0
    for i, batch in enumerate(train_loader):
        # Mueve los datos del batch al dispositivo (GPU o CPU)
        #batch = [{oid: {key: value.to(model.device) for key, value in light_curve.items()}} for light_curve in batch]
        print(i)
        results = model.obtain_results(batch)
        loss = loss_function(results)

        loss.backward()
        train_losses.append(loss.item())

        optimizer.step()
        optimizer.zero_grad()

    # Puedes añadir un mensaje de final de epoch aquí si lo deseas
    #if epoch % 10 == 0:
    print(f'Epoch [{epoch+1}/{num_epoch}], Loss: {loss.item()/300:.4f}')

In [None]:
fig, ax = plt.subplots()
ax.plot(np.arange(num_epoch), np.array(train_losses)/300)
ax.set_xlabel('Epochs')
ax.set_ylabel('Train loss')

In [None]:
def plot_curves(time, original_magnitude, reconstructed_magnitude, title='Curvas Originales y Reconstruidas'):
    plt.figure(figsize=(12, 6))
    plt.plot(time, original_magnitude, label='Original', linestyle='--', color='blue')
    plt.plot(time, reconstructed_magnitude, label='Reconstruido', linestyle='-', color='red')
    plt.xlabel('Tiempo')
    plt.ylabel('Magnitud')
    plt.title(title)
    plt.legend()
    plt.show()

In [None]:
model.eval()

# Definir una lista para almacenar las pérdidas de cada muestra de prueba
test_losses = []

# Definir una lista para almacenar las reconstrucciones de las muestras de prueba
reconstructions = []

#with torch.inference_mode():
with torch.inference_mode():
    for _, group_oid in test_data_synthetic.groupby(by='oid'):
        X = group_oid[['mjd','magpsf','sigmapsf']]
        #X = process_light_curve_atat(X)# Karpathy constant is just a joke
        X = process_light_curve_parsnip(X)
        X, X_weights = create_grid(X)
        time = X[0,:]
        original_magnitude = X[1,:]
        X = torch.tensor(X, dtype=torch.float32)
        X_weights = torch.tensor(X_weights, dtype=torch.float32)
        X = X.T
        X_weights = X_weights.T

        # Forward pass
        X_reconstructed, mu, logvar = model(X)

        #print(X_reconstructed)
        # Convierte las predicciones a numpy
        reconstructed_magnitude = X_reconstructed.cpu().numpy()[:,0]  # Segunda fila para magnitud
        #print(len(reconstructed_magnitude))
        
        # Grafica
        plot_curves(time, original_magnitude, reconstructed_magnitude, title=f'Curvas para OID {oid}')
        
        # Sal de la iteración si solo deseas graficar para un OID
        break

        # Compute loss
        loss = loss_function(X_reconstructed, X_weights, mu, logvar)
        #replace_nan_grads(model.parameters())

        # Guardar la pérdida y las reconstrucciones
        test_losses.append(loss.item())
        reconstructions.append(X_reconstructed.cpu().numpy())

    # Calcular la pérdida promedio en los datos de prueba
    average_test_loss = np.mean(test_losses)
    print("Average test loss:", average_test_loss)

In [None]:
reconstructions = np.array(reconstructions)
len(reconstructions[0][:,0]), len(reconstructions[0][:,1])

In [None]:
reconstructions[0][:,0]

In [None]:
sample_index = 2
sample_data = test_data_synthetic[test_data_synthetic.oid == test_data_synthetic.oid.unique()[sample_index]]
sample_data = process_light_curve_parsnip(sample_data)
sample_data, _ = create_grid(sample_data)
sample_reconstruction = reconstructions[sample_index]

fig, ax = plt.subplots()

ax.plot(sample_data[0,:], sample_data[1,:], 'o', color='C0')
ax.plot(np.linspace(0,300,300), sample_reconstruction[:,1], 'o', color='C1')

ax.set_xlabel('Time')
ax.set_ylabel('Flux')
ax.set_title('Original vs. Reconstructed Flux')
#ax.legend()
#ax.invert_yaxis()
plt.show()