In [None]:
import torch
import pytorch_lightning as pl

from typing import Any

In [None]:
# localization network
class LocalizationNetwork(pl.LightningModule):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=7)
        self.conv2 = torch.nn.Conv2d(8, 10, kernel_size=5)
        self.fc1 = torch.nn.Linear(10 * 3 * 3, 32)
        self.fc2 = torch.nn.Linear(32, 3 * 2)
    
    def forward(self, x):
        xs = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv1(x), 2))
        xs = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv2(xs), 2))
        xs = xs.view(-1, 10 * 3 * 3)
        xs = torch.nn.functional.relu(self.fc1(xs))
        xs = self.fc2(xs)
        return xs.view(-1, 2, 3)

    # training step
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    # validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log('val_loss', loss)
        return loss

    # test step
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log('test_loss', loss)
        return loss

    # configure optimizer
    def configure_optimizers(self):
        # adamw optimizer
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        # learning rate scheduler
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
        return [optimizer], [scheduler]

In [None]:
# create model
model = LocalizationNetwork()

In [None]:
# prepare random dataset for train, validation and test
train_dataset = torch.utils.data.TensorDataset(torch.rand(100, 3, 224, 224), torch.rand(100, 2, 3))
val_dataset = torch.utils.data.TensorDataset(torch.rand(100, 3, 224, 224), torch.rand(100, 2, 3))
test_dataset = torch.utils.data.TensorDataset(torch.rand(100, 3, 224, 224), torch.rand(100, 2, 3))

# create dataloaders
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# train model
trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=20)
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
# test model
trainer.test(model, test_dataloaders=test_dataloader)