In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from sklearn.datasets import load_iris

In [2]:
# GPU の設定状況に基づいたデバイスの選択
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
# Iris データセットの読み込み
x, t = load_iris(return_X_y=True)

x = torch.tensor(x, dtype=torch.float32)
t = torch.tensor(t, dtype=torch.int64)

In [4]:
x

tensor([[5.1000, 3.5000, 1.4000, 0.2000],
        [4.9000, 3.0000, 1.4000, 0.2000],
        [4.7000, 3.2000, 1.3000, 0.2000],
        [4.6000, 3.1000, 1.5000, 0.2000],
        [5.0000, 3.6000, 1.4000, 0.2000],
        [5.4000, 3.9000, 1.7000, 0.4000],
        [4.6000, 3.4000, 1.4000, 0.3000],
        [5.0000, 3.4000, 1.5000, 0.2000],
        [4.4000, 2.9000, 1.4000, 0.2000],
        [4.9000, 3.1000, 1.5000, 0.1000],
        [5.4000, 3.7000, 1.5000, 0.2000],
        [4.8000, 3.4000, 1.6000, 0.2000],
        [4.8000, 3.0000, 1.4000, 0.1000],
        [4.3000, 3.0000, 1.1000, 0.1000],
        [5.8000, 4.0000, 1.2000, 0.2000],
        [5.7000, 4.4000, 1.5000, 0.4000],
        [5.4000, 3.9000, 1.3000, 0.4000],
        [5.1000, 3.5000, 1.4000, 0.3000],
        [5.7000, 3.8000, 1.7000, 0.3000],
        [5.1000, 3.8000, 1.5000, 0.3000],
        [5.4000, 3.4000, 1.7000, 0.2000],
        [5.1000, 3.7000, 1.5000, 0.4000],
        [4.6000, 3.6000, 1.0000, 0.2000],
        [5.1000, 3.3000, 1.7000, 0

In [5]:
#DataSetに格納
dataset = torch.utils.data.TensorDataset(x, t)
dataset

<torch.utils.data.dataset.TensorDataset at 0x2160a219f10>

In [6]:
# datasetの分割
n_train = int(len(dataset) * 0.6)
n_val = int(len(dataset) * 0.2)
n_test = len(dataset) - n_train - n_val

n_train, n_val, n_test

(90, 30, 30)

In [7]:
# ランダムに分割する
torch.manual_seed(0)

train, val, test = torch.utils.data.random_split(dataset, [n_train, n_val, n_test])

In [8]:
len(train), len(val), len(test)

(90, 30, 30)

In [9]:
# 学習データ用クラス
class TrainNet(pl.LightningModule):
    
    @pl.data_loader
    def train_dataloader(self):
        return torch.utils.data.DataLoader(train, self.batch_size, shuffle=True)
    
    def training_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)        
        tensorboard_logs = {'train/train_loss': loss, 'train/train_acc': acc} # tensorboard
        results = {'loss': loss, 'log': tensorboard_logs}
        #results = {'loss': loss}
        return results

In [10]:
# 検証データ用クラス
class ValidationNet(pl.LightningModule):

    @pl.data_loader
    def val_dataloader(self):
        return torch.utils.data.DataLoader(val, self.batch_size)

    def validation_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'val_loss': loss, 'val_acc': acc}
        return results

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        tensorboard_logs = {'val/avg_loss': avg_loss, 'val/avg_acc': avg_acc}
        results = {'val_loss': avg_loss, 'val_acc': avg_acc, 'log': tensorboard_logs}        
        #results = {'val_loss': avg_loss, 'val_acc': avg_acc}
        return results

In [11]:
# テストデータ用クラス
class TestNet(pl.LightningModule):

    @pl.data_loader
    def test_dataloader(self):
        return torch.utils.data.DataLoader(test, self.batch_size)

    def test_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'test_loss': loss, 'test_acc': acc}
        return results

    def test_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        results = {'test_loss': avg_loss, 'test_acc': avg_acc}
        return results

In [12]:
# 学習データ、検証データ、テストデータクラスの継承クラス
class Net(TrainNet, ValidationNet, TestNet):
    def __init__(self, input_size=4, hidden_size=4, output_size=3, batch_size=10):
        super(Net, self).__init__()
        self.L1 = nn.Linear(input_size, hidden_size)
        self.L2 = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        
    def forward(self, x):
        x = self.L1(x)
        x = F.relu(x)
        x = self.L2(x)
        return x
    
    def lossfun(self, y, t):
        return F.cross_entropy(y, t)
    
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

In [13]:
net = Net()

#print(net.L1.weight, net.L1.bias) # 学習前のパラメータ

trainer = Trainer(max_epochs=30) # 学習用のインスタンス化と学習の
trainer.fit(net)


INFO:lightning:
  | Name | Type        | Params
---------------------------------
0 | L1   | Linear      | 20    
1 | L2   | Linear      | 15    
2 | bn   | BatchNorm1d | 8     


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=3.0, style=Prog…




1

In [14]:
trainer.callback_metrics

{'loss': 0.4427512288093567,
 'train/train_loss': 0.4427512288093567,
 'train/train_acc': 0.699999988079071,
 'val_loss': 0.42039212584495544,
 'val_acc': 0.7333333492279053,
 'val/avg_loss': 0.42039212584495544,
 'val/avg_acc': 0.7333333492279053,
 'epoch': 29}

In [18]:
# 重みの更新内容
print(net.L1.weight, net.L1.bias)

Parameter containing:
tensor([[ 1.3241,  1.9703, -3.1035, -2.3658],
        [-0.2348,  0.1136, -0.0432,  0.3854],
        [-0.0444,  0.1323, -0.1511, -0.0983],
        [-0.4777, -0.3311, -0.2061,  0.0185]], requires_grad=True) Parameter containing:
tensor([ 1.0529,  0.2931, -0.3390, -0.2177], requires_grad=True)


In [15]:
trainer.test()

HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=3.0, style=Progres…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': 0.6333333253860474, 'test_loss': 0.49433454871177673}
--------------------------------------------------------------------------------



In [16]:
trainer.callback_metrics

{'loss': 0.4427512288093567,
 'train/train_loss': 0.4427512288093567,
 'train/train_acc': 0.699999988079071,
 'val_loss': 0.42039212584495544,
 'val_acc': 0.7333333492279053,
 'val/avg_loss': 0.42039212584495544,
 'val/avg_acc': 0.7333333492279053,
 'epoch': 29,
 'test_loss': 0.49433454871177673,
 'test_acc': 0.6333333253860474}

In [17]:
# パラメータの保存
net = net.to('cpu')
torch.save(net.state_dict(), 'lightning_param.pt')

In [None]:
# 保存したパラメータの読み込み
net = Net()
net.load_state_dict(torch.load('lightning_param.pt'))