In [1]:
import torch 
import torch.nn as nn
import functorch
from torch.optim.lr_scheduler import MultiStepLR

import os
import time
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
import pandas as pd
import copy
import sys
from itertools import chain
from IPython.display import clear_output

In [2]:
cuda_num = 2
seed = 0
folder = 'cgdpo_european_vanilla_call'
if not os.path.exists(folder) :
    os.mkdir(folder)

torch.set_printoptions(sci_mode=False,precision=4)
np.set_printoptions(suppress=True,precision=4,linewidth=100)
dev = 'cuda:%d'%cuda_num

In [3]:
r = 0.07
vol = 0.3
K = 1.
T = 1.

m = 5
n = int(1e3)

T_max = T
S_min = 1e-1
S_max = 2.

In [4]:
class PricingNet(nn.Module) :
 
    def __init__(self) :
        super(PricingNet,self).__init__()
        
        self.linear1a = nn.Linear(2,200)
        self.linear2a = nn.Linear(200,200)
        self.linear3a = nn.Linear(200,200)
        self.linear4a = nn.Linear(200,1)
        
        self.F = nn.LeakyReLU()
        self.H = nn.Softplus()
        
    def forward(self,x0) :
        
        xa = self.linear1a(x0)
        xa = self.F(xa)
        xa = self.linear2a(xa)
        xa = self.F(xa)
        xa = self.linear3a(xa)
        xa = self.F(xa)
        xa = self.linear4a(xa)
        xa = self.H(xa)

        return xa.squeeze()

In [5]:
torch.manual_seed(seed)
np.random.seed(seed)

net_v = PricingNet()
net_v = net_v.to(dev)

opt_v = torch.optim.Adam(net_v.parameters(),lr=1e-3)

In [6]:
def generate_domain_uniform_sampling(n) :

    T = T_max*torch.rand([n],device=dev)
    dt = T/m
    S = S_min + (S_max-S_min)*torch.rand([n],device=dev)

    return T,S,dt

T,S,dt = generate_domain_uniform_sampling(n)

In [7]:
def black_scholes_call(S, T):
    d1 = (np.log(S / K) + (r + 0.5 * vol ** 2) * T) / (vol * np.sqrt(T))
    d2 = d1 - vol * np.sqrt(T)
    price = S * norm.cdf(d1) - K * np.exp(-r * T) * norm.cdf(d2)    
    return price

In [8]:
def per_sample_gradients(model, inputs):
    def compute_output(x):
        return model(x).sum()  

    grad_fn = functorch.grad(compute_output)
    batched_grad = functorch.vmap(grad_fn)(inputs)
    
    return batched_grad

def measure(xx) :

    mm = len(xx)
    print(xx.shape,mm)

    net_v.zero_grad()
    xx.requires_grad= True
    v_p = net_v(xx)
    v_p.backward(torch.ones_like(v_p))

    net_grad = 0.
    total_n_param = 0
    for param in net_v.parameters():
        if param.grad is not None:        
            n_param = np.prod(param.shape)
            total_n_param = total_n_param + n_param
            net_grad = net_grad + torch.sum(param.grad.view(mm,-1),axis=1)**2
    net_grad = net_grad/total_n_param

    xx.requires_grad= False
    net_v.zero_grad()
    
    xx_perturbed = xx.detach()
    xx_perturbed[:,0] = 0.99*xx_perturbed[:,0]

    with torch.no_grad() :
        v_p_perturbed = net_v(xx_perturbed)

    dv_squared = (v_p_perturbed - v_p)**2

    std = net_grad*dv_squared

    return std

def measure(xx):
    mm = len(xx)

    def compute_output(x):
        return net_v(x).sum()  

    grad_fn = functorch.grad(compute_output)
    batched_grad = functorch.vmap(grad_fn)(xx)  

    batched_grad_flat = batched_grad.view(mm, -1)
    
    total_n_param = sum(p.numel() for p in net_v.parameters())
    
    net_grad = torch.sum(batched_grad_flat ** 2, dim=1) / total_n_param

    xx_perturbed = xx.detach().clone()
    xx_perturbed[:,1] = 0.99 * xx_perturbed[:,1]

    with torch.no_grad():
        v_p = net_v(xx)
        v_p_perturbed = net_v(xx_perturbed)

    dv_relative_squared = (v_p_perturbed - v_p) ** 2 / v_p ** 2

    std = torch.sqrt( net_grad * dv_relative_squared )

    return std

