In [11]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

inputs = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float)
outputs = torch.tensor([[0], [1], [1], [0]], dtype=torch.float)

In [12]:
class Model(pl.LightningModule):
    def __init__(self, hidden_dim = 2):
        super().__init__()
        self.inner = torch.nn.Sequential(
            torch.nn.Linear(inputs.shape[1], hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_dim, 1),
            torch.nn.Tanh()
        )

    def forward(self, x):
        return self.inner.forward(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.mse_loss(logits, y)

        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.035)
        return optimizer

model = Model()
trainer = pl.Trainer(max_epochs=1000, enable_progress_bar=False, log_every_n_steps=1, accelerator='cpu')
trainer.fit(model, [inputs, outputs])

trainer.predict(model, inputs, return_predictions=True)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | inner | Sequential | 9     
-------------------------------------
9         Trainable params
0         Non-trainable params
9         Total params
0.000     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=1000` reached.


[tensor([0.0054]), tensor([0.9376]), tensor([0.9310]), tensor([0.0018])]

In [13]:
print(list(model.parameters()))

[Parameter containing:
tensor([[ 2.0390, -1.9987],
        [ 1.3765, -1.3589]], requires_grad=True), Parameter containing:
tensor([ 1.0219, -0.5786], requires_grad=True), Parameter containing:
tensor([[-1.6162,  1.7088]], requires_grad=True), Parameter containing:
tensor([2.1423], requires_grad=True)]
