In [2]:
import torch
import torch.nn as nn

In [6]:
class MyDecisionGate(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

In [8]:
class MyCell(nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = nn.Linear(4,4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)

In [9]:
cell.forward(x, h)

(tensor([[ 0.7993,  0.2224,  0.6987,  0.3583],
         [ 0.3484,  0.4165,  0.9276,  0.1115],
         [ 0.3825,  0.6215,  0.8559, -0.1185]], grad_fn=<TanhBackward0>),
 tensor([[ 0.7993,  0.2224,  0.6987,  0.3583],
         [ 0.3484,  0.4165,  0.9276,  0.1115],
         [ 0.3825,  0.6215,  0.8559, -0.1185]], grad_fn=<TanhBackward0>))

In [11]:
traced_cell = torch.jit.trace(cell, (x, h))

  if x.sum() > 0:


In [12]:
print(cell)

MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)


In [13]:
print(traced_cell)

MyCell(
  original_name=MyCell
  (dg): MyDecisionGate(original_name=MyDecisionGate)
  (linear): Linear(original_name=Linear)
)


In [14]:
traced_cell(x, h)

(tensor([[ 0.7993,  0.2224,  0.6987,  0.3583],
         [ 0.3484,  0.4165,  0.9276,  0.1115],
         [ 0.3825,  0.6215,  0.8559, -0.1185]], grad_fn=<TanhBackward0>),
 tensor([[ 0.7993,  0.2224,  0.6987,  0.3583],
         [ 0.3484,  0.4165,  0.9276,  0.1115],
         [ 0.3825,  0.6215,  0.8559, -0.1185]], grad_fn=<TanhBackward0>))

In [18]:
torch.allclose(traced_cell(x, h)[0], cell(x, h)[0]), torch.allclose(traced_cell(x, h)[1], cell(x, h)[1])
# In this situation instead of tracing doesn't work completely (look at above warning) but both result of 2 instances are same.

(True, True)

In [24]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %x.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %dg : __torch__.MyDecisionGate = prim::GetAttr[name="dg"](%self.1)
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %31 : Tensor = prim::CallMethod[name="forward"](%linear, %x.1)
  %32 : NoneType = prim::CallMethod[name="forward"](%dg, %31)
  %18 : int = prim::Constant[value=1]() # /tmp/ipykernel_54144/1343332710.py:8:0
  %19 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%31, %h, %18) # /tmp/ipykernel_54144/1343332710.py:8:0
  %20 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%19) # /tmp/ipykernel_54144/1343332710.py:8:0
  %21 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%20, %20)
  return (%21)



In [23]:
print(traced_cell.dg.graph)

graph(%self : __torch__.MyDecisionGate,
      %3 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)):
  %6 : NoneType = prim::Constant()
  return (%6)



In [25]:
scripted_cell = torch.jit.script(cell, (x, h))

