# Outline
This tutorial demonstrates how to build, train, and evaluate a Recurrent Neural Network (RNN) model for classifying astronomical time series data using PyTorch.

- `Dataset`: PLAsTiCC — simulated time-series observations of astronomical transients using LSST-like photometric data, observed in five optical bands
- `Model`: ResNet and a custom architecture
- `Objective`: Classify input time series into one of 15 object transient classes

The tutorial will guide you through the following steps:

1) Setting up the environment
2) Downloading the dataset and preparing data loaders with appropriate transformations and augmentations
3) Building a RNN classifier with a custom architecture
4) Training the model
5) Evaluating the model on the test dataset

# Configuring the environment

## Module installation
We’ll begin by installing the necessary Python modules for this tutorial.

In [None]:
import os

# - Install modules from requirements.txt if present
if os.path.isfile("requirements_plasticc.txt"):
  print("Installing modules from local requirements_plasticc.txt file ...")
  %pip install -q -r requirements_plasticc.txt
else:
  print("Installing modules ...")  

  %pip install -q pandas                                                       # Data analysis modules                     
  %pip install -q torch torchvision torchmetrics torchinfo torchtune torchao   # ML modules
  %pip install -q sh gdown matplotlib tqdm                                     # Plot/util modules
    
  # - Create requirements file
  %pip freeze > requirements_plasticc.txt

## Import modules
Next, we import the essential modules needed throughout the tutorial.

In [None]:
###########################
##   STANDARD MODULES
###########################
import os
import math
from pathlib import Path
import shutil
import gdown
import tarfile
import numpy as np
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from itertools import islice
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
import urllib.request
from sh import gunzip
import copy

###########################
##   DATA/TORCH MODULES
###########################
# - Data analysis
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score

# - Torch modules
import torch
from torch import Tensor
from torch.utils.data import Dataset, Subset, random_split, DataLoader
import torch.nn.functional as F
import torchvision
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor
import torchmetrics
import torchinfo
from torchinfo import summary
from torchtune.training import get_cosine_schedule_with_warmup

## Project folders
We create a working directory `rundir` to run the tutorial in.

In [None]:
topdir= os.getcwd()
rundir= os.path.join(topdir, "run-plasticc_classifier")
path = Path(rundir)
path.mkdir(parents=True, exist_ok=True)

