In [1]:
import torch  # This is all you need to use both PyTorch and TorchScript!
print(torch.__version__)
import torchviz
torch.manual_seed(191009)  # set the seed for reproducibility
from torchviz import make_dot, make_dot_from_trace

1.12.1


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

(tensor([[0.8219, 0.8990, 0.6670, 0.8277],
        [0.5176, 0.4017, 0.8545, 0.7336],
        [0.6013, 0.6992, 0.2618, 0.6668]]), tensor([[0.8219, 0.8990, 0.6670, 0.8277],
        [0.5176, 0.4017, 0.8545, 0.7336],
        [0.6013, 0.6992, 0.2618, 0.6668]]))
MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8573,  0.6190,  0.5774,  0.7869],
        [ 0.3326,  0.0530,  0.0702,  0.8114],
        [ 0.7818, -0.0506,  0.4039,  0.7967]], grad_fn=<TanhBackward0>), tensor([[ 0.8573,  0.6190,  0.5774,  0.7869],
        [ 0.3326,  0.0530,  0.0702,  0.8114],
        [ 0.7818, -0.0506,  0.4039,  0.7967]], grad_fn=<TanhBackward0>))


In short, TorchScript provides tools to capture the definition of your model, even in light of the flexible and dynamic nature of PyTorch. Let’s begin by examining what we call tracing.

In [3]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    


my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)



MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


(tensor([[-0.0416,  0.7165,  0.5299, -0.0434],
         [ 0.2687,  0.1412,  0.6382,  0.1054],
         [ 0.3480,  0.5014,  0.6016, -0.3498]], grad_fn=<TanhBackward0>),
 tensor([[-0.0416,  0.7165,  0.5299, -0.0434],
         [ 0.2687,  0.1412,  0.6382,  0.1054],
         [ 0.3480,  0.5014,  0.6016, -0.3498]], grad_fn=<TanhBackward0>))

In [4]:
print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[-0.0416,  0.7165,  0.5299, -0.0434],
        [ 0.2687,  0.1412,  0.6382,  0.1054],
        [ 0.3480,  0.5014,  0.6016, -0.3498]], grad_fn=<TanhBackward0>), tensor([[-0.0416,  0.7165,  0.5299, -0.0434],
        [ 0.2687,  0.1412,  0.6382,  0.1054],
        [ 0.3480,  0.5014,  0.6016, -0.3498]], grad_fn=<TanhBackward0>))
(tensor([[-0.0416,  0.7165,  0.5299, -0.0434],
        [ 0.2687,  0.1412,  0.6382,  0.1054],
        [ 0.3480,  0.5014,  0.6016, -0.3498]], grad_fn=<TanhBackward0>), tensor([[-0.0416,  0.7165,  0.5299, -0.0434],
        [ 0.2687,  0.1412,  0.6382,  0.1054],
        [ 0.3480,  0.5014,  0.6016, -0.3498]], grad_fn=<TanhBackward0>))


In [5]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)


def forward(self,
    argument_1: Tensor) -> Tensor:
  return torch.neg(argument_1)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)



  if x.sum() > 0:


## Trying this on mechanisms


In [9]:
# the following two can be merged potentially:
class Torch_Irrev_MM_Uni(torch.nn.Module):
    def __init__(self,
                 vmax: float,
                 km_substrate: float,
                 to_be_learned):
        super(Torch_Irrev_MM_Uni, self).__init__()

        if to_be_learned[0]:
            # make mu a learnable parameter
            self.vmax = torch.nn.Parameter(torch.tensor([vmax]))
        else:
            self.vmax = vmax

        if to_be_learned[1]:
            self.km_substrate = torch.nn.Parameter(torch.Tensor([km_substrate]))
        else:
            self.km_substrate = km_substrate

    def forward(self, substrate):
        nominator = (self.vmax)*(substrate/self.km_substrate)
        # nominator=self.vmax*substrate
        # denominator=self.km_substrate + substrate
        denominator = (1+(substrate/self.km_substrate))
        return nominator/denominator
    

substrate=torch.Tensor([2.1])
mechanism=Torch_Irrev_MM_Uni(vmax=2.0,km_substrate=3.0,to_be_learned=[True,True])
traced_mechanism=torch.jit.trace(mechanism,(substrate))

print(list(traced_mechanism.parameters()))
traced_mechanism.forward(2)


# mechanism(2)



[Parameter containing:
tensor([2.], requires_grad=True), Parameter containing:
tensor([3.], requires_grad=True)]


tensor([0.8000], grad_fn=<DivBackward0>)

In [59]:
%%timeit
mechanism(2)

44.2 µs ± 6.66 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [60]:
%%timeit
traced_mechanism(2)

20.1 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
# the following two can be merged potentially:
class Torch_Irrev_MM_Uni(torch.nn.Module):
    def __init__(self,
                 vmax: float,
                 km_substrate: float,
                 to_be_learned):
        super(Torch_Irrev_MM_Uni, self).__init__()

        if to_be_learned[0]:
            # make mu a learnable parameter
            self.vmax = torch.nn.Parameter(torch.tensor([vmax]))
        else:
            self.vmax = vmax

        if to_be_learned[1]:
            self.km_substrate = torch.nn.Parameter(torch.Tensor([km_substrate]))
        else:
            self.km_substrate = km_substrate

    def calculate(self, substrate):
        nominator = (self.vmax)*(substrate/self.km_substrate)
        # nominator=self.vmax*substrate
        # denominator=self.km_substrate + substrate
        denominator = (1+(substrate/self.km_substrate))
        return nominator/denominator
    


