In [1]:
# run this all in docker 
import torch.nn as nn 
import torch.nn.functional as F 
import torch 
print(torch.__version__)

In [None]:
#https://github.com/pytorch/tutorials/blob/master/beginner_source/data_loading_tutorial.py

In [None]:
# detect device 
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

device = torch.device("cuda" if args.cuda else "cpu")


In [None]:
# below is a module- basic unit of composition in pytorch. Contains: 1. a constructor which prepares the module for invocation
# 2 set of parameters and a sub modules. 
# 3 a forward function 
# this is basically running a relu func on x + h where x & h are random tensors of 3x4 

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        
    def forward(self, x, h):
        new_h = torch.relu(x + h)
        return new_h, new_h
my_cell = MyCell()
x = torch.rand(3,4)
h = torch.rand(3,4)
print(my_cell(x, h))

In [None]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
        self.conv2d = torch.nn.Conv2d(4, 4, 2)
        
    def forward(self, x, h):
        new_h = torch.relu(self.linear(x) + h)
        return new_h, new_h
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

In [None]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x 
        else:
            return -x
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

In [None]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4,4)
    
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

In [None]:
print(traced_cell.graph)


In [None]:
print(traced_cell.code)

In [None]:
print(my_cell(x, h))


In [None]:
print(traced_cell(x, h))