In [2]:
!pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.5.1-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Downloading pytorch_lightning-2.4.0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Downloading torchmetrics-1.5.1-py3-none-any.whl (890 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m890.6/890.6 kB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.11.8 pytorch-lightning-2.4.0 torchmetrics-1.5.1


In [None]:
from google.colab import drive
drive.mount('/content/drive')
# Não tá usando ainda no código
checkpoint_dir = "/content/drive/MyDrive/Checkpoints"

# Certifica que existee
import os
os.makedirs(checkpoint_dir, exist_ok=True)

In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

# Trocar aqui pelo nosso modelo
class SimpleNN(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes, learning_rate=1e-3):
        super(SimpleNN, self).__init__()
        self.learning_rate = learning_rate
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss) # salvando no arquivo de log q ele cria
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss) # salvando no arquivo de log q ele cria
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

# substituir pelos nosso dados reais
def load_data():
    X_train = torch.randn(1000, 10)  # 1000 amostras, 10 características
    y_train = torch.randint(0, 2, (1000,))  # Classes 0 e 1
    X_val = torch.randn(200, 10)
    y_val = torch.randint(0, 2, (200,))

    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    return train_loader, val_loader
# Salvar a cada n épocas
# Dara para usar mais de um ao mesmo tempo, adicionando ao callbacks de pl.Trainer
#checkpoint_callback = ModelCheckpoint(
#    every_n_epochs=5,  # Salva o modelo a cada 5 épocas
#    dirpath="checkpoints/",
#    filename="epoch-{epoch:02d}",
#)

# Configuração para checkpoints no PyTorch Lightning
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints/",
    filename="best-checkpoint",
    save_top_k=1, # aqui daria para slavar mais de um
    mode="min"
)

# Função principal para treinar o modelo
def train_model():
    train_loader, val_loader = load_data()
    model = SimpleNN(input_size=10, hidden_size=16, num_classes=2, learning_rate=1e-3)

    trainer = pl.Trainer(
        max_epochs=10,
        callbacks=[checkpoint_callback], # aqui da para adicionar mais callbacks para slavar coisas diferentes
        accelerator="gpu",
        devices=1 if torch.cuda.is_available() else None
    )

    trainer.fit(model, train_loader, val_loader)
    print("Treinamento finalizado!")
    print(f"Melhor modelo salvo em: {checkpoint_callback.best_model_path}")

train_model()

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | Sequential       | 210    | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
210       Trainable params
0         Non-trainable params
210       Total params
0.001     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


Treinamento finalizado!
Melhor modelo salvo em: /content/checkpoints/best-checkpoint.ckpt


In [None]:
best_model = SimpleNN.load_from_checkpoint("checkpoints/best-checkpoint.ckpt")