# Konstrukcja modelu

* Moduł MLP do przetwarzania metadanych
* 6 modułów LSTM do przetwarzania sekwencji z różnych pasm
* Głowica klasyfikacyjna łącząca wyniki z 7 odnóg

## Ładowanie danych

* `collate_fn` dodaje padding przy składaniu mini-pakietu

In [1]:
import torch
from torch.nn.utils.rnn import pad_sequence


def supernova_collate_fn(batch):
    """Collate function for SupernovaDataset to prepare data for SupernovaClassifierV1."""
    metadata = torch.stack([item["metadata"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])
    lengths = {
        band_id: torch.stack([item["lengths"][band_id] for item in batch])
        for band_id in range(6)
    }

    # Process sequences for each band (0-5)
    band_sequences = {
        band_id: pad_sequence(
            [item["sequences"][band_id] for item in batch],
            batch_first=True,
            padding_value=0,
        )
        for band_id in range(6)
    }

    return {
        "metadata": metadata,
        "sequences": band_sequences,
        "lengths": lengths,
        "labels": labels,
    }

In [2]:
from torch.utils.data import DataLoader
from supernova.dataset import SupernovaDataset

dataset = SupernovaDataset("../data/processed/training_set.pkl")

dataloader = DataLoader(
    dataset, batch_size=32, shuffle=True, collate_fn=supernova_collate_fn, num_workers=4
)

## Architektura modelu
* 7 gałęzi
  * MLP dla metadanych
  * 6 LSTM (1 per pasmo) dla sekwencji (krzywych fotometrycznych)
* Głowica klasyfikacyjna
  * MLP
  * Wejście: połączone wyjścia z 7 gałęzi
  * Wyjście: logity klas

In [3]:
from torch import nn
from dataclasses import dataclass


class MLP(nn.Module):
    def __init__(
        self,
        input_size: int,
        num_hidden_layers: int,
        hidden_size: int,
        output_size: int,
        dropout: float,
    ):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            *[
                nn.Sequential(
                    nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(dropout)
                )
                for _ in range(num_hidden_layers)
            ],
            nn.Linear(hidden_size, output_size),
        )

    def forward(self, x):
        return self.network(x)


@dataclass
class SupernovaClassifierV1Config:
    metadata_input_size: int
    metadata_num_hidden_layers: int
    metadata_hidden_size: int
    metadata_output_size: int
    lightcurve_input_size: int
    lightcurve_num_hidden_layers: int
    lightcurve_hidden_size: int
    classifier_hidden_size: int
    classifier_num_hidden_layers: int
    num_classes: int
    dropout: float

    def __post_init__(self):
        assert self.metadata_input_size > 0
        assert self.metadata_num_hidden_layers > 0
        assert self.metadata_hidden_size > 0
        assert self.metadata_output_size > 0

        assert self.lightcurve_input_size > 0
        assert self.lightcurve_num_hidden_layers > 0
        assert self.lightcurve_hidden_size > 0

        assert self.classifier_hidden_size > 0
        assert self.classifier_num_hidden_layers > 0
        assert self.num_classes > 0

        assert 0 <= self.dropout < 1.0


