# 🔥 Building Rockpool modules in Torch

## Convert an existing Torch ``torch.nn.module`` for use in Rockpool

In [112]:
# - Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F

# - Rockpool imports
from rockpool.nn.modules import TorchModule

# - Rich printing
try:
    from rich import print
except:
    pass

# - Implement a Torch class
class TorchNet(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # - Build some convolutional layers
        self.conv1 = nn.Conv2d(1, 2, 3, 1)
        
        # - Add a dropout layer
        self.dropout1 = nn.Dropout2d(0.25)
        
        # - Fully-connected layer
        self.fc1 = nn.Linear(338, 10)
        
        # - Register an example buffer
        self.register_buffer('test_buf', torch.zeros(3, 4))
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        
        x = torch.flatten(x, 1)
        
        x = self.fc1(x)
        x = F.relu(x)
        
        output = F.log_softmax(x, dim = 1)
        return output

In [113]:
# - Instantiate the network and test the Torch API

# Equates to one random 28x28 image
random_data = torch.rand((1, 1, 28, 28))

# - Generate torch module and test evaluation
mod = TorchNet()
result = mod(random_data)

In [114]:
# - Convert object to Rockpool API, in-place
TorchModule.from_torch(mod)
print(mod)

In [115]:
# - Use the Rockpool API to evolve the module
output, _, _ = mod(random_data)
print(output)

In [116]:
# - Use the Rockpool API to access parameters
print('Parameters: ', mod.parameters())
print('State: ', mod.state())

## Write a native Rockpool/Torch module using ``TorchModule``

In [121]:
# - Implement a Rockpool class using the TorchModule base class
class RockpoolNet(TorchModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # - Build some convolutional layers
        self.conv1 = nn.Conv2d(1, 2, 3, 1)
        
        # - Add a dropout layer
        self.dropout1 = nn.Dropout2d(0.25)
        
        # - Fully-connected layer
        self.fc1 = nn.Linear(338, 10)
        
        # - Register an example buffer
        self.register_buffer('test_buf', torch.zeros(3, 4))

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        
        x = torch.flatten(x, 1)
        
        x = self.fc1(x)
        x = F.relu(x)
        
        output = F.log_softmax(x, dim = 1)
        return output

In [118]:
# - Instantiate the Rockpool class directly
rmod = RockpoolNet()
print(rmod)

In [119]:
# - Evaluate the module using the Rockpool API
output, _, _ = rmod(random_data)
print(output)

In [122]:
# - Access parameters using the Rockpool API
print('Parameters: ', rmod.parameters())
print('State: ', rmod.state())