In [1]:
import torch 
import torch.nn as nn
from torch.optim.lr_scheduler import MultiStepLR

import os
import time
import numpy as np
import matplotlib.pyplot as plt

In [2]:
cuda_num = 1 
dev = 'cuda:%d'%cuda_num
seed = 0
folder = 'cgdpo_complete'
if not os.path.exists(folder) :
    os.mkdir(folder)

torch.set_printoptions(sci_mode=False,precision=4)
np.set_printoptions(suppress=True,precision=4,linewidth=100)

In [3]:
mu = 0.18
r = 0.07
vol = 0.3
gamma = 2.
rho = 0.1
eps = 0.1

m = 5
n = int(1e3)

T_max = 1.
W_min = 1e-1
W_max = 2.
lb_w = 1e-4
lb_c = 1e-4

In [4]:
class MyopicNet(nn.Module) :
 
    def __init__(self) :
        super(MyopicNet,self).__init__()
        
        self.linear1a = nn.Linear(2,200)
        self.linear2a = nn.Linear(200,200)
        self.linear3a = nn.Linear(200,200)
        self.linear4a = nn.Linear(200,1)
        
        self.F = nn.LeakyReLU()
        self.H = nn.Softplus()
        
    def forward(self,x0) :

        xa = torch.zeros_like(x0[:,0:2])
        
        xa = self.linear1a(xa)
        xa = self.F(xa)
        xa = self.linear2a(xa)
        xa = self.F(xa)
        xa = self.linear3a(xa)
        xa = self.F(xa)
        xa = self.linear4a(xa)

        return xa.squeeze()

In [5]:
class ConsumeNet(nn.Module) :
    
    def __init__(self) :
        super(ConsumeNet,self).__init__()
        
        self.linear1a = nn.Linear(2,200)
        self.linear2a = nn.Linear(200,200)
        self.linear3a = nn.Linear(200,200)
        self.linear4a = nn.Linear(200,1)
        
        self.F = nn.LeakyReLU()
        self.H = nn.Softplus()
        
    def forward(self,x0) :

        xa = x0
       
        xa = self.linear1a(xa)
        xa = self.F(xa)
        xa = self.linear2a(xa)
        xa = self.F(xa)
        xa = self.linear3a(xa)
        xa = self.F(xa)
        xa = self.linear4a(xa)
        xa = self.H(xa)
        
        return xa.squeeze()

In [6]:
torch.manual_seed(seed)
np.random.seed(seed)

net_c = ConsumeNet()
net_c = net_c.to(dev)

net_m = MyopicNet()
net_m = net_m.to(dev)

opt_c = torch.optim.Adam(net_c.parameters(),lr=1e-5)
opt_m = torch.optim.Adam(net_m.parameters(),lr=1e-3)

In [7]:
def generate_uniform_domain(n) :

    T = T_max*torch.rand([n],device=dev)
    dt = T/m
    W = W_min + (W_max-W_min)*torch.rand([n],device=dev)

    return T,W,dt

T,W,dt = generate_uniform_domain(n)

In [8]:
def get_nu() :
    nu = rho - (1.-gamma)*( (mu-r)**2/(2.*vol**2*gamma) + r )
    nu = nu/gamma
    return nu    

def get_consume(x, t) :
    nu = get_nu()
    c = nu*x
    c /= (1.+(nu*eps-1.)*np.exp(-nu*t))
    return c

def get_myopic(x, t) :
    return (mu-r)/(vol**2*gamma)*x

In [9]:
def measure(xx) :

    xx.requires_grad = True

    a_c = net_c(xx)
    
    grad_outputs = torch.ones_like(a_c) 
    d1 = torch.autograd.grad(a_c, xx, grad_outputs=grad_outputs, create_graph=True)[0]
    d11 = torch.autograd.grad(d1[:,0], xx, grad_outputs=grad_outputs, create_graph=True)[0][:,0:1].detach()
    d22 = torch.autograd.grad(d1[:,1], xx, grad_outputs=grad_outputs, create_graph=True)[0][:,1:2].detach()

    d11 = torch.sqrt(d11**2/torch.sum(d11**2))
    d22 = torch.sqrt(d22**2/torch.sum(d22**2))
    d2 = torch.cat([d11,d22],axis=1).detach()
    std = torch.sum(d2**2,axis=1,keepdims=True)

    xx.requires_grad = False

    return std

In [10]:
def weighted_sampling_torch(vector, num_samples):
    tensor = torch.tensor(vector, dtype=torch.float)
    weights = torch.linspace(1, 0.1, len(vector))**2
    probabilities = weights / weights.sum()
    sampled_indices = torch.multinomial(probabilities, num_samples, replacement=False)
    samples = tensor[sampled_indices].tolist()
    return samples