In [9]:
def weighted_sampling_torch(vector, num_samples):
    tensor = torch.tensor(vector, dtype=torch.float)
    weights = torch.linspace(1, 0.1, len(vector))**2
    probabilities = weights / weights.sum()
    sampled_indices = torch.multinomial(probabilities, num_samples, replacement=False)
    samples = tensor[sampled_indices].tolist()
    return samples

def generate_domain_adaptive_sampling(n) :

    T,S,dt = generate_domain_uniform_sampling(5*n)
    
    xx = torch.rand([5*n,2],device=dev)
    xx[:,0] = S
    xx[:,1] = T

    std = measure(xx)
    S = S.reshape(-1)
    T = T.reshape(-1)
    std = std.reshape(-1)

    idx = torch.argsort(std,descending=True)
    n_sample = int(0.9*n)
    sample = weighted_sampling_torch(torch.arange(len(idx),device=dev), n_sample)
    S2 = S[idx][sample]
    T2 = T[idx][sample]
    
    T1 = T_max*torch.rand([n-n_sample],device=dev)
    T = torch.hstack([T1,T2])
    dt = T/m
    
    S1 = S_min + (S_max-S_min)*torch.rand([n-n_sample],device=dev)
    S = torch.hstack([S1,S2])

    return T,S,dt

In [10]:
def plot(iter,min,save,U,case,lr_p) :

    np.random.seed(seed)
    torch.manual_seed(seed)

    if not save :
        iter = -1

    def draw(x,t,u,filename) :
        
        tmp = u
        tmp = np.sort(tmp)
        m = 100
        n = (len(tmp)-1)//m
        levels = [tmp[i*n] for i in range(m+1)]
        levels[-1] = tmp[-1]
    
        color_map = plt.cm.jet
        colors = color_map(np.linspace(0, 1, m+1))
    
        plt.figure(figsize=(6.5,5))
        plt.plot(x, t, 'o', markersize=0.5, alpha=0.5, color='grey')
        try :
            contourf_plot = plt.tricontourf(x,t,u,levels=levels,colors=colors)
        except :
            contourf_plot = plt.tricontourf(x,t,u)
        plt.xlabel('S')
        plt.ylabel('T-t')
        plt.colorbar(contourf_plot)

        plt.savefig(filename)
        plt.close()

    # ************************************************************************************************** #

    T,S,dt = generate_domain_adaptive_sampling(10000)

    xx = torch.rand([10000,2],device=dev)
    xx[:,0] = S
    xx[:,1] = T

    std = measure(xx)
    x = xx[:,0].clone().cpu().numpy()
    t = xx[:,1].clone().cpu().numpy()
    std = std.detach().cpu().numpy()
    
    plt.figure(figsize=(5.2,5.2))
    plt.plot(x, t, 'o', markersize=0.5, alpha=0.5, color='grey')
    plt.xlabel('S'); plt.xlim([0,S_max])
    plt.ylabel('T-t'); plt.ylim([0,T_max])
    plt.grid()
    plt.savefig('%s/sampling_points_case_%d_iter_%d.png'%(folder,case,iter))
    plt.close()

    # ************************************************************************************************** #

    T,S,dt = generate_domain_uniform_sampling(10000)
    
    xx = torch.rand([10000,2],device=dev)
    xx[:,0] = S
    xx[:,1] = T
    
    with torch.no_grad() :
        v_p = net_v(xx)
        
    v_p = v_p.detach().cpu().numpy()
    x = xx[:,0].clone().cpu().numpy()
    t = xx[:,1].clone().cpu().numpy()
    
    v_sol = black_scholes_call(x,t)

    draw(x,t,v_p,'%s/v_net_case_%d_iter_%d.png'%(folder,case,iter))
    draw(x,t,v_sol,'%s/v_sol_case_%d_iter_%d.png'%(folder,case,iter))
    draw(x,t,v_p-v_sol,'%s/v_err_case_%d_iter_%d.png'%(folder,case,iter))
    idx = (v_sol>1e-6)
    err_v = np.sqrt( np.mean((v_p[idx]-v_sol[idx])**2/v_sol[idx]**2) )

    with open('%s/errs_case_%d_iter_%d.txt'%(folder,case,iter),'wt') as f :
        print("i: %d  min: %d  err_v: %10.6e  loss: %10.6e  lr_v: %10.3e"%(i,min,err_v,U,lr_v),file=f)

