In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [47]:
class CNNTradingAgent(nn.Module):
    def __init__(self):
        super().__init__()
        
        # If `True`, switch the network into evaluation mode when the batch size of the input is 1.
        # This is to avoid BatchNorm error when taking a single batch of input.
        self._auto_detect_single_batch = True
        
        # Bottleneck idea from Google's MobileNetV2
        
        # N * 256 * 16
        # x.transpose(-1, -2).contiguous().unsqueeze(-1)
        # N * 16 * 256 * 1
        self.conv0 = nn.Sequential(
            nn.LayerNorm([256, 1]),
            nn.Conv2d(16, 32, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)),
            nn.BatchNorm2d(32)
        )
        # N * 32 * 128 * 1
        self.bottleneck0 = nn.Sequential(
            nn.Conv2d(32, 192, kernel_size=1),
            nn.BatchNorm2d(192),
            nn.ReLU6(),
            nn.Conv2d(192, 192, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), groups=192),
            nn.BatchNorm2d(192),
            nn.ReLU6(),
            nn.Conv2d(192, 64, kernel_size=1),
            nn.BatchNorm2d(64),
            nn.AvgPool2d(kernel_size=(2, 1))
        )
        # N * 64 * 32 * 1
        self.bottleneck1 = nn.Sequential(
            nn.Conv2d(64, 384, kernel_size=1),
            nn.BatchNorm2d(384),
            nn.ReLU6(),
            nn.Conv2d(384, 384, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), groups=384),
            nn.BatchNorm2d(384),
            nn.ReLU6(),
            nn.Conv2d(384, 128, kernel_size=1),
            nn.BatchNorm2d(128),
            nn.AvgPool2d(kernel_size=(2, 1))
        )
        # N * 128 * 8 * 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(128, 512, kernel_size=1),
            nn.BatchNorm2d(512),
            nn.AvgPool2d(kernel_size=(8, 1))
        )
        # N * 512 * 1 * 1
        self.conv2 = nn.Conv2d(512, 3, kernel_size=1)
        
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        if self._auto_detect_single_batch and self.training and x.size(0) == 1:
            switched = True
            self.eval()
        else:
            switched = False
        
        x = x.transpose(-1, -2).unsqueeze(-1)
        x = self.conv0(x)
        x = self.bottleneck0(x)
        x = self.bottleneck1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 3)
        x = self.softmax(x)
        
        if switched:
            self.train()
        
        return x

In [None]:
class RNNTradingAgent(nn.Module):
    def __init__(self):
        super().__init__()
        
        # If `True`, switch the network into evaluation mode when the batch size of the input is 1.
        # This is to avoid BatchNorm error when taking a single batch of input.
        self._auto_detect_single_batch = True
        
        # N * 256 * 16
        # x.transpose(-1, -2)
        # N * 16 * 256
        self.ln = nn.LayerNorm(256)
        # x.transpose(-1, -2)
        # N * 256 * 16
        self.fc_in = nn.Linear(16, 32)
        # N * 256 * 32
        self.rnn = nn.GRU(input_size=32, hidden_size=64, num_layers=2, batch_first=True, bidirectional=True)
        # N * 256 * 128
        self.fc_out = nn.Linear()
        
        self nn.Module.register_parameter('init_hidden', torch.zeros(2 * 2, ))

        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        if self._auto_detect_single_batch and self.training and x.size(0) == 1:
            switched = True
            self.eval()
        else:
            switched = False
            
        h = x.new_zeros(2 * 2, x.size(0), 64)
        
        x = self.conv0(x)
        x = self.bottleneck0(x)
        x = self.bottleneck1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 3)
        x = self.softmax(x)
        
        if switched:
            self.train()
        
        return x

In [48]:
model = CNNTradingAgent()

In [49]:
model

CNNTradingAgent(
  (conv0): Sequential(
    (0): LayerNorm(torch.Size([256, 1]), eps=1e-05, elementwise_affine=True)
    (1): Conv2d(16, 32, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (bottleneck0): Sequential(
    (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6()
    (3): Conv2d(192, 192, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0), groups=192)
    (4): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU6()
    (6): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): AvgPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0)
  )
  (bottleneck1): Sequential(
    (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNo

In [50]:
toy_input = torch.randn(3, 256, 16)
print(toy_input)

tensor([[[ 1.6086,  1.0247,  0.8554,  ...,  0.0395,  2.3656, -1.0874],
         [ 0.2770,  0.8557,  0.4668,  ...,  0.2226,  0.9344,  0.7121],
         [-0.2932,  0.2818,  0.4620,  ...,  2.1147, -1.9017,  0.8169],
         ...,
         [-0.0395,  1.5317,  0.5689,  ..., -1.0811,  0.1842, -0.6972],
         [ 0.6453,  0.9932,  1.3931,  ..., -0.0586,  0.0549,  0.4356],
         [ 0.2619, -1.7053,  0.1622,  ...,  0.2687, -0.4136,  0.0930]],

        [[ 0.1257, -0.4109, -0.0192,  ...,  1.0087,  0.6787, -0.2372],
         [-1.7194,  0.1472,  1.4590,  ...,  1.2508,  1.2125,  0.1135],
         [-1.0114, -0.0751, -0.7786,  ..., -0.8571,  1.3649,  0.3910],
         ...,
         [-0.3882, -2.8003, -0.7627,  ...,  0.0135,  1.1887, -0.2866],
         [ 1.3291, -0.5450,  0.3006,  ..., -0.8672,  0.8090,  0.0489],
         [-0.0511, -1.1801, -0.7393,  ..., -0.1444, -0.3574, -0.0697]],

        [[ 1.4249,  0.7177, -0.3779,  ..., -1.6122, -3.0001, -0.3684],
         [-0.1546, -0.5691,  1.4686,  ..., -0

In [53]:
model(toy_input)

tensor([[0.3328, 0.3400, 0.3272],
        [0.3658, 0.2747, 0.3595],
        [0.3201, 0.3514, 0.3285]], grad_fn=<SoftmaxBackward>)