# Trening klasyfikatora

In [1]:
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from torch import nn
from torch.optim import Adam
from torch.utils.data import random_split, DataLoader

from supernova.dataset import SupernovaDataset, supernova_collate_fn
from supernova.modeling.model import SupernovaClassifierV1Config, SupernovaClassifierV1

In [2]:
class SupernovaTraining(LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=["model"])

        self.model = model
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, metadata, sequences, lengths):
        return self.model(metadata, sequences, lengths)

    def training_step(self, batch, batch_idx):
        metadata = batch["metadata"]
        sequences = batch["sequences"]
        lengths = batch["lengths"]
        labels = batch["labels"]

        logits = self(metadata, sequences, lengths)
        loss = self.criterion(logits, labels)
        acc = (logits.argmax(dim=1) == labels).float().mean()

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        metadata = batch["metadata"]
        sequences = batch["sequences"]
        lengths = batch["lengths"]
        labels = batch["labels"]

        logits = self(metadata, sequences, lengths)
        loss = self.criterion(logits, labels)
        acc = (logits.argmax(dim=1) == labels).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.hparams.learning_rate)

In [3]:
dataset = SupernovaDataset("../data/processed/training_set.pkl")

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=supernova_collate_fn,
    num_workers=4,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=supernova_collate_fn,
    num_workers=4,
)

In [4]:
config = SupernovaClassifierV1Config(
    metadata_input_size=10,
    metadata_num_hidden_layers=2,
    metadata_hidden_size=32,
    metadata_output_size=16,
    lightcurve_input_size=4,
    lightcurve_num_hidden_layers=2,
    lightcurve_hidden_size=32,
    classifier_hidden_size=64,
    classifier_num_hidden_layers=2,
    num_classes=14,
    dropout=0.2,
)
model = SupernovaClassifierV1(config)

In [5]:
training = SupernovaTraining(model, learning_rate=1e-3)

checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",
    dirpath="../models/checkpoints",
    filename="supernova-{epoch:02d}-{val_acc:.2f}",
    save_top_k=1,
    mode="max",
)

early_stop_callback = EarlyStopping(monitor="val_loss", patience=10, mode="min")

# Trainer
trainer = Trainer(
    max_epochs=4,
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator="auto",
    devices=1,
    log_every_n_steps=1,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/home/mgarbowski/repos/fo-projekt/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [6]:
trainer.fit(training, train_loader, val_loader)

/home/mgarbowski/repos/fo-projekt/.venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/mgarbowski/repos/fo-projekt/models/checkpoints exists and is not empty.

  | Name      | Type                  | Params | Mode  | FLOPs
--------------------------------------------------------------------
0 | model     | SupernovaClassifierV1 | 105 K  | train | 0    
1 | criterion | CrossEntropyLoss      | 0      | train | 0    
--------------------------------------------------------------------
105 K     Trainable params
0         Non-trainable params
105 K     Total params
0.422     Total estimated model params size (MB)
37        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Trening się kończy bo przez braki w danych `loss=nan`