def generate_domain_adaptive_sampling(n) :

    T,W,dt = generate_uniform_domain(5*n)
    
    xx = torch.rand([5*n,2],device=dev)
    xx[:,0] = W
    xx[:,1] = T

    std = measure(xx)
    W = W.reshape(-1)
    T = T.reshape(-1)
    std = std.reshape(-1)

    idx = torch.argsort(std,descending=True)
    n_sample = int(0.9*n)
    sample = weighted_sampling_torch(torch.arange(len(idx),device=dev), n_sample)
    W2 = W[idx][sample]
    T2 = T[idx][sample]
    
    T1 = T_max*torch.rand([n-n_sample],device=dev)
    T = torch.hstack([T1,T2])
    dt = T/m
    
    W1 = W_min + (W_max-W_min)*torch.rand([n-n_sample],device=dev)
    W = torch.hstack([W1,W2])

    return T,W,dt

In [11]:
def plot(iter,min,save,U,case,lr_c,lr_m) :

    np.random.seed(seed)
    torch.manual_seed(seed)

    if not save :
        iter = -1

    def draw(x,t,u,filename) :
        
        tmp = u
        tmp = np.sort(tmp)
        m = 100
        n = (len(tmp)-1)//m
        levels = [tmp[i*n] for i in range(m+1)]
        levels[-1] = tmp[-1]
    
        color_map = plt.cm.jet
        colors = color_map(np.linspace(0, 1, m+1))
    
        plt.figure(figsize=(6.5,5))
        plt.plot(x, t, 'o', markersize=0.5, alpha=0.5, color='grey')
        try :
            contourf_plot = plt.tricontourf(x,t,u,levels=levels,colors=colors)
        except :
            contourf_plot = plt.tricontourf(x,t,u)
        plt.xlabel('W')
        plt.ylabel('T-t')
        plt.colorbar(contourf_plot)

        plt.savefig(filename)
        plt.close()

    # ************************************************************************************************** #

    T,W,dt = generate_domain_adaptive_sampling(10000)

    xx = torch.rand([10000,2],device=dev)
    xx[:,0] = W
    xx[:,1] = T

    std = measure(xx)
    x = xx[:,0].clone().cpu().numpy()
    t = xx[:,1].clone().cpu().numpy()
    std = std.detach().cpu().numpy()
    
    plt.figure(figsize=(5.2,5.2))
    plt.plot(x, t, 'o', markersize=0.5, alpha=0.5, color='grey')
    plt.xlabel('W'); plt.xlim([0,W_max])
    plt.ylabel('T-t'); plt.ylim([0,T_max])
    plt.grid()
    plt.savefig('%s/sampling_points_case_%d_iter_%d.png'%(folder,case,iter))
    plt.close()

    # ************************************************************************************************** #

    T,W,dt = generate_uniform_domain(10000)
    
    xx = torch.rand([10000,2],device=dev)
    xx[:,0] = W
    xx[:,1] = T
    
    with torch.no_grad() :
        a_c = net_c(xx)
        a_m = net_m(xx)
        
    a_c = a_c.detach().cpu().numpy()
    a_m = a_m.detach().cpu().numpy()
    x = xx[:,0].clone().cpu().numpy()
    t = xx[:,1].clone().cpu().numpy()

    a_c = a_c*x/(t+lb_c)
    a_m = a_m*x   
    
    c_sol = get_consume(x,t)
    m_sol = get_myopic(x,t)

    draw(x,t,a_c,'%s/c_net_case_%d_iter_%d.png'%(folder,case,iter))
    draw(x,t,c_sol,'%s/c_sol_case_%d_iter_%d.png'%(folder,case,iter))
    draw(x,t,a_c-c_sol,'%s/c_err_case_%d_iter_%d.png'%(folder,case,iter))
    err_c = np.sqrt( np.mean((a_c-c_sol)**2/c_sol**2) )

    draw(x,t,a_m,'%s/m_net_case_%d_iter_%d.png'%(folder,case,iter))
    draw(x,t,m_sol,'%s/m_sol_case_%d_iter_%d.png'%(folder,case,iter))
    draw(x,t,a_m-m_sol,'%s/m_err_case_%d_iter_%d.png'%(folder,case,iter))
    err_m = np.sqrt( np.mean((a_m-m_sol)**2/m_sol**2) )

    with open('%s/errs_case_%d_iter_%d.txt'%(folder,case,iter),'wt') as f :
        print("i: %d  min: %d  err_c: %10.6e  err_m: %10.6e  U: %10.6e  lr_c: %10.3e  lr_m: %10.3e"%(i,min,err_c,err_m,U,lr_c,lr_m),file=f)

