In [3]:

%load_ext tensorboard
#%reload_ext tensorboard


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [4]:
import torch.optim.lr_scheduler as lr_scheduler
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics import functional as FM

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [5]:
loss_ftn = nn.CrossEntropyLoss()

In [14]:
class Model(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(28*28, 32)
        self.linear2 = nn.Linear(28*28, 32)
        self.linear3 = nn.Linear(32+32, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x0 = self.flatten(x)
        x1_1 = self.linear1(x0)
        x1_2 = self.linear2(x0)

        x2_1 = self.relu(x1_1)
        x2_2 = self.relu(x1_2)

        x3 = torch.cat([x2_1, x2_2], dim=1) 

        x4 = self.linear3(x3)

        return x4 
    

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = loss_ftn(y_pred, y)
        acc = FM.accuracy(y_pred, y, task='multiclass', num_classes=10)
        met = {'loss' : loss, 'acc' : acc}
        self.log_dict(met)
        return loss


    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = loss_ftn(y_pred, y)
        acc = FM.accuracy(y_pred, y, task='multiclass', num_classes=10)
        met = {'val_loss' : loss, 'val_acc' : acc}
        self.log_dict(met)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    



In [15]:
model = Model()

In [16]:
logger = pl.loggers.TensorBoardLogger('tb_logs', name='model_tb')

trainer = pl.Trainer(max_epochs=20, logger=logger, accelerator='auto')
trainer.fit(model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | flatten | Flatten | 0      | train
1 | linear1 | Linear  | 25.1 K | train
2 | linear2 | Linear  | 25.1 K | train
3 | linear3 | Linear  | 650    | train
4 | relu    | ReLU    | 0      | train
--------------------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\msong\anaconda3\envs\py3_11_8\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
c:\Users\msong\anaconda3\envs\py3_11_8\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [17]:
%tensorboard --logdir ./tb_logs


Reusing TensorBoard on port 6006 (pid 18848), started 0:25:38 ago. (Use '!kill 18848' to kill it.)

In [9]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils import data
from torch.utils.data import DataLoader

class MNISTDataModule(pl.LightningDataModule):
  def __init__(self, data_dir: str = '', batch_size: int = 32):
    super().__init__()
    self.data_dir = data_dir
    self.batch_size = batch_size

  def setup(self, stage):
    # transforms for images
    transform=transforms.Compose([transforms.ToTensor(), # 1/255,tensor로 변환
                                  transforms.Normalize((0.1307,), (0.3081,))])
    self.mnist_test = MNIST(self.data_dir, train=False, transform=transform, download=True)
    mnist_full = MNIST(self.data_dir, train=True, transform=transform, download=True)
    self.mnist_train, self.mnist_val = data.random_split(mnist_full, [55000, 5000])

  def train_dataloader(self):
    return DataLoader(self.mnist_train, batch_size=self.batch_size)
  def val_dataloader(self):
    return DataLoader(self.mnist_val, batch_size=self.batch_size)
  def test_dataloader(self):
    return DataLoader(self.mnist_test, batch_size=self.batch_size)

data_module = MNISTDataModule(batch_size=256)