class SupernovaClassifierV1(nn.Module):
    def __init__(self, config: SupernovaClassifierV1Config):
        super().__init__()
        self.config = config

        classifier_input_size = (
            config.metadata_output_size + 6 * config.lightcurve_hidden_size
        )

        self.metadata_mlp = MLP(
            input_size=config.metadata_input_size,
            num_hidden_layers=config.metadata_num_hidden_layers,
            hidden_size=config.metadata_hidden_size,
            output_size=config.metadata_output_size,
            dropout=config.dropout,
        )
        self.lightcurve_lstm_modules = nn.ModuleList(
            [
                nn.LSTM(
                    input_size=config.lightcurve_input_size,
                    hidden_size=config.lightcurve_hidden_size,
                    num_layers=config.lightcurve_num_hidden_layers,
                    batch_first=True,
                    dropout=config.dropout,
                )
                for _ in range(6)
            ]
        )
        self.classifier_mlp = MLP(
            input_size=classifier_input_size,
            num_hidden_layers=config.classifier_num_hidden_layers,
            hidden_size=config.classifier_hidden_size,
            output_size=config.num_classes,
            dropout=config.dropout,
        )

    def forward(self, metadata, sequences, lengths):
        """
        Args:
            metadata: tensor (batch_size, metadata_input_size)
            sequences: dict mapping band_id (0-5) to padded sequences (batch_size, max_seq_len, n_lightcurve_features)
            lengths: dict mapping band_id (0-5) to sequence lengths tensor (batch_size)

        Returns:
            logits: tensor of shape (batch_size, num_classes)
        """
        # Process metadata through MLP
        metadata_features = self.metadata_mlp(metadata)

        # Process each band's lightcurve through corresponding LSTM
        lightcurve_features = [
            self._process_lightcurve(band_id, sequences[band_id], lengths[band_id])
            for band_id in range(6)
        ]

        # Concatenate all features
        combined_features = torch.cat([metadata_features] + lightcurve_features, dim=1)

        # Pass through classifier
        logits = self.classifier_mlp(combined_features)

        return logits

    def _process_lightcurve(
        self, band_id: int, sequence: torch.Tensor, lengths: torch.Tensor
    ) -> torch.Tensor:
        """Process a single band's lightcurve through its LSTM module.

        Args:
            band_id: Band index
            sequence: Padded sequences (batch_size, max_seq_len, lightcurve_input_size)
            lengths: Sequence lengths tensor (batch_size)

        Returns:
            Final hidden state (batch_size, lightcurve_hidden_size)
        """
        packed = nn.utils.rnn.pack_padded_sequence(
            sequence, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        _, (hidden, _) = self.lightcurve_lstm_modules[band_id](packed)
        return hidden[-1]

## Sprawdzenie działania

Przepuszczenie jednego minipakietu przez model

In [4]:
batch = next(iter(dataloader))
print(f"Metadat: {batch['metadata'].shape}")
for band_id in range(6):
    print(
        f"Band {band_id} sequences: {batch['sequences'][band_id].shape}, lengths: {batch['lengths'][band_id].shape}"
    )
print(f"Labels: {batch['labels'].shape}")

Metadat: torch.Size([32, 10])
Band 0 sequences: torch.Size([32, 72, 4]), lengths: torch.Size([32])
Band 1 sequences: torch.Size([32, 58, 4]), lengths: torch.Size([32])
Band 2 sequences: torch.Size([32, 58, 4]), lengths: torch.Size([32])
Band 3 sequences: torch.Size([32, 58, 4]), lengths: torch.Size([32])
Band 4 sequences: torch.Size([32, 58, 4]), lengths: torch.Size([32])
Band 5 sequences: torch.Size([32, 57, 4]), lengths: torch.Size([32])
Labels: torch.Size([32])


In [5]:
config = SupernovaClassifierV1Config(
    metadata_input_size=10,
    metadata_num_hidden_layers=2,
    metadata_hidden_size=32,
    metadata_output_size=16,
    lightcurve_input_size=4,
    lightcurve_num_hidden_layers=2,
    lightcurve_hidden_size=32,
    classifier_hidden_size=64,
    classifier_num_hidden_layers=2,
    num_classes=14,
    dropout=0.2,
)
model = SupernovaClassifierV1(config)

In [6]:
logits = model(
    metadata=batch["metadata"], sequences=batch["sequences"], lengths=batch["lengths"]
)
logits.shape

torch.Size([32, 14])

Softmax na logitach daje rozkład prawdopodobieństwa

In [7]:
from torch import softmax

softmax(logits, dim=1)

tensor([[   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
            nan,    nan,    nan,    nan,    nan],
        [0.0745, 0.0882, 0.0983, 0.0790, 0.0452, 0.0807, 0.0655, 0.0604, 0.0454,
         0.0534, 0.0789, 0.1041, 0.0645, 0.0618],
        [0.0711, 0.0792, 0.0795, 0.0740, 0.0580, 0.0774, 0.0726, 0.0644, 0.0618,
         0.0624, 0.0722, 0.0952, 0.0701, 0.0620],
        [0.0693, 0.0767, 0.0703, 0.0802, 0.0620, 0.0706, 0.0686, 0.0782, 0.0623,
         0.0650, 0.0740, 0.0783, 0.0727, 0.0719],
        [0.0660, 0.0753, 0.0650, 0.0804, 0.0615, 0.0685, 0.0736, 0.0743, 0.0675,
         0.0658, 0.0737, 0.0841, 0.0737, 0.0706],
        [0.0687, 0.0782, 0.0833, 0.0697, 0.0569, 0.0735, 0.0653, 0.0676, 0.0594,
         0.0611, 0.0921, 0.0958, 0.0666, 0.0618],
        [0.0680, 0.0789, 0.0715, 0.0736, 0.0648, 0.0672, 0.0719, 0.0746, 0.0671,
         0.0693, 0.0698, 0.0840, 0.0722, 0.0671],
        [0.0685, 0.0824, 0.0715, 0.0808, 0.0616, 0.0732, 0.0693, 0.0751, 0.0534,
  

W danych występują braki - trzeba będzie to poprawić, wprowadzić jakąś strategię imputacji (dla metadanych)