In [11]:
i = 0
min = 0
lr_v = opt_v.param_groups[0]['lr']
plot(i,min,False,0,0,lr_v)

  warn_deprecated('grad')
  warn_deprecated('vmap', 'torch.vmap')
  tensor = torch.tensor(vector, dtype=torch.float)


In [12]:
mins = np.array([1,5,10,30,60])

torch.manual_seed(seed)
np.random.seed(seed)

net_v = PricingNet()
net_v = net_v.to(dev)
opt_v = torch.optim.Adam(net_v.parameters(),lr=1e-3)

flag_write = False
flag_draw = False
flag_exit = False
i = 0; mins_idx = 0
start_time = time.time()
start_time_draw = time.time()
draw_time = 0.

while True :

    # ************************************************************************************************** #

    opt_v.zero_grad()

    T,S,dt = generate_domain_adaptive_sampling(n)

    U = []
    for k in range(m) :
        t = k*dt

        state = torch.vstack([S,T-t]).T
        v_p = net_v(state)

        U.append( torch.exp(-r*t)*v_p )
        
        dZ = torch.randn([n],device=dev)*torch.sqrt(dt)
        S = S*torch.exp( (r - 0.5*vol**2)*dt + vol*dZ )

    t = T
    h = torch.relu(S-K)

    loss = 0.
    for k in range(m) :
        loss = loss + (h-U[k])**2
    loss = loss*dt
    loss = loss.nanmean()
    loss.backward()
   
    max_norm = 0.01
    torch.nn.utils.clip_grad_norm_(net_v.parameters(), max_norm)

    opt_v.step()

    # ************************************************************************************************** #

    time_i = time.time()
    elapsed_time = time_i - start_time - draw_time
    current_min = int(elapsed_time//60)
    elapsed_time_draw = time_i - start_time_draw
    
    target_min = mins[mins_idx]
    current_set_time = 60*target_min
    last_set_time = 60*mins[-1]
    
    if elapsed_time_draw > 60*3 : 
        start_time_draw = time.time()
        flag_draw = True
    
    if elapsed_time > current_set_time : 
        flag_write = True
        mins_idx = mins_idx + 1
        
    if elapsed_time > last_set_time :
        flag_exit = True

    lr_v = opt_v.param_groups[0]['lr']
    
    if not flag_write :
        if flag_draw or i == 0 :
            print('current_min: %5d  target_min: %5d  last_min: %5d'%(current_min,target_min,mins[-1]))
            plot(i,current_min,False,loss,0,lr_v)
            flag_draw = False
        
    elif flag_write :
        print('current_min: %5d  target_min: %5d  last_min: %5d'%(current_min,target_min,mins[-1]),'-->','saved')
        plot(i,current_min,True,loss,0,lr_v)
        flag_write = False    

    draw_time = draw_time + (time.time() - time_i)

    if flag_exit : break
    i = i+1        

print('finished')

  tensor = torch.tensor(vector, dtype=torch.float)


current_min:     0  target_min:     1  last_min:    60
current_min:     1  target_min:     1  last_min:    60 --> saved
current_min:     2  target_min:     5  last_min:    60
current_min:     5  target_min:     5  last_min:    60 --> saved
current_min:     5  target_min:    10  last_min:    60
current_min:     8  target_min:    10  last_min:    60
current_min:    10  target_min:    10  last_min:    60 --> saved
current_min:    11  target_min:    30  last_min:    60
current_min:    14  target_min:    30  last_min:    60
current_min:    17  target_min:    30  last_min:    60
current_min:    20  target_min:    30  last_min:    60
current_min:    23  target_min:    30  last_min:    60
current_min:    26  target_min:    30  last_min:    60
current_min:    29  target_min:    30  last_min:    60
current_min:    30  target_min:    30  last_min:    60 --> saved
current_min:    32  target_min:    60  last_min:    60
current_min:    35  target_min:    60  last_min:    60
current_min:    38  targe