In [12]:
i = 0
min = 0
lr_c = opt_c.param_groups[0]['lr']
lr_m = opt_m.param_groups[0]['lr']
plot(i,min,False,0,0,lr_c,lr_m)

  tensor = torch.tensor(vector, dtype=torch.float)


In [13]:
mins = np.array([1,5,10,30,60])

torch.manual_seed(seed)
np.random.seed(seed)

net_c = ConsumeNet()
net_c = net_c.to(dev)
opt_c = torch.optim.Adam(net_c.parameters(),lr=1e-5)

net_m = MyopicNet()
net_m = net_m.to(dev)
opt_m = torch.optim.Adam(net_m.parameters(),lr=1e-3)

flag_write = False
flag_draw = False
flag_exit = False
i = 0; mins_idx = 0
start_time = time.time()
start_time_draw = time.time()
draw_time = 0.

while True :

    # ************************************************************************************************** #

    opt_c.zero_grad()
    opt_m.zero_grad()

    T,W,dt = generate_domain_adaptive_sampling(n)

    U = 0.
    for k in range(m) :
        t = k*dt

        state = torch.vstack([W,T-t]).T
        a_c = net_c(state)*W/(T-t+lb_c)
            
        W = W - a_c*dt
        W = torch.relu(W - lb_w) + lb_w

        U = U + torch.exp(-rho*t) * a_c**(1-gamma)/(1-gamma) * (1.-torch.exp(-rho*dt))/rho
        
        state = torch.vstack([W,T-t]).T
        pi = net_m(state)      
        
        dZ = torch.randn([n],device=dev)*torch.sqrt(dt)
        W = W*torch.exp( (mu*pi + r*(1-pi) - 0.5*vol**2*pi**2 )*dt + vol*pi*dZ )
        W = torch.relu(W - lb_w) + lb_w

    t = T
    
    U = U + torch.exp(-rho*t) *eps**gamma * W**(1-gamma)/(1-gamma)     
    U = -U.nanmean()
    U.backward()
   
    max_norm = 0.01
    torch.nn.utils.clip_grad_norm_(net_c.parameters(), max_norm)
    torch.nn.utils.clip_grad_norm_(net_m.parameters(), max_norm)

    opt_c.step()
    opt_m.step()

    # ************************************************************************************************** #

    time_i = time.time()
    elapsed_time = time_i - start_time - draw_time
    current_min = int(elapsed_time//60)
    elapsed_time_draw = time_i - start_time_draw
    
    target_min = mins[mins_idx]
    current_set_time = 60*target_min
    last_set_time = 60*mins[-1]
    
    if elapsed_time_draw > 60*3 : 
        start_time_draw = time.time()
        flag_draw = True
    
    if elapsed_time > current_set_time : 
        flag_write = True
        mins_idx = mins_idx + 1
        
    if elapsed_time > last_set_time :
        flag_exit = True

    lr_c = opt_c.param_groups[0]['lr']
    lr_m = opt_m.param_groups[0]['lr']
    
    if not flag_write :
        if flag_draw or i == 0 :
            print('current_min: %5d  target_min: %5d  last_min: %5d'%(current_min,target_min,mins[-1]))
            plot(i,current_min,False,U,0,lr_c,lr_m)
            flag_draw = False
        
    elif flag_write :
        print('current_min: %5d  target_min: %5d  last_min: %5d'%(current_min,target_min,mins[-1]),'-->','saved')
        plot(i,current_min,True,U,0,lr_c,lr_m)
        flag_write = False    

    draw_time = draw_time + (time.time() - time_i)

    if flag_exit : break
    i = i+1        

print('finished')

  tensor = torch.tensor(vector, dtype=torch.float)


current_min:     0  target_min:     1  last_min:   480
current_min:     1  target_min:     1  last_min:   480 --> saved
current_min:     2  target_min:     5  last_min:   480
current_min:     5  target_min:     5  last_min:   480 --> saved
current_min:     5  target_min:    10  last_min:   480
current_min:     8  target_min:    10  last_min:   480
current_min:    10  target_min:    10  last_min:   480 --> saved
current_min:    11  target_min:    30  last_min:   480
current_min:    14  target_min:    30  last_min:   480
current_min:    17  target_min:    30  last_min:   480
current_min:    20  target_min:    30  last_min:   480
current_min:    23  target_min:    30  last_min:   480
current_min:    26  target_min:    30  last_min:   480
current_min:    29  target_min:    30  last_min:   480
current_min:    30  target_min:    30  last_min:   480 --> saved
current_min:    32  target_min:    60  last_min:   480
current_min:    35  target_min:    60  last_min:   480
current_min:    38  targe

KeyboardInterrupt: 