In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

In [7]:
mnist = MNIST("/Volumes/GoogleDrive/내 드라이브/data/", download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist, batch_size=100, shuffle=True)
test_loader = DataLoader(mnist, batch_size=100, shuffle=False)

In [128]:
class DRCN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 50, 5, padding=2),
            nn.MaxPool2d(2, return_indices=True)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(50, 100, 5, padding=2),
            nn.MaxPool2d(2, return_indices=True)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(100, 150, 3, padding=2),
            nn.MaxPool2d(2, return_indices=True)
        )
        self.enc4 = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(),
            nn.Linear(2400, 100)
        )
        self.label = nn.Sequential(
            nn.Dropout(),
            nn.Linear(100, 10)
        )
        self.dec1 = nn.Sequential(
            nn.Linear(100, 2400), 
            nn.Unflatten(1, (150, 4, 4))
        )
        self.unpool1 = nn.MaxUnpool2d(2, stride=2)
        self.dec2 = nn.ConvTranspose2d(150, 100, 3, padding=2)
        self.unpool2 = nn.MaxUnpool2d(2, stride=2)
        self.dec3 = nn.ConvTranspose2d(100, 50, 5, padding=2)
        self.unpool3 = nn.MaxUnpool2d(2, stride=2)
        self.recon = nn.ConvTranspose2d(50, 1, 5, padding=2)

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        return recon

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        enc1, indices1 = self.enc1(x)
        enc2, indices2 = self.enc2(enc1)
        enc3, indices3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        label = self.label(enc4)
        dec1 = self.dec1(enc4)
        unpool1 = self.unpool1(dec1, indices3, output_size=(9, 9))
        dec2 = self.dec2(unpool1)
        unpool2 = self.unpool1(dec2, indices2, output_size=(14, 14))
        dec3 = self.dec3(unpool2)
        unpool3 = self.unpool3(dec3, indices1, output_size=(28, 28))
        recon = self.recon(unpool3)
        loss = F.mse_loss(recon, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [129]:
# init model
drcn = DRCN()
trainer = pl.Trainer(
    max_epochs=10,
    logger=False
)
trainer.fit(model=drcn, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

   | Name    | Type            | Params
---------------------------------------------
0  | enc1    | Sequential      | 1.3 K 
1  | enc2    | Sequential      | 125 K 
2  | enc3    | Sequential      | 135 K 
3  | enc4    | Sequential      | 240 K 
4  | label   | Sequential      | 1.0 K 
5  | dec1    | Sequential      | 242 K 
6  | unpool1 | MaxUnpool2d     | 0     
7  | dec2    | ConvTranspose2d | 135 K 
8  | unpool2 | MaxUnpool2d     | 0     
9  | dec3    | ConvTranspose2d | 125 K 
10 | unpool3 | MaxUnpool2d     | 0     
11 | recon   | ConvTranspose2d | 1.3 K 
---------------------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
4.026     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
