In [38]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl

In [39]:
transform = Compose(
    [
        ToTensor(),
    ]
)

In [40]:
# get the data
train = MNIST(
    root="./dara/",
    download=True,
    train=True,
    transform=transform
)
val = MNIST(
    root="./data/",
    download=True,
    train=False,
    transform=transform
)

In [90]:
class LinearModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
        loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)
        return loader
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def val_dataloader(self):
        # TODO: do a real train/val split
        dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
        loader = DataLoader(dataset, batch_size=32, num_workers=4)
        return loader
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def test_dataloader(self):
        # TODO: do a real train/val split
        dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
        loader = DataLoader(dataset, batch_size=32, num_workers=4)
        return loader
    

In [91]:
model = LinearModel()

In [99]:
trainer = pl.Trainer(max_epochs=10, num_processes=1)

GPU available: True, used: False
No environment variable for node rank defined. Set as 0.


In [100]:
# trainer = pl.Trainer

In [None]:
trainer.fit(model)


  | Name | Type   | Params
----------------------------
0 | l1   | Linear | 7 K   


Epoch 1:  86%|████████▌ | 1875/2188 [00:11<00:01, 157.57it/s, loss=0.989, v_num=11]
Validating: 0it [00:00, ?it/s][A
Epoch 1:  86%|████████▌ | 1879/2188 [00:12<00:01, 156.26it/s, loss=0.989, v_num=11]
Epoch 1:  88%|████████▊ | 1919/2188 [00:12<00:01, 158.26it/s, loss=0.989, v_num=11]
Epoch 1:  91%|█████████ | 1993/2188 [00:12<00:01, 163.00it/s, loss=0.989, v_num=11]
Epoch 1:  95%|█████████▍| 2074/2188 [00:12<00:00, 168.24it/s, loss=0.989, v_num=11]
Epoch 1: 100%|██████████| 2188/2188 [00:12<00:00, 174.86it/s, loss=0.989, v_num=11]
Epoch 2:   0%|          | 0/2188 [00:00<?, ?it/s, loss=0.989, v_num=11]            



Epoch 2:  86%|████████▌ | 1875/2188 [00:11<00:01, 164.92it/s, loss=0.945, v_num=11]
Validating: 0it [00:00, ?it/s][A
Validating:   1%|▏         | 4/313 [00:00<00:08, 35.04it/s][A
Epoch 2:  89%|████████▉ | 1944/2188 [00:11<00:01, 167.49it/s, loss=0.945, v_num=11]
Epoch 2:  93%|█████████▎| 2025/2188 [00:11<00:00, 172.95it/s, loss=0.945, v_num=11]
Epoch 2:  96%|█████████▋| 2106/2188 [00:11<00:00, 178.28it/s, loss=0.945, v_num=11]
Epoch 2: 100%|██████████| 2188/2188 [00:11<00:00, 183.28it/s, loss=0.945, v_num=11]
Epoch 3:  86%|████████▌ | 1875/2188 [00:11<00:01, 167.44it/s, loss=0.990, v_num=11]
Validating: 0it [00:00, ?it/s][A
Validating:   1%|▏         | 4/313 [00:00<00:08, 36.53it/s][A
Epoch 3:  89%|████████▉ | 1944/2188 [00:11<00:01, 170.15it/s, loss=0.990, v_num=11]
Epoch 3:  93%|█████████▎| 2025/2188 [00:11<00:00, 175.66it/s, loss=0.990, v_num=11]
Epoch 3:  96%|█████████▋| 2108/2188 [00:11<00:00, 181.28it/s, loss=0.990, v_num=11]
Epoch 3: 100%|██████████| 2188/2188 [00:11<00:00, 

In [69]:
trainer.test()

Testing:  79%|███████▊  | 246/313 [00:00<00:02, 25.08it/s]--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_loss': tensor(1.1330, device='cuda:0'),
 'test_loss': tensor(1.1330, device='cuda:0')}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 313/313 [00:00<00:00, 602.65it/s]


In [79]:
model.test_dataloader()

<torch.utils.data.dataloader.DataLoader at 0x7fb0a448cb50>