Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
#!/usr/bin/env python3
import torch.nn as nn
import cherry as ch
class RoboticsLinear(nn.Linear):
Akin to `nn.Linear`, but with proper initialization for robotic control.
Adapted from Ilya Kostrikov's implementation.
* **gain** (float, *optional*) - Gain factor passed to `robotics_init_` initialization.
* This class extends `nn.Linear` and supports all of its arguments.
linear = ch.nn.Linear(23, 5, bias=True)
action_mean = linear(state)
def __init__(self, *args, **kwargs):
gain = kwargs.pop('gain', None)
super(RoboticsLinear, self).__init__(*args, **kwargs)
ch.nn.init.robotics_init_(self, gain=gain)