Skip to content

Commit

Permalink
update pytorch docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mlech26l committed Jul 28, 2023
1 parent 6b5ba81 commit 4687046
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 72 deletions.
4 changes: 3 additions & 1 deletion docs/examples/torch_first_steps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ For the wiring we will use the ```AutoNCP`` class, which creates a NCP wiring di

.. code-block:: python
out_features = 1
in_features = 2
wiring = AutoNCP(16, out_features) # 16 units, 1 motor neuron
ltc_model = LTC(in_features, wiring, batch_first=True)
Expand All @@ -123,7 +126,6 @@ For the wiring we will use the ```AutoNCP`` class, which creates a NCP wiring di
logger=pl.loggers.CSVLogger("log"),
max_epochs=400,
gradient_clip_val=1, # Clip gradient to stabilize training
gpus=0,
)
Draw the wiring diagram of the network
Expand Down
136 changes: 65 additions & 71 deletions examples/pt_example.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,42 @@
# Copyright (2017-2021)
# The Wormnet project
# Mathias Lechner (mlechner@ist.ac.at)
import numpy as np
import torch.nn as nn
import kerasncp as kncp
from kerasncp.torch import LTCCell
from ncps.wirings import AutoNCP
from ncps.torch import LTC
import pytorch_lightning as pl
import torch
import torch.utils.data as data

# nn.Module that unfolds a RNN cell into a sequence
class RNNSequence(nn.Module):
def __init__(
self,
rnn_cell,
):
super(RNNSequence, self).__init__()
self.rnn_cell = rnn_cell

def forward(self, x):
device = x.device
batch_size = x.size(0)
seq_len = x.size(1)
hidden_state = torch.zeros(
(batch_size, self.rnn_cell.state_size), device=device
)
outputs = []
for t in range(seq_len):
inputs = x[:, t]
new_output, hidden_state = self.rnn_cell.forward(inputs, hidden_state)
outputs.append(new_output)
outputs = torch.stack(outputs, dim=1) # return entire sequence
return outputs
import matplotlib.pyplot as plt
import seaborn as sns

N = 48 # Length of the time-series
out_features = 1
in_features = 2
# Input feature is a sine and a cosine wave
data_x = np.stack(
[np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], axis=1
)
data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension
# Target output is a sine with double the frequency of the input signal
data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32)
print("data_x.shape: ", str(data_x.shape))
print("data_y.shape: ", str(data_y.shape))
data_x = torch.Tensor(data_x)
data_y = torch.Tensor(data_y)
dataloader = data.DataLoader(
data.TensorDataset(data_x, data_y), batch_size=1, shuffle=True, num_workers=4
)

# Let's visualize the training data
sns.set()
plt.figure(figsize=(6, 4))
plt.plot(data_x[0, :, 0], label="Input feature 1")
plt.plot(data_x[0, :, 1], label="Input feature 1")
plt.plot(data_y[0, :, 0], label="Target output")
plt.ylim((-1, 1))
plt.title("Training data")
plt.legend(loc="upper right")
plt.savefig("pt_plot1.png")


# LightningModule for training a RNNSequence module
Expand All @@ -43,15 +48,15 @@ def __init__(self, model, lr=0.005):

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model.forward(x)
y_hat, _ = self.model.forward(x)
y_hat = y_hat.view_as(y)
loss = nn.MSELoss()(y_hat, y)
self.log("train_loss", loss, prog_bar=True)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model.forward(x)
y_hat, _ = self.model.forward(x)
y_hat = y_hat.view_as(y)
loss = nn.MSELoss()(y_hat, y)

Expand All @@ -65,52 +70,41 @@ def test_step(self, batch, batch_idx):
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=self.lr)

def optimizer_step(
self,
current_epoch,
batch_nb,
optimizer,
optimizer_idx,
closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
optimizer.optimizer.step(closure=closure)
# Apply weight constraints
self.model.rnn_cell.apply_weight_constraints()

wiring = AutoNCP(16, out_features) # 16 units, 1 motor neuron

in_features = 2
out_features = 1
N = 48 # Length of the time-series
# Input feature is a sine and a cosine wave
data_x = np.stack(
[np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], axis=1
)
data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension
# Target output is a sine with double the frequency of the input signal
data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32)
data_x = torch.Tensor(data_x)
data_y = torch.Tensor(data_y)
print("data_y.shape: ", str(data_y.shape))

wiring = kncp.wirings.FullyConnected(8, out_features) # 16 units, 8 motor neurons
ltc_cell = LTCCell(wiring, in_features)
dataloader = data.DataLoader(
data.TensorDataset(data_x, data_y), batch_size=1, shuffle=True, num_workers=4
)

ltc_sequence = RNNSequence(
ltc_cell,
)
learn = SequenceLearner(ltc_sequence, lr=0.01)
ltc_model = LTC(in_features, wiring, batch_first=True)
learn = SequenceLearner(ltc_model, lr=0.01)
trainer = pl.Trainer(
logger=pl.loggers.CSVLogger("log"),
max_epochs=400,
progress_bar_refresh_rate=1,
gradient_clip_val=1, # Clip gradient to stabilize training
gpus=1,
)


# Train the model for 400 epochs (= training steps)
trainer.fit(learn, dataloader)
results = trainer.test(learn, dataloader)

# Let's visualize how LTC initialy performs before the training
sns.set()
with torch.no_grad():
prediction = ltc_model(data_x)[0].numpy()
plt.figure(figsize=(6, 4))
plt.plot(data_y[0, :, 0], label="Target output")
plt.plot(prediction[0, :, 0], label="NCP output")
plt.ylim((-1, 1))
plt.title("Before training")
plt.legend(loc="upper right")
plt.savefig("pt_plot2.png")

# How does the trained model now fit to the sinusoidal function?
sns.set()
with torch.no_grad():
prediction = ltc_model(data_x)[0].numpy()
plt.figure(figsize=(6, 4))
plt.plot(data_y[0, :, 0], label="Target output")
plt.plot(prediction[0, :, 0], label="NCP output")
plt.ylim((-1, 1))
plt.title("After training")
plt.legend(loc="upper right")
plt.savefig("pt_plot3.png")

0 comments on commit 4687046

Please sign in to comment.