In [1]:
import torch
import torch.nn.functional as F

import time
import numpy as np
from scipy.stats import multivariate_normal as normal

In [3]:
dtype = torch.float64
DELTA_CLIP = 50.0

In [4]:
class Equation(object):
    def __init__(self,eqn_config):
        self.dim = eqn_config["dim"]
        self.total_time = eqn_config["total_time"]
        self.num_time_interval = eqn_config["num_time_interval"]
        self.delta_t = self.total_time / self.num_time_interval
        self.sqrt_delta_t = np.sqrt(self.delta_t)
        self.y_init = None

In [5]:
class HJBLQ(Equation):
    def __init__(self,eqn_config):
        super(HJBLQ,self).__init__(eqn_config)
        self.x_init = np.zeros(self.dim)
        self.sigma = np.sqrt(2.0)
        self.lambd = 1.0
        
    def sample(self, num_sample):
        dw_sample = normal.rvs(size=[num_sample,
                                     self.dim,
                                     self.num_time_interval]) * self.sqrt_delta_t
        x_sample = np.zeros([num_sample, self.dim, self.num_time_interval + 1])
        x_sample[:, :, 0] = np.ones([num_sample, self.dim]) * self.x_init
        for i in range(self.num_time_interval):
            x_sample[:, :, i + 1] = x_sample[:, :, i] + self.sigma * dw_sample[:, :, i]
        return dw_sample, x_sample
    
    def f_tf(self,t,x,y,z):
        return -self.lambd * torch.mul(z,z).sum(1)
    def g_tf(self,t,x):
        return torch.log((1+torch.mul(x,x).sum(1))/2)

In [6]:
config = {
  "eqn_config": {
    "_comment": "HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115",
    "eqn_name": "HJBLQ",
    "total_time": 1.0,
    "dim": 100,
    "num_time_interval": 20
  },
  "net_config": {
    "y_init_range": [0, 1],
    "num_hiddens": [110, 110],
    "lr_values": [1e-2, 1e-2],
    "lr_boundaries": [1000],
    "num_iterations": 2000,
    "batch_size": 64,
    "valid_size": 256,
    "logging_frequency": 100,
    "dtype": "float64",
    "verbose": True
  }
}

In [13]:
class BSDESolver(object):
    def __init__(self,config,bsde):
        self.eqn_config = config["eqn_config"]
        self.net_config = config["net_config"]
        self.bsde = bsde
        
        self.model = NonsharedModel(config,bsde)
        self.t_init = self.model.y_init
        self.optimizer = torch.optim.Adam(self.model.parameters(),lr=1e-2, eps=1e-8)
        
    def train(self):
        start_time = time.time()
        training_history = []
        valid_data = self.bsde.sample(self.net_config["valid_size"])
        
        for step in range(self.net_config["num_iterations"]+1):
            if step % self.net_config["logging_frequency"] == 0 :
                loss = self.loss_fn(valid_data,training=False).numpy()
                y_init = self.y_init.numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, y_init, elapsed_time])
                
            loss = self.loss_fn(self.bsde.sample(self.net_config.batch_size),training=True)
            
            self.optimizer.zoro_grad()
            loss.backward()
            optimizer.step()
            
        return np.array(training_history)
        
        
    def loss_fn(self, inputs, training):
        dw, x = inputs
        y_terminal = self.model(inputs,training)
        delta = y_terminal - self.bsde.g_tf(self.bsde.total_time, x[:, :, -1])
        
        loss = torch.mean(torch.where(torch.abs(delta)<DELTA_CLIP,torch.mul(delta),
                                      2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP **2))
        
        return loss

In [12]:
class NonsharedModel(torch.nn.Module):
    def __init__(self,config,bsde):
        super(NonsharedModel,self).__init__()
        self.eqn_config = config["eqn_config"]
        self.net_config = config["net_config"]
        self.bsde = bsde
        
        self.y_init = torch.tensor(np.random.uniform(low=self.net_config["y_init_range"][0],
                                                    high=self.net_config["y_init_range"][1],
                                                    size=[1])
                                  )
        self.z_init = torch.tensor(np.random.uniform(low=-.1, high=.1,
                                                    size=[1, self.eqn_config["dim"]])
                                  )
        
        self.subnet = torch.nn.ModuleList([FeedForwardSubNet(config) for _ in range(self.eqn_config["num_time_interval"]-1)])
    
    def forward(self, inputs, training):
        dw = torch.tensor(inputs[0],dtype=dtype)
        x = torch.tensor(inputs[1],dtype=dtype)
        time_stamp = np.arange(0, self.eqn_config["num_time_interval"]) * self.bsde.delta_t
        all_one_vec = torch.ones(self.eqn_config["dim"],1,dtype=dtype)
        y = all_one_vec * self.y_init
        z = all_one_vec.mm(self.z_init)
        
        for t in range(0, self.bsde.num_time_interval-1):
            y = y - self.bsde.delta_t * (
                self.bsde.f_tf(time_stamp[t], x[:, :, t], y, z)
            ) + torch.mm(z,dw[:,:,t].t()).sum(1)          
            z = self.subnet[t](x[:, :, t + 1], training) / self.bsde.dim
        # terminal time
        y = y - self.bsde.delta_t * self.bsde.f_tf(time_stamp[-1], x[:, :, -2], y, z) + \
            (z * dw[:, :, -1]).sum(1) 

        return y
        

In [9]:
class FeedForwardSubNet(torch.nn.Module):
    def __init__(self, config):
        super(FeedForwardSubNet, self).__init__()
        dim = config["eqn_config"]["dim"]
        num_hiddens = config["net_config"]["num_hiddens"]
        
        self.bn_layers = torch.nn.ModuleList([torch.nn.BatchNorm1d(dim,momentum=0.99, eps=1e-6,affine=True)])
        for i in num_hiddens:
            self.bn_layers.append(torch.nn.BatchNorm1d(i,momentum=0.99, eps=1e-6,affine=True))
        self.bn_layers.append(torch.nn.BatchNorm1d(dim,momentum=0.99, eps=1e-6,affine=True))
  

        self.dense_layers = torch.nn.ModuleList([torch.nn.Linear(dim,num_hiddens[0],bias=False)])
        for i in range(len(num_hiddens)-1):
            self.dense_layers.append(torch.nn.Linear(num_hiddens[i],num_hiddens[i+1],bias=False))
        self.dense_layers.append(torch.nn.Linear(num_hiddens[-1],dim,bias=False))
  
    def forward(self,x,training):
        x = self.bn_layers[0](x)
        for i in range(len(self.dense_layers)-1):
            x = self.dense_layers[i](x)
            x = self.bn_layers[i+1](x)
            x = F.relu(x)
        x = self.dense_layers[-1](x)
        x = self.bn_layers[-1](x)
        return x

In [14]:
bsde = HJBLQ(config["eqn_config"])
bsde_solver = BSDESolver(config,bsde)

In [15]:
training_history = bsde_solver.train()

RuntimeError: expected scalar type Double but found Float

In [104]:
torch.tensor([1,2])[:]

tensor([1, 2])