[Episode 1](https://www.youtube.com/watch?v=OMDn66kM9Qc&list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2&index=1) 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

# Pytorch Lightning

1. model
2. optimizer
3. data
4. training loop "the magic"
5. validation loop "the validation magic"

In [None]:
import pytorch_lightning as pl
import torchmetrics

In [None]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
     
    
    def prepare_data(self):
        datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
        datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())
        
        
    def setup(self, stage):
        # transformation
        dataset = datasets.MNIST('data', train=True, download=False, transform=transforms.ToTensor())
        self.test_dataset = datasets.MNIST('data', train=False, download=False, transform=transforms.ToTensor())
        self.train_dataset, self.val_dataset = random_split(dataset, [55000, 5000])
    
        
    def train_dataloader(self):
        train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size)
        return train_loader
    
    
    def val_dataloader(self):
        val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size)
        return val_loader
    
    
    def test_dataloader(self):
        test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size)
        return test_loader

In [None]:
class ResNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28 , 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)
        
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        do = self.do(h2 + h1)
        logits = self.l3(do)
        return logits
    
    
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=1e-2)
        return optimizer
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        b = x.size(0)
        x = x.view(b, -1)
        
        logits = self(x) # l: logit
        # import pdb; pdb.set_trace()
        
        J = self.loss(logits, y) # J: loss value
        
        acc = torchmetrics.functional.accuracy(logits, y)
        #pbar = {'train_acc': acc}
        self.log('train_acc', acc, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
        #return {'loss': J, 'progress_bar': pbar}
        return {'loss': J}
    
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        b = x.size(0)
        x = x.view(b, -1)
        
        logits = self(x) # l: logit
        # import pdb; pdb.set_trace()
        
        J = self.loss(logits, y) # J: loss value
        
        acc = torchmetrics.functional.accuracy(logits, y)
        #pbar = {'val_acc': acc}
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        #return {'loss': J, 'progress_bar': pbar}
        return {'loss': J}
    
    
    def validation_epoch_end(self, val_step_outputs):
        # [results, results, results, results]
        avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
        #avg_val_acc = torch.tensor([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()
        #pbar = {'avg_val_acc': avg_val_acc}
        #return {'val_loss': avg_val_loss, 'progress_bar': pbar}
        return {'val_loss': avg_val_loss}
    
    # Overwrite
    # def backward(self, trainer, loss, optimizer, optimizer_idx):
    #     loss.backward()
    
model = ResNet()

In [None]:
mnist_dm = MNISTDataModule()

trainer = pl.Trainer(progress_bar_refresh_rate=20,
                     max_epochs=5,
                     gpus=2,
                     num_nodes=1, accelerator="dp") # progress bar update 느리게
# trainer = pl.Trainer()
outputs = trainer.fit(model, mnist_dm)

In [None]:
! ls lightning_logs/version_2/checkpoints

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self):
        self.resnet = ResNet()

In [None]:
train_data = datasets.MNIST('data', train=True, download=False, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

In [None]:
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 10)
)

model.to('cuda:1')

In [None]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28 , 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)
        
    
    def forward(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        do = self.do(h2 + h1)
        logits = self.l3(do)
        return logits

model = ResNet().to('cuda:1')

In [None]:
params = model.parameters()
optimiser = optim.SGD(params, lr=1e-2)

In [None]:
loss = nn.CrossEntropyLoss()

In [None]:
nb_epochs = 5
for epoch in range(nb_epochs):
    losses = list()
    accuracies = list()
    model.train()
    for batch in train_loader:
        x, y = batch
        
        b = x.size(0)
        x = x.view(b, -1).to('cuda:1')
        
        l = model(x) # l: logit
        # import pdb; pdb.set_trace()
        
        J = loss(l, y.to('cuda:1')) # J: loss value
        model.zero_grad()
        J.backward()
        optimiser.step()
        
        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1).cpu()).float().mean())
    
    print(f'Epoch {epoch + 1}', end=', ')
    print(f'training loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'training accuracy: {torch.tensor(accuracies).mean():.2f}')
    
    losses = list()
    accuracies = list()
    model.eval()
    for batch in val_loader:
        x, y = batch
        
        b = x.size(0)
        x = x.view(b, -1).to('cuda:1')
        
        with torch.no_grad():
            l = model(x) # l: logit
        J = loss(l, y.to('cuda:1')) # J: loss value
        
        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1).cpu()).float().mean())
        
    print(f'Epoch {epoch + 1}', end=', ')
    print(f'validation loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'validation accuracy: {torch.tensor(accuracies).mean():.2f}')

In [None]:
import matplotlib.pyplot as plt