# 📚 Dataset
For this tutorial, we will use the [**PLAsTiCC dataset**](https://plasticc.org) dataset.
The PLAsTiCC (Photometric LSST Astronomical Time Series Classification Challenge) dataset provides simulated time-series observations of astronomical transients for classification tasks using LSST-like photometric data.
It was originally part of a [**Kaggle competition**](https://www.kaggle.com/competitions/PLAsTiCC-2018/overview) in 2018. 

For this tutorial we are going to use the [**unblinded data collection**](https://zenodo.org/records/2539456), released after the competition closure, that provides classication information for the full test set (previously undisclosed to the participants).

The dataset contains two types of files, in CSV format:

- Metadata files: containing object identifiers, class type and physical parameters (e.g. redshift, etc)
- Lightcurve files: containing time series of object fluxed and other variables per each observing band

Below, we report some examples of light curve data for three different objects (taken from dataset description paper):

<img src="media/plasticc_sample1.png" style="display: block; margin-left: 0; width: 500px;" />
<img src="media/plasticc_sample2.png" style="display: block; margin-left: 0; width: 500px;" />
<img src="media/plasticc_sample3.png" style="display: block; margin-left: 0; width: 500px;" />

Each time series is labeled as one of the following 15 classes:

<img src="media/plasticc_classes.jpg" style="display: block; margin-left: 0; width: 800px;" />
<img src="media/plasticc_classes_counts.jpg" style="display: block; margin-left: 0; width: 800px;" />

More details on the data format are provided below.


## 🔸 Metadata Files

Each row corresponds to a unique transient object and contains both observational and model parameters.

- **File names:**
  - `plasticc_train_metadata.csv.gz`
  - `plasticc_test_metadata.csv.gz`

- **Columns (subset):**
  - `object_id`: Unique identifier
  - `ra`, `decl`: Sky coordinates
  - `ddf_bool`: Deep Drilling Field flag (1 = DDF, 0 = WFD)
  - `hostgal_specz`: Spectroscopic redshift (partial)
  - `hostgal_photoz`, `hostgal_photoz_err`: Photometric redshift and error
  - `distmod`: Distance modulus
  - `mwebv`: Galactic extinction
  - `target`: Challenge class label (only in training set)
  - `true_target`: Actual class label (post-challenge)
  - `true_submodel`: Sub-model variant ID (for some classes)
  - `true_z`, `true_distmod`, `true_lensdmu`: Redshift-related physical quantities
  - `tflux_[u,g,r,i,z,y]`: Template fluxes in each LSST band

## 🔸 Lightcurve Files

Contain photometric time-series data for each object. Each row corresponds to one measurement at a given time and band.

- **Training file:**
  - `plasticc_training_lightcurves.csv`

- **Test files (split into 11 subsets):**
  - `plasticc_test_lightcurves_01.csv.gz` to `plasticc_test_lightcurves_11.csv.gz`

- **Columns:**
  - `object_id`: Match to metadata
  - `mjd`: Observation time (Modified Julian Date)
  - `passband`: LSST filter index (0 = u, ..., 5 = y)
  - `flux`: Observed flux (corrected for Galactic extinction)
  - `flux_err`: Flux uncertainty
  - `detected_bool`: Detection flag from image subtraction
  
 ## 🔸 References
 More details are available in the these references:
 
 - `https://plasticc.org`
- `https://www.kaggle.com/competitions/PLAsTiCC-2018/overview`
- `R. Hložek et al, 2023, ApJS, 267, 25`
- `https://arxiv.org/pdf/1809.11145`
- `https://arxiv.org/pdf/1810.00001`

## Dataset Download
Next, we download the dataset from Google Drive and unzip it in the main folder.

In [None]:
def download_data(url, data_path, destdir):
  """ Download data """
  data_fullpath= os.path.join(destdir, data_path)
    
  print("Downloading file from url %s ..." % (url))
  urllib.request.urlretrieve(url, data_path)  
  print("DONE!")

  print("Moving file %s to dir %s ..." % (data_path, destdir))
  shutil.move(data_path, destdir)

  print("Unzipping dataset file %s ..." % (data_fullpath))
  gunzip(data_fullpath)


# - Download train metadata
train_metadata_url= "https://zenodo.org/records/2539456/files/plasticc_train_metadata.csv.gz?download=1"
train_metadata_gz_path= 'plasticc_train_metadata.csv.gz'
train_metadata_gz_fullpath= os.path.join(rundir, train_metadata_gz_path)
train_metadata_fullpath= os.path.join(rundir, 'plasticc_train_metadata.csv')
if not os.path.isfile(train_metadata_fullpath):
  download_data(train_metadata_url, train_metadata_gz_path, rundir)

# - Download train data
train_data_url= "https://zenodo.org/records/2539456/files/plasticc_train_lightcurves.csv.gz?download=1"
train_data_gz_path= 'plasticc_train_lightcurves.csv.gz'
train_data_gz_fullpath= os.path.join(rundir, train_data_gz_path)
train_data_fullpath= os.path.join(rundir, 'plasticc_train_lightcurves.csv')
if not os.path.isfile(train_data_fullpath):
  download_data(train_data_url, train_data_gz_path, rundir)

# - Download test metadata
test_metadata_url= "https://zenodo.org/records/2539456/files/plasticc_test_metadata.csv.gz?download=1"
test_metadata_gz_path= 'plasticc_test_metadata.csv.gz'
test_metadata_gz_fullpath= os.path.join(rundir, test_metadata_gz_path)
test_metadata_fullpath= os.path.join(rundir, 'plasticc_test_metadata.csv')
if not os.path.isfile(test_metadata_fullpath):
  download_data(test_metadata_url, test_metadata_gz_path, rundir)

# - Download test data (part 1)
test_data_url= "https://zenodo.org/records/2539456/files/plasticc_test_lightcurves_01.csv.gz?download=1"
test_data_gz_path= 'plasticc_test_lightcurves_01.csv.gz'
test_data_gz_fullpath= os.path.join(rundir, test_data_gz_path)
test_data_fullpath= os.path.join(rundir, 'plasticc_test_lightcurves_01.csv')
if not os.path.isfile(test_data_fullpath):
  download_data(test_data_url, test_data_gz_path, rundir)

## Dataset loading

### Loading train metadata

In [None]:
# - Read train metadata as panda data frame
train_metadata= pd.read_csv(train_metadata_fullpath)
print("--> Train metadata")
print(train_metadata)

### Loading train data

In [None]:
# - Load data
print(f"Loading train data from file {train_data_fullpath} ...")
train_data = pd.read_csv(train_data_fullpath)
print("train_data")
print(train_data)

### Splitting train/val sets
Let's reserve a small portion (10%) of the training dataset for validation scopes. Below, we split the original training dataset into train and validation data frames.

In [None]:
# - Split metadata in train/val sets
random_state= 42
test_size= 0.1
meta_df_train, meta_df_val = train_test_split(
  train_metadata,
  test_size=test_size,
  random_state=random_state,
  shuffle=True
)

# - Get object_ids of train/test splits
train_object_ids = meta_df_train['object_id'].unique()
val_object_ids = meta_df_val['object_id'].unique()

# - Use object IDs to split data_df
data_df_train = train_data[train_data['object_id'].isin(train_object_ids)]
data_df_val  = train_data[train_data['object_id'].isin(val_object_ids)]

print(f"#{len(meta_df_train)}/{len(meta_df_val)} data entries in train/val sets ...")

### Loading test metadata

In [None]:
# - Read train metadata as panda data frame
meta_df_test= pd.read_csv(test_metadata_fullpath)
print("--> Test metadata")
print(meta_df_test)

### Loading test data

In [None]:
# - Load data
print(f"Loading test data from file {test_data_fullpath} ...")
data_df_test = pd.read_csv(test_data_fullpath)
print("test_data")
print(data_df_test)

### Create PyTorch datasets

Define PyTorch dataset

In [None]:
def pad_sequences_numpy(sequences, maxlen=None, dtype='float32', padding='post', truncating='post', value=0.0):
  """
    Pads a list of 2D numpy arrays (sequence_len_i, n_features) to shape (N, maxlen, n_features).
    
    Parameters:
        sequences : list of np.ndarray of shape (Ti, D)
        maxlen    : int or None, length to pad/truncate to. If None, use max sequence length.
        dtype     : data type of output array
        padding   : 'pre' or 'post'
        truncating: 'pre' or 'post'
        value     : value used for padding

    Returns:
        np.ndarray of shape (N, maxlen, D)
  """
  num_samples = len(sequences)
  feature_dim = sequences[0].shape[1]
  lengths = [seq.shape[0] for seq in sequences]

  if maxlen is None:
    maxlen = max(lengths)

  padded = np.full((num_samples, maxlen, feature_dim), value, dtype=dtype)

  for i, seq in enumerate(sequences):
    if truncating == 'pre':
      trunc = seq[-maxlen:]
    else:
      trunc = seq[:maxlen]

    if padding == 'pre':
      padded[i, -len(trunc):] = trunc
    else:
      padded[i, :len(trunc)] = trunc

  return padded


def pad_single_sequence_numpy(sequence, maxlen, dtype='float32', padding='post', truncating='post', value=0.0):
    """
    Pads or truncates a 2D array (seq_len, n_features) to shape (maxlen, n_features).

    Parameters:
        sequence  : np.ndarray of shape (seq_len, n_features)
        maxlen    : int, the target sequence length
        dtype     : data type of the output array
        padding   : 'pre' or 'post' — where to pad (default: 'post')
        truncating: 'pre' or 'post' — where to truncate (default: 'post')
        value     : value to use for padding

    Returns:
        np.ndarray of shape (maxlen, n_features)
    """
    seq_len, n_features = sequence.shape
    result = np.full((maxlen, n_features), value, dtype=dtype)

    if seq_len >= maxlen:
        if truncating == 'pre':
            trunc = sequence[-maxlen:]
        else:  # 'post'
            trunc = sequence[:maxlen]
    else:
        trunc = sequence

    if padding == 'pre':
        result[-len(trunc):] = trunc
    else:
        result[:len(trunc)] = trunc

    return result

#####################################
##      DATASET
#####################################
class PlasticcDataset(Dataset):
  def __init__(
    self,
    data_df, # Pandas data frame
    meta_df, # Pandas data frame
    use_specz=False,
    extragalactic=None,  
    nmax=-1,
    max_seq_size=256,
    auto_set_max_seq_size=True,
    do_augmentation=False, # replace original data with augmented version with augment_prob
    augment_prob= 0.5,  
    augment_data=False, # increase size of original size by augment_factor
    augment_factor=25,
    drop_rate= 0.3
  ):
    # - Set options
    self.meta_df= meta_df
    self.data_df= data_df
    self.metadata= []
    self.data= []
    self.X_seq = None    # shape: (N, seq_len, 4)
    self.X_meta = None   # shape: (N, num_features)
    self.Y= None  # one-hot labels: (N, num_classes)
    self.Y_target= None # labels: (N, 1)
    self.wtable= None
    self.class_weights= None
    self.nmax= nmax
    self.use_specz= use_specz
    self.extragalactic= extragalactic
    self.classes = np.array([6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95, 99], dtype='int32')
    self.class_names = ['class_6','class_15','class_16','class_42','class_52','class_53','class_62','class_64','class_65','class_67','class_88','class_90','class_92','class_95','class_99']
    self.classid2label= {
      6: "PS-MULENS",# Point_source_mu-lensing
      15: "TDE", # Tidal disruption event
      16: "EBE", # Eclipsing binary event
      42: "SN-II", # Core-collapse supernova Type II
      52: "SN-Iax", # Supernova Type Ia-x
      53: "MIRA", # Mira variable
      62: "SN-Ibc", # Core-collapse supernova Type Ibc
      64: "KN", # Kilonova
      65: "M-DWARF", # M dwarf
      67: "SN-Ia-91bg", # Supernova Type Ia-91bg
      88: "AGN", # Active galactic nucleus
      90: "SN-Ia", # Supernova Type Ia
      92: "RR-LY", # RR Lyrae
      95: "SLSN", # Superluminous supernova
      99: "OTHER", # Other class
    }
    
    self.classid_remap= {
      6: 6,# Point_source_mu-lensing
      15: 15, # Tidal disruption event
      16: 16, # Eclipsing binary event
      42: 42, # Core-collapse supernova Type II
      52: 52, # Supernova Type Ia-x
      53: 53, # Mira variable
      62: 62, # Core-collapse supernova Type Ibc
      64: 64, # Kilonova
      65: 65, # M dwarf
      67: 67, # Supernova Type Ia-91bg
      88: 88, # Active galactic nucleus
      90: 90, # Supernova Type Ia
      92: 92, # RR Lyrae
      95: 95, # Superluminous supernova
      991: 99, # Microlens from binary lens
      992: 99, # Intermediate luminous optical transient
      993: 99, # Calcium-rich transient
      994: 99, # Pair instability SN
    }    
        
    self.nclasses= len(self.classes)
    self.class_weight_factors= np.array([2,2,1,1,1,1,1,2,1,1,1,1,1,2,2], dtype='float32')
    
    # LSST passbands (nm)  u    g    r    i    z    y      
    self.passbands = np.array([357, 477, 621, 754, 871, 1004], dtype='float32')
    
    self.max_seq_size= max_seq_size
    self.auto_set_max_seq_size= auto_set_max_seq_size
    self.seq_sizes= []
    
    self.augment_data= augment_data
    self.augment_factor= augment_factor
    self.drop_rate = drop_rate
    self.do_augmentation= do_augmentation
    self.augment_prob= augment_prob
    
    # - Read data
    self.__read_data()
    
    # - Load data
    self.__load_data()
    
  def __compute_wtable(self):
    """
      Compute:
        - wtable: class frequencies (N_class,)
        - class_weights: inverse frequency, normalized to sum to num_classes

      Returns:
        wtable (torch.Tensor), class_weights (torch.Tensor)
    """
    class_counts = self.Y.sum(dim=0)  # sum over all samples (across rows)
    total_samples = self.Y.shape[0]
    wtable = class_counts / total_samples
    wtable[self.nclasses-1]= 1.0
    
    # Inverse frequency as weight (avoid divide-by-zero)
    class_weights = 1.0 / (wtable + 1e-8)
    class_weights = class_weights * (len(wtable) / class_weights.sum())  # normalize to mean 1

    return wtable, class_weights

  def __get_sequence_data(self, inputdata, inputmetadata, augment=False):
    """ Return sequence data from raw plasticc data """    

    # - Copy input data
    data= copy.deepcopy(inputdata)
    meta= copy.deepcopy(inputmetadata)
    
    # - Retrieve original meta data
    z_photo= meta['hostgal_photoz'].iloc[0]         # photometric host-redshift (float32)
    zerr_photo= meta['hostgal_photoz_err'].iloc[0]  # uncertainty on photometric host-redshift
    ddf= meta['ddf_bool'].iloc[0]                   # boolean flag: 1 for DDF, 0 for WFD
    mwebv= meta['mwebv'].iloc[0]                    # Galactic E(B-V) extinction
    z_spec= meta['hostgal_specz'].iloc[0]           # accurate spectroscopic-redshift for small subset
    z= z_spec if self.use_specz else z_photo
    z_err= 0.0 if self.use_specz else zerr_photo
    
    # - Retrieve original seq data 
    mjd      = np.array(data['mjd'],      dtype='float32')
    band     = np.array(data['passband'], dtype='int32')
    flux     = np.array(data['flux'],     dtype='float32')
    flux_err = np.array(data['flux_err'], dtype='float32')
    detected = np.array(data['detected_bool'], dtype='float32')
    seq_size= mjd.shape[0]
    
    # - Augment features by dropping some records?
    if augment:
      # - Drop randomly some sequency observations  
      mjd_aug= []
      band_aug= []  
      flux_aug= []
      flux_err_aug= []
      for k in range(seq_size):
        if random.uniform(0, 1) >= self.drop_rate:
          mjd_aug.append(mjd[k])
          band_aug.append(band[k])
          flux_aug.append(flux[k])
          flux_err_aug.append(flux_err[k])
       
      mjd_aug= np.array(mjd_aug, dtype='float32')
      band_aug= np.array(band_aug, dtype='int32')
      flux_aug= np.array(flux_aug, dtype='float32')
      flux_err_aug= np.array(flux_err_aug, dtype='float32') 
      
      # - Randomize redshift
      z_aug = random.normalvariate(z, z_err / 1.5)
      z_aug = max(z_aug, 0)
      z_aug = min(z_aug, 5)
      
      # - Randomize flux
      flux_aug = np.random.normal(flux_aug, flux_err_aug / 1.5)
    
      # - Override original data with augmented versions
      z= z_aug
      mjd= mjd_aug
      band= band_aug  
      flux= flux_aug
      flux_err= flux_err_aug
        
    # - Scale/modify feature data
    #   1) Convert time at observer (t) to time at source (t0): t0= t/(1+z)
    #   2) Convert wavelength at observer (l) to wavelength at source (l0): l0=l/(1+z)
    seq_size= mjd.shape[0]
    mjd -= mjd[0]
    mjd /= 100 # Earth time shift in day*100
    mjd /= (z + 1) # Object time shift in day*100
    tdiff= np.ediff1d(mjd, to_begin = [0])
    flux_max = np.max(flux)
    flux_min = np.min(flux)
    flux_norm= flux_max - flux_min 
    flux_pow = math.log2(flux_norm)   
    received_wavelength = self.passbands[band] # Earth wavelength in nm
    source_wavelength = received_wavelength / (z + 1) # Object wavelength in nm
    source_wavelength/= 1000.
    
    # - Set sequence data features
    features_seq= np.zeros( (seq_size, 4), dtype = 'float32')
    features_seq[:,0]= tdiff
    features_seq[:,1]= flux/flux_norm
    features_seq[:,2]= flux_err/flux_norm
    features_seq[:,3]= source_wavelength
    ##features_seq[:,4]= detected
       
    # - Set metadata features
    features_meta= np.zeros(5, dtype = 'float32')
    features_meta[0]= ddf
    features_meta[1]= z
    features_meta[2]= z_err
    features_meta[3]= mwebv
    features_meta[4]= flux_pow / 10
    
    # - Check for NaNs
    if np.any(np.isnan(features_seq)):
      print("WARN: features_seq has NaNs values!")   
    if np.any(np.isnan(features_meta)):
      print("WARN: features_meta has NaNs values!") 
    
    return features_seq, features_meta
    
    
  def __read_data(self):
    """ Read data/metadata from input data frames """
    
    # - Group data by object_id
    groups= self.data_df.groupby('object_id')
    print(f"Reading {len(groups)} data entries ...")
    
    # - Read data and apply selection
    self.metadata= []
    self.data= []
                       
    for idx, g in enumerate(groups):
      nentries= idx+1  
      if idx % 1000 == 0:
        print('Reading data {0}'.format(idx), end='\r') 
        
      if self.nmax!=-1 and nentries >= self.nmax:
        print(f'Reached data sample limit {self.nmax}...stop reading data')  
        break
        
      # - Find data with object_id
      id = g[0]
      data_entry= g[1]
      meta = self.meta_df.loc[self.meta_df['object_id'] == id]  
      z_photo= meta['hostgal_photoz'].iloc[0] 
    
      # - Skip source with invalid redshift?
      if self.extragalactic == True and z_photo==0:
        continue

      if self.extragalactic == False and z_photo>0:
        continue
        
      # - Store selected data
      self.metadata.append(meta)
      self.data.append(data_entry)
    
    print(f'{len(self.data)} data samples read...')  
    
    
  def __load_data_item(self, idx, augment=False, random_augment=False):  
    """ Load data item """
    
    meta_item= self.metadata[idx]
    data_item= self.data[idx]
    
    # - Get target id
    if 'target' in meta_item:
      class_id= int(meta_item['target'].iloc[0])
      if class_id==0:
        class_id= int(meta_item['true_target'].iloc[0])
      class_id= self.classid_remap[class_id] # remap id (done for the 99x other classes)
      target_id= np.where(self.classes == class_id)[0][0]
    else:
      target_id= len(self.classes) - 1  # interpret as class 99
    
    # - Get sequence & meta pars
    augment_sample= False
    if augment:
      if random_augment:
        augment_sample= (random.uniform(0, 1) >= self.augment_prob)
      else:
        augment_sample= True     
    features_seq, features_meta= self.__get_sequence_data(data_item, meta_item, augment=augment_sample)
    seq_size= features_seq.shape[0]
    
    # - Pad sequence data to desired max length
    features_seq_padded= pad_single_sequence_numpy(features_seq, maxlen=self.max_seq_size)
    
    return features_seq_padded, features_meta, target_id
    
    
  def __load_data(self):  
    """ Load data/metadata from files """
    
    features_seq_all= []
    features_meta_all= []
    target_ids= []
    
    for i in range(len(self.data)):
      meta_item= self.metadata[i]
      data_item= self.data[i]
      
      # - Get target id
      if 'target' in meta_item:
        class_id= int(meta_item['target'].iloc[0])
        if class_id==0:
          class_id= int(meta_item['true_target'].iloc[0])
        class_id= self.classid_remap[class_id] # remap id (done for the 99x other classes)
        target_id= np.where(self.classes == class_id)[0][0]
      else:
        target_id= len(self.classes) - 1  # interpret as class 99
      
      # - Read original data 
      features_seq, features_meta= self.__get_sequence_data(data_item, meta_item, augment=False)
      seq_size= features_seq.shape[0]
      features_seq_all.append(features_seq)
      features_meta_all.append(features_meta)
      self.seq_sizes.append(seq_size)
      target_ids.append(target_id) 
    
      # - Augment data, e.g. produce new samples by randomly dropping observations
      if self.augment_data:
        for j in range(self.augment_factor):
          features_seq_aug, features_meta_aug= self.__get_sequence_data(data_item, meta_item, augment=True)
          seq_size_aug= features_seq_aug.shape[0]
          features_seq_all.append(features_seq_aug)
          features_meta_all.append(features_meta_aug)
          self.seq_sizes.append(seq_size_aug)
          target_ids.append(target_id) 
       
    seq_min_size= np.min(self.seq_sizes)
    seq_max_size= np.max(self.seq_sizes)
    print(f"#{len(features_seq_all)} data added with seq range ({seq_min_size}, {seq_max_size})...")
    
    # - Find sequence truncation point (assuming a power of 2 larger than max seq size) 
    if self.auto_set_max_seq_size:
      seq_size_opt= 2 ** math.ceil(math.log2(seq_max_size))
    else:
      seq_size_opt= self.max_seq_size
    
    print(f"Padding sequence data to a size of {seq_size_opt} ...")
    features_seq_padded= pad_sequences_numpy(features_seq_all, maxlen=seq_size_opt)
    print("features_seq_padded.shape")
    print(features_seq_padded.shape)
    
    # - Convert data to tensors
    self.X_seq= torch.from_numpy(features_seq_padded)
    self.X_meta= torch.from_numpy(np.array(features_meta_all))
    self.Y_target= torch.from_numpy(np.array(target_ids))
    #self.Y= F.one_hot(self.Y_target, num_classes=-1)
    self.Y= F.one_hot(self.Y_target, num_classes=len(self.classes))
    
    print("self.X_seq.shape")
    print(self.X_seq.shape)
    print("self.X_meta.shape")
    print(self.X_meta.shape)
    print("self.Y.shape")
    print(self.Y.shape)
    
    # - Compute wtable
    self.wtable, self.class_weights= self.__compute_wtable()
    print("self.wtable")
    print(self.wtable)
    print("self.class_weights")
    print(self.class_weights)
    
    
  def __load_data_old(self):
    """ Load data/metadata from files """      

    # - Group data by object_id
    groups = self.data_df.groupby('object_id')
    print(f"Reading {len(groups)} data entries ...")
    
    features_seq_all= []
    features_meta_all= []
    target_ids= []
    seq_min_size= 1.e+99
    seq_max_size= -1
    
    for idx, g in enumerate(groups):
      # - Check max number of reads  
      nentries= idx+1  
      if idx % 1000 == 0:
        print('Converting data {0}'.format(idx), end='\r') 
        
      if self.nmax!=-1 and nentries >= self.nmax:
        print(f'Reached data sample limit {self.nmax}...stop reading data')  
        break  
        
      # - Find data with object_id
      id = g[0]
      meta = self.meta_df.loc[self.meta_df['object_id'] == id]
      
      z_photo= meta['hostgal_photoz'].iloc[0]         # photometric host-redshift (float32)
      #zerr_photo= meta['hostgal_photoz_err'].iloc[0]  # uncertainty on photometric host-redshift
      #ddf= meta['ddf_bool'].iloc[0]                   # boolean flag: 1 for DDF, 0 for WFD
      #mwebv= meta['mwebv'].iloc[0]                    # Galactic E(B-V) extinction
      #z_spec= meta['hostgal_specz'].iloc[0]           # accurate spectroscopic-redshift for small subset
      #z= z_spec if self.use_specz else z_photo
      #z_err= 0.0 if self.use_specz else zerr_photo
    
      # - Skip source with invalid redshift?
      if self.extragalactic == True and z_photo==0:
        continue

      if self.extragalactic == False and z_photo>0:
        continue

      # - Get target id
      if 'target' in meta:
        class_id= int(meta['target'].iloc[0])
        if class_id==0:
          class_id= int(meta['true_target'].iloc[0])
        class_id= self.classid_remap[class_id] # remap id (done for the 99x other classes)
        #print(f"class_id={class_id}")
        target_id= np.where(self.classes == class_id)[0][0]
      else:
        target_id= len(self.classes) - 1  # interpret as class 99
        
      target_ids.append(target_id)
        
      # - Get sequence & meta pars
      features_seq, features_meta= self.__get_sequence_data(g[1], meta, augment=False)
      seq_size= features_seq.shape[0]
      features_seq_all.append(features_seq)
      features_meta_all.append(features_meta)
      self.seq_sizes.append(seq_size)
        
      # - Augment data, e.g. produce new samples by randomly dropping observations
      if self.augment_data:
        for j in range(self.augment_factor):
          features_seq_aug, features_meta_aug= self.__get_sequence_data(g[1], meta, augment=True)
          seq_size_aug= features_seq_aug.shape[0]
          features_seq_all.append(features_seq_aug)
          features_meta_all.append(features_meta_aug)
          self.seq_sizes.append(seq_size_aug)
          target_ids.append(target_id)
      
    
    seq_min_size= np.min(self.seq_sizes)
    seq_max_size= np.max(self.seq_sizes)
    print(f"#{len(features_seq_all)} data added with seq range ({seq_min_size}, {seq_max_size})...")
    
    # - Find sequence truncation point (assuming a power of 2 larger than max seq size) 
    if self.auto_set_max_seq_size:
      seq_size_opt= 2 ** math.ceil(math.log2(seq_max_size))
    else:
      seq_size_opt= self.max_seq_size
    
    print(f"Padding sequence data to a size of {seq_size_opt} ...")
    features_seq_padded= pad_sequences_numpy(features_seq_all, maxlen=seq_size_opt)
    print("features_seq_padded.shape")
    print(features_seq_padded.shape)
    
    # - Convert data to tensors
    self.X_seq= torch.from_numpy(features_seq_padded)
    self.X_meta= torch.from_numpy(np.array(features_meta_all))
    self.Y_target= torch.from_numpy(np.array(target_ids))
    #self.Y= F.one_hot(self.Y_target, num_classes=-1)
    self.Y= F.one_hot(self.Y_target, num_classes=len(self.classes))
    
    print("self.X_seq.shape")
    print(self.X_seq.shape)
    print("self.X_meta.shape")
    print(self.X_meta.shape)
    print("self.Y.shape")
    print(self.Y.shape)
    
    # - Compute wtable
    self.wtable, self.class_weights= self.__compute_wtable()
    print("self.wtable")
    print(self.wtable)
    print("self.class_weights")
    print(self.class_weights)
    
    
  def __len__(self):
    #return len(self.X_seq)
    return len(self.data) 
    
  def __getitem__(self, idx):
    #return self.X_seq[idx], self.X_meta[idx], self.Y[idx]
    
    # - Load item
    features_seq, features_meta, target_id= self.__load_data_item(
      idx, 
      augment=self.do_augmentation, 
      random_augment=True
    )
    
    # - Convert data to tensors
    X_seq= torch.from_numpy(features_seq)
    X_meta= torch.from_numpy(np.array(features_meta))
    Y_target= torch.from_numpy(np.array(target_id))
    #Y= F.one_hot(Y_target, num_classes=-1)
    Y= F.one_hot(Y_target, num_classes=len(self.classes))
    
    return X_seq, X_meta, Y

Create train dataset

In [None]:
dataset_train= PlasticcDataset(
  data_df_train, 
  meta_df_train,
  use_specz=False,
  extragalactic=None,  
  nmax=-1,
  max_seq_size=352,
  auto_set_max_seq_size=False,
  do_augmentation=True,
  augment_prob= 0.5,  
  augment_data=False,
  augment_factor=25,
  drop_rate= 0.3
)

#print("--> wtable")
#print(wtable)
#print("--> wtable (dataset)")
#print(dataset_train.wtable)

Plot the distribution of time series lengths

In [None]:
seq_lengths= dataset_train.seq_sizes
targets= dataset_train.Y_target
seq_min_size= np.min(seq_lengths)
seq_max_size= np.max(seq_lengths)
seq_mean_size= np.mean(seq_lengths)
seq_median_size= np.median(seq_lengths)
seq_std_size= np.std(seq_lengths)
print(f"seq stats: min/max={seq_min_size}/{seq_max_size}, mean/median={seq_mean_size}/{seq_median_size}, std={seq_std_size}")

plt.scatter(targets, seq_lengths)

Load validation dataset

In [None]:
dataset_val= PlasticcDataset(
  data_df_val, 
  meta_df_val,
  use_specz=False,
  extragalactic=None,  
  nmax=-1,
  max_seq_size=352,
  auto_set_max_seq_size=False,
  do_augmentation=False,  
  augment_prob=0.5,  
  augment_data=False,
  augment_factor=1,
  drop_rate= 0.3  
)

Load test dataset

In [None]:
dataset_test= PlasticcDataset(
  data_df_test, 
  meta_df_test,
  use_specz=False,
  extragalactic=None,  
  nmax=-1,
  max_seq_size=352,
  auto_set_max_seq_size=False,
  do_augmentation=False,
  augment_prob=0.5,
  augment_data=False,
  augment_factor=1,
  drop_rate= 0.3  
)

### Create dataloaders

In [None]:
batch_size= 64
dl_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dl_val   = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)
dl_test   = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

## Model

### Define the classifier model
Define the RNN model. It takes two inputs:

- `seq`: Time series, shape (batch_size, time_steps, n_features)
- `meta`: Meta data parameters, shape (batch_size, n_meta_features)

In [None]:
class PlasticcClassifier(torch.nn.Module):
  def __init__(
    self, 
    n_features, 
    n_meta_features, 
    hidden_size=64, 
    num_layers=2, 
    num_classes=15, 
    bidirectional=False, 
    dropout=0.5,
    fc_hidden_size=128  
  ):
    super().__init__()

    self.num_directions = 2 if bidirectional else 1

    self.gru = torch.nn.GRU(
      input_size=n_features,
      hidden_size=hidden_size,
      num_layers=num_layers,
      batch_first=True,
      bidirectional=bidirectional,
      dropout=dropout if num_layers > 1 else 0.0
    )

    self.dropout = torch.nn.Dropout(dropout)

    self.classifier = torch.nn.Sequential(
      torch.nn.Linear(hidden_size * self.num_directions + n_meta_features, fc_hidden_size),
      torch.nn.ReLU(),
      torch.nn.Dropout(dropout),
      torch.nn.Linear(fc_hidden_size, num_classes),
    )

  def forward(self, x_seq, x_meta):
    rnn_out, _ = self.gru(x_seq)                      # (batch, seq_len, hidden*2)
    pooled, _ = torch.max(rnn_out, dim=1)             # (batch, hidden*2)
    x = torch.cat([pooled, x_meta], dim=1)            # (batch, hidden*2 + meta_features)
    logits = self.classifier(x)
    probs = F.softmax(logits, dim=1)                  # one-hot output
    
    return probs, logits
    #return probs

Create classifier instance

In [None]:
# - Create model
n_features= dataset_train.X_seq.shape[2]
n_meta_features= dataset_train.X_meta.shape[1]
num_classes= dataset_train.Y.shape[1]
seq_length= dataset_train.X_seq.shape[1]
print(f"seq_length={seq_length}")
print(f"n_features={n_features}")
print(f"n_meta_features={n_meta_features}")
print(f"num_classes={num_classes}")

model = PlasticcClassifier(
  n_features=n_features,
  n_meta_features=n_meta_features,
  hidden_size=128,
  num_layers=2,
  num_classes=num_classes,
  bidirectional=True,
  dropout=0.5,
  fc_hidden_size=128  
)

# - Print model structure
summary(
  model, 
  #input_data=[
  #  torch.randn(batch_size, seq_length, n_features),  # x_seq: (batch, seq_len, n_features)
  #  torch.randn(batch_size, n_meta_features)       # x_meta: (batch, n_meta_features)
  #]
  input_size=[(batch_size, seq_length, n_features), (batch_size, n_meta_features)] 
)

### Model training

#### PLAsTiCC Weighted Log-Loss Metric

The PLAsTiCC challenge uses a **weighted multi-class logarithmic loss** to evaluate classification performance. This metric penalizes incorrect predictions and accounts for **class imbalance** by assigning each class a weight.

The loss is defined as:

$$
L = -\frac{\sum_{i=1}^{M} w_i \cdot \frac{1}{N_i} \sum_{j=1}^{N_i} y_{ij}\log(p_{ij})}{\sum_{i=1}^{M} w_i} 
$$

Where:

- $N_i$ is the number of objects with class $i$
- $w_i$ is the weight for class $i$
- $y_{ij}$ is 1 if observation $i$ belongs to class $j$ and 0 otherwise
- $p_{ij}$ is the predicted probability that observation $i$ belongs to class $j$

**Notes**
- The loss is computed per class, then **averaged across all classes**, weighted by the importance $w_i$ (= 1 for most classes, 2 for rare objects). This ensures that rare but important classes contribute proportionally.
- Predictions are clipped to the range $([1e-15, 1 - 1e-15])$ to prevent $\log(0)$.
- The sum of class weights is used to normalize the final score.
- During training, class 99 (the "none of the above" class) was excluded and evaluated separately.

Define the loss function as required by the PLASTICC challenge

In [None]:
def multi_weighted_logloss(y_true, y_pred, class_weights, normalize_batch=True):
  """
    multi logloss for PLAsTiCC challenge
    Adapted from TF version: https://www.kaggle.com/ogrellier
  
    - y_true/y_pred are one-hot encoded
    - class_weights=[2,2,1,1,1,1,1,2,1,1,1,1,1,2,2] (see challenge result paper)
  """ 
    
  # - Normalize rows and limit y_preds to eps, 1-eps    
  eps = 1e-15
  preds = torch.clamp(y_pred, min=eps, max=1 - eps)

  # - Transform to log
  log_preds = torch.log(y_pred)  # (N, M)

  # (1) Per-class weighted log loss: sum across samples
  y_log_ones = torch.sum(y_true * log_preds, dim=0)              # (M,)
  nb_pos = torch.sum(y_true, dim=0)                              # (M,)
  nb_pos = torch.where(nb_pos == 0, torch.ones_like(nb_pos), nb_pos)

  y_w = y_log_ones * class_weights / nb_pos                       # (M,)

  loss = -torch.sum(y_w) / torch.sum(class_weights)
  if normalize_batch:
    loss = loss / y_pred.shape[0]  # normalize by batch size
    
  if not torch.isfinite(loss):
    print("⚠️ Warning: loss is NaN or Inf")
    print("y_pred finite? ")
    print(torch.all(torch.isfinite(y_pred)))
    print("y_true finite? ")
    print(torch.all(torch.isfinite(y_true)))
    print("log_preds finite? ")
    print(torch.all(torch.isfinite(log_preds)))
    print("y_log_ones finite? ")
    print(torch.all(torch.isfinite(y_log_ones)))
    print("y_w finite? ")
    print(torch.all(torch.isfinite(y_w)))
    
  return loss

In [None]:
def multi_weighted_logloss_stable(y_true, logits, class_weights, normalize_batch=True):
    """
    Numerically stable PLAsTiCC log-loss using log_softmax.

    Args:
        y_true: one-hot encoded targets (N, num_classes)
        logits: raw output of model before softmax (N, num_classes)
        class_weights: tensor of class weights (num_classes,)

    Returns:
        Scalar loss (float)
    """
    # - Use log_softmax for numerical stability
    log_probs = F.log_softmax(logits, dim=1)  # (N, num_classes)

    # - Multiply by one-hot true labels → picks correct class log-prob
    y_log_ones = torch.sum(y_true * log_probs, dim=0)  # (num_classes,)

    # - Number of positive examples per class
    nb_pos = torch.sum(y_true, dim=0)
    nb_pos = torch.where(nb_pos == 0, torch.ones_like(nb_pos), nb_pos)  # avoid divide-by-zero

    # - Weighted log-loss per class
    y_w = y_log_ones * class_weights / nb_pos

    loss = -torch.sum(y_w) / torch.sum(class_weights)
    if normalize_batch:
      loss = loss / y_pred.shape[0]  # normalize by batch size
    
    if not torch.isfinite(loss):
      print("⚠️ Warning: loss is NaN or Inf")
      print("logits finite? ")
      print(torch.all(torch.isfinite(logits)))
      print("y_true finite? ")
      print(torch.all(torch.isfinite(y_true)))
      print("log_probs finite? ")
      print(torch.all(torch.isfinite(log_probs)))
      print("y_log_ones finite? ")
      print(torch.all(torch.isfinite(y_log_ones)))
      print("y_w finite? ")
      print(torch.all(torch.isfinite(y_w)))
    
    return loss

Define a function to initialize weights before training.

In [None]:
def initialize_weights(model):
  """ Applies custom weight initialization to layers in the model """
  for m in model.modules():
    if isinstance(m, torch.nn.Linear):
      torch.nn.init.xavier_uniform_(m.weight)  # or kaiming_uniform_
      if m.bias is not None:
        torch.nn.init.zeros_(m.bias)

    elif isinstance(m, torch.nn.GRU):
      for name, param in m.named_parameters():
        if 'weight_ih' in name:
          torch.nn.init.xavier_uniform_(param.data)
        elif 'weight_hh' in name:
          torch.nn.init.orthogonal_(param.data)
        elif 'bias' in name:
          torch.nn.init.zeros_(param.data)

Define the training loop

In [None]:
class AverageMeter:
  def __init__(self):
    self.reset()

  def reset(self):
    self.sum = 0
    self.count = 0

  def update(self, value, n=1):
    self.sum += value * n
    self.count += n

  @property
  def avg(self):
    return self.sum / self.count if self.count > 0 else 0


def train_model(
    model,
    dl_train,
    dl_val=None,
    num_epochs=1,
    lr=1e-3,
    use_lr_scheduler=False,
    warmup_ratio=0.1,
    class_weights=None,
    clip_grad=False,
    max_grad_norm=5,
    outfile_model= "model.pth",
    outfile_weights= "model_weights.pth"
):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model.to(device)

  # - Check class weights  
  if class_weights is None:
    raise ValueError("class_weights must be provided")

  class_weights = class_weights.to(device)
  
  # - Init optimizer
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  scheduler = None
  if use_lr_scheduler:
    num_training_steps = num_epochs * len(dl_train)
    num_warmup_steps = int(warmup_ratio * num_training_steps)
    print(f"--> #steps warmup/tot: {num_warmup_steps}/{num_training_steps}")
    scheduler = get_cosine_schedule_with_warmup(
      optimizer=optimizer,
      num_warmup_steps=num_warmup_steps,
      num_training_steps=num_training_steps
    )  

  # - Init metrics
  loss_meter = AverageMeter()
  acc_metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)
  f1_metric = torchmetrics.F1Score(task="multiclass", num_classes=num_classes, average="macro").to(device)
  confusion_matrix_metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize="true").to(device)
    
  val_acc_metric = val_f1_metric = None
  if dl_val is not None:
    val_loss_meter = AverageMeter()    
    val_acc_metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)
    val_f1_metric = torchmetrics.F1Score(task="multiclass", num_classes=num_classes, average="macro").to(device)
    val_confusion_matrix_metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize="true").to(device)
    
  history = {
    "loss_train": [],
    "acc_train": [],
    "f1score_train": [],
    "cm_train": None,
    "cm_metric_train": None,  
    "loss_val": [],
    "acc_val": [],
    "f1score_val": [],
    "cm_val": None,
    "cm_metric_val": None,   
  }

  # - Start training loop
  for epoch in range(num_epochs):
    model.train()
    
    # - Reset avg metrics
    progress = tqdm(dl_train, desc=f"Epoch {epoch+1} [Train]", leave=False)
    loss_meter.reset()
    acc_metric.reset()
    f1_metric.reset()
    confusion_matrix_metric.reset()
    
    total_loss = 0.0
    y_true_all = []
    y_pred_all = []

    # - Run batch loop
    for x_seq, x_meta, y in progress:    
      x_seq, x_meta, y = x_seq.to(device), x_meta.to(device), y.to(device)

      optimizer.zero_grad()
      #y_pred = model(x_seq, x_meta)
      y_pred, logits = model(x_seq, x_meta)    

      #loss = multi_weighted_logloss(y, y_pred, class_weights, normalize_batch=False)
      loss = multi_weighted_logloss_stable(y, logits, class_weights, normalize_batch=False)
      
      # ✅ Check 1: Loss is finite (not NaN or Inf)
      if not torch.isfinite(loss):
        print("⚠️ Warning: loss is NaN or Inf. Skipping this batch.")
        print("x_seq finite? ")
        print(torch.all(torch.isfinite(x_seq)))
        print("x_meta finite? ")
        print(torch.all(torch.isfinite(x_meta)))  
        continue  # skip backprop for this batch
        
      loss.backward()
    
      # ✅ Check 2: Optional: Print max gradient norm for debugging
      total_norm = 0
      for p in model.parameters():
        if p.grad is not None:
          param_norm = p.grad.data.norm(2)
          total_norm += param_norm.item() ** 2
      total_norm = total_norm ** 0.5
      if total_norm > 1e3:
        print(f"⚠️ High gradient norm: {total_norm:.2f}")

      # ✅ Check 3: Clip gradients to prevent explosion
      if clip_grad:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
    
      #if clip_grad:  
      #  #`clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
      #  torch.nn.utils.clip_grad_norm_(model.parameters(), gradclip)
      #  for p in model.parameters():
      #    p.data.add_(p.grad, alpha=-lr)  
    
      optimizer.step()
      if scheduler:
        scheduler.step()  # 🔁 per-step LR update
        current_lr = scheduler.get_last_lr()[0]
      else:
        current_lr= lr
        
      # - Update loss and metrics 
      target_pred = y_pred.argmax(dim=1)
      target_true= y.argmax(dim=1)
      loss_meter.update(loss.item(), x_seq.size(0))
      acc_metric.update(target_pred, target_true)
      f1_metric.update(target_pred, target_true)
      confusion_matrix_metric.update(target_pred, target_true)  
        
      total_loss += loss.item()
      y_true_all.append(target_true.cpu())
      y_pred_all.append(target_pred.cpu())  
    
      # - Update progress bar
      progress.set_postfix({
        "lr": f"{current_lr:.6f}",
        "loss": f"{loss_meter.avg:.4f}",
        "acc": f"{acc_metric.compute().item():.4f}",
        "f1": f"{f1_metric.compute().item():.4f}"
      })  

    # - Compute average metrics (v1)
    avg_train_loss = total_loss / len(dl_train)
    y_true_all = torch.cat(y_true_all)
    y_pred_all = torch.cat(y_pred_all)
    train_acc = (y_true_all == y_pred_all).float().mean().item()
    train_f1 = f1_score(y_true_all, y_pred_all, average='macro', zero_division=0)

    # - Compute average metrics (v2)
    avg_train_loss_v2 = loss_meter.avg
    train_acc_v2 = acc_metric.compute().item()
    train_f1_v2 = f1_metric.compute().item()
    confusion_matrix= confusion_matrix_metric.compute().cpu().numpy()
    history["loss_train"].append(avg_train_loss_v2)
    history["acc_train"].append(train_acc_v2)
    history["f1score_train"].append(train_f1_v2) 
    history["cm_train"]= confusion_matrix
    history["cm_metric_train"]= confusion_matrix_metric
        
    #print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}", end='')
    print(f"Epoch [{epoch+1}/{num_epochs}]: lr={current_lr:.6f}, loss={avg_train_loss:.4f}, {avg_train_loss_v2:.4f} | acc={train_acc:.4f}, {train_acc_v2:.4f} | f1={train_f1:.4f}, {train_f1_v2:.4f}", end='')

    
    if dl_val is not None:
      model.eval()
    
      # - Init val metrics
      val_loss_meter.reset()
      val_acc_metric.reset()
      val_f1_metric.reset()
      val_confusion_matrix_metric.reset()
      val_progress = tqdm(dl_val, desc=f"Epoch {epoch+1} [Val]", leave=False)
      val_loss = 0.0
      y_true_val = []
      y_pred_val = []

      with torch.no_grad():
        for x_seq, x_meta, y in val_progress:    
          x_seq, x_meta, y = x_seq.to(device), x_meta.to(device), y.to(device)
          #y_pred = model(x_seq, x_meta)
          y_pred, logits = model(x_seq, x_meta)  

          #loss = multi_weighted_logloss(y, y_pred, class_weights, normalize_batch=False)
          loss = multi_weighted_logloss_stable(y, logits, class_weights, normalize_batch=False)
        
          # - Update loss and accuracy  
          val_loss_meter.update(loss.item(), x_seq.size(0))
          target_pred = y_pred.argmax(dim=1)
          target_true = y.argmax(dim=1)  
          val_acc_metric.update(target_pred, target_true)
          val_f1_metric.update(target_pred, target_true)
          val_confusion_matrix_metric.update(target_pred, target_true) 
          val_loss += loss.item()

          y_true_val.append(y.argmax(dim=1).cpu())
          y_pred_val.append(y_pred.argmax(dim=1).cpu())
  
          # - Update progress bar
          val_progress.set_postfix({"loss": f"{val_loss_meter.avg:.4f}"})
        
      # - Compute average metrics (v1)    
      avg_val_loss = val_loss / len(dl_val)
      y_true_val = torch.cat(y_true_val)
      y_pred_val = torch.cat(y_pred_val)
      val_acc = (y_true_val == y_pred_val).float().mean().item()
      val_f1 = f1_score(y_true_val, y_pred_val, average='macro', zero_division=0)

      # - Compute average metrics (v2)    
      avg_val_loss_v2 = val_loss_meter.avg
      val_acc_v2 = val_acc_metric.compute().item()
      val_f1_v2 = val_f1_metric.compute().item()
      val_confusion_matrix= val_confusion_matrix_metric.compute().cpu().numpy()  
      history["loss_val"].append(avg_val_loss_v2)
      history["acc_val"].append(val_acc_v2)
      history["f1score_val"].append(val_f1_v2)  
      history["cm_val"]= val_confusion_matrix
      history["cm_metric_val"]= val_confusion_matrix_metric
    
      print(f" | Val Loss: {avg_val_loss:.4f}, {avg_val_loss_v2:.4f} | Val Acc: {val_acc:.4f}, {val_acc_v2:.4f} | Val F1: {val_f1:.4f}, {val_f1_v2:.4f}")
    else:
      print()

  # - Save final model
  print(f"\n✅ Model checkpoint saved to: {outfile_weights}")
  torch.save(model.state_dict(), outfile_weights)  
  torch.save(model, outfile_model)

  print("Training complete.")
  return history

Initialize weights and start training

In [None]:
# - Initialize weights
torch.manual_seed(10)
initialize_weights(model)

# - Run train
nepochs= 100
lr= 1e-3
warmup_ratio= 0.1
class_weights= torch.from_numpy(dataset_train.class_weight_factors)
outfile_weights= os.path.join(rundir, "model_weights.pth")
outfile_model= os.path.join(rundir, "model.pth")

metric_hist= train_model(
  model, 
  dl_train=dl_train, 
  dl_val=dl_val, 
  num_epochs=nepochs,
  lr=lr,
  use_lr_scheduler=False,  
  warmup_ratio=warmup_ratio,
  class_weights=class_weights,
  clip_grad=False,
  max_grad_norm=5,
  outfile_model=outfile_model,  
  outfile_weights=outfile_weights
)

Let’s plot the training and validation metrics after the training run is complete.

In [None]:
def draw_metric_hist(metric_hist):
  
  epochs = np.arange(1, len(metric_hist["loss_train"]) + 1)
  fig = plt.figure(figsize=(14, 5))

  # - Plot train/val loss
  ax1 = fig.add_subplot(1, 2, 1)
  ax1.plot(epochs, metric_hist["loss_train"], '-o', label='Train Loss')
  ax1.plot(epochs, metric_hist["loss_val"], '--<', label='Validation Loss')
  ax1.set_title("Loss Over Epochs", fontsize=14)
  ax1.set_xlabel("Epoch", fontsize=12)
  ax1.set_ylabel("Loss", fontsize=12)
  ax1.legend(fontsize=11)
  ax1.grid(True) 
    
  # - Plot acc/f1score
  ax2 = fig.add_subplot(1, 2, 2)
  ax2.set_ylim(0, 1)
  ax2.plot(epochs, metric_hist["acc_train"], '-o', label='Train Accuracy')
  ax2.plot(epochs, metric_hist["acc_val"], '--<', label='Validation Accuracy')
  ax2.plot(epochs, metric_hist["f1score_train"], '-*', label='Train F1-score')
  ax2.plot(epochs, metric_hist["f1score_val"], '-->', label='Validation F1-score')
  ax2.set_title("Accuracy and F1-score", fontsize=14)
  ax2.set_xlabel("Epoch", fontsize=12)
  ax2.set_ylabel("Score", fontsize=12)
  ax2.legend(fontsize=11)
  ax2.grid(True)

  plt.tight_layout()
  plt.show()

# - Print & plot metrics
print("== Training Metrics ==")
#print(metric_hist)

draw_metric_hist(metric_hist)

# - Draw confusion matrix
fig, ax = plt.subplots(figsize=(20,20))
metric_hist["cm_metric_train"].plot(ax=ax)
#fig_, ax_ = metric_hist["cm_metric_train"][-1].plot()

### Model evaluation
We now evaluate the trained model on the test set.

Load the model from the saved checkpoint.

In [None]:
# - Load model
#torch.serialization.add_safe_globals([PlasticcClassifier])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_trained= torch.load(outfile_model, weights_only=False)
model.load_state_dict(torch.load(outfile_weights, weights_only=True))
model.to(device)
model.eval()

Run model prediction on test set

In [None]:
def predict_model(model, dl_test, class_weights):
    """
    Perform inference on a test set.

    Args:
        model: Trained AstroMultiInputClassifier
        dl_test: DataLoader for the test set

    Returns:
        all_probs: Tensor of predicted class probabilities (N, num_classes)
        all_preds: Tensor of predicted class indices (N,)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    all_probs = []
    all_preds = []
    all_targets = []
    progress = tqdm(dl_test, desc="[Test]", leave=False)
    
    with torch.no_grad():
        for x_seq, x_meta, y in progress:
            x_seq = x_seq.to(device)
            x_meta = x_meta.to(device)
            y= y.to(device)

            #y_pred = model(x_seq, x_meta)  # (batch, num_classes)
            y_pred, logits = model(x_seq, x_meta)  # (batch, num_classes)
            all_probs.append(y_pred.cpu())
            all_preds.append(y_pred.argmax(dim=1).cpu())
            all_targets.append(y.argmax(dim=1).cpu())  # from one-hot

    probs = torch.cat(all_probs, dim=0)
    preds = torch.cat(all_preds, dim=0)
    targets = torch.cat(all_targets, dim=0)

    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro', zero_division=0)

    # One-hot targets again for PLAsTiCC loss
    targets_onehot = torch.nn.functional.one_hot(targets, num_classes=probs.shape[1]).float()
    class_weights = class_weights.to(probs.device)
    metric = multi_weighted_logloss(targets_onehot, probs, class_weights, normalize_batch=False)
    y_true= targets_onehot

    print(f"✅ Inference Metrics:")
    print(f"  Accuracy        : {acc:.4f}")
    print(f"  Macro F1-score  : {f1:.4f}")
    print(f"  PLAsTiCC metric : {metric:.4f}")

    return y_true, probs, preds, acc, f1, metric

# - Run prediction on train
print("Run prediction on train dataset ...")
y_true_train, probs_train, preds_train, acc_train, f1_train, metric_train = predict_model(model_trained, dl_train, class_weights)

# - Run prediction on test
print("Run prediction on test dataset ...")
y_true_test, probs_test, preds_test, acc_test, f1_test, metric_test = predict_model(model_trained, dl_test, class_weights)

In [None]:
weights= dataset_test.class_weight_factors
print("weights")
print(weights)

def compute_plasticc_metric(y_true, probs, weights):
  # - Weights sum
  wsum= np.sum(weights)
  print("wsum")
  print(wsum)
  
  # - Normalize rows and limit y_preds to eps, 1-eps    
  eps = 1e-15
  probs = np.clip(probs, min=eps, max=1 - eps)

  # - Transform to log
  log_probs = np.log(probs)  # (N, M)

  # (1) Per-class weighted log loss: sum across samples
  y_log_ones = np.sum(y_true * log_probs, axis=0)              # (M,)
  nb_pos = np.sum(y_true, axis=0)                              # (M,)
  print("nb_pos")
  print(nb_pos)  
  nb_pos = np.where(nb_pos == 0, np.ones_like(nb_pos), nb_pos)
  print("nb_pos")
  print(nb_pos)  

  y_w = y_log_ones * weights / nb_pos                       # (M,)

  metric = -np.sum(y_w) / wsum
  return metric

#####################
## TRAIN METRICS
#####################
print("type(y_true_train)")
print(type(y_true_train))

metric_train= compute_plasticc_metric(y_true_train.numpy(), probs_train.numpy(), weights)    
print("== Plasticc metric (TRAIN)")
print(metric_train)

#####################
## TEST METRICS
#####################

metric_test= compute_plasticc_metric(y_true_test.numpy(), probs_test.numpy(), weights)    
print("== Plasticc metric (TEST)")
print(metric_test)