In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from data.custom_dataset import CustomDataset
from Network.conditional_net import ConditionalNet
from Network.sdf_net import SDFNet

In [12]:
MACRO_SIZE = 124
LSTM_HIDDEN = 16
LSTM_LAYER = 1
CHAR_SIZE = 27
FF_HIDDEN = [32, 16]
FF_LAYER = 2
BATCH_SIZE = 3

In [30]:
conditional_net = ConditionalNet(
    MACRO_SIZE, CHAR_SIZE, LSTM_HIDDEN, LSTM_LAYER, FF_HIDDEN, FF_LAYER)
sdf_net = SDFNet(
    MACRO_SIZE, CHAR_SIZE, LSTM_HIDDEN, LSTM_LAYER, FF_HIDDEN, FF_LAYER)
optimizer_C = torch.optim.Adam(conditional_net.parameters(), lr=0.0002, betas=(0.9, 0.99))
optimizer_S = torch.optim.Adam(sdf_net.parameters(), lr=0.1, betas=(0.9, 0.99))

In [4]:
CHAR_PATH = "./data/27_features_rets_normalized_filled.csv"
MACRO_PATH = "./data/124_macro_data.csv"
dataset = CustomDataset(CHAR_PATH, MACRO_PATH)


In [31]:
for i in range(len(dataset)):
    
    
    data = dataset[i]
    char = torch.from_numpy(data["char"]).float()
    macro = torch.from_numpy(data["macro"]).float()
    ret = torch.from_numpy(data["ret"]).float()
   
    
    SDF = sdf_net(macro, char, ret)
    g_hat = conditional_net(macro, char, ret)
    conditional_loss = -1*(torch.sum((SDF*g_hat).pow(2)))
    optimizer_C.zero_grad()
    conditional_loss.backward()
    optimizer_C.step()


    for j in range(100):
        SDF = sdf_net(macro, char, ret)
        g_hat = conditional_net(macro,char,ret)
        sdf_loss = (torch.sum((SDF*g_hat).pow(2)))
        g_hat_mean = torch.mean(g_hat)

        optimizer_S.zero_grad()
        sdf_loss.backward()
        optimizer_S.step()
    print("[Iter %d/%d] [sdf_loss: %f] [sdf: %f] [g_hat_mean: %f]"
                % (i, len(dataset), sdf_loss.item(), SDF.item(), g_hat_mean.item())
                )
    

[Iter 0/659] [sdf_loss: 111.663902] [sdf: 0.939704] [g_hat_mean: -0.005539]
[Iter 1/659] [sdf_loss: 65.536499] [sdf: 0.699179] [g_hat_mean: -0.007095]
[Iter 2/659] [sdf_loss: 20.622366] [sdf: 0.429297] [g_hat_mean: -0.002810]
[Iter 3/659] [sdf_loss: 56.711067] [sdf: 0.904513] [g_hat_mean: 0.002500]
[Iter 4/659] [sdf_loss: 56.882061] [sdf: 0.784364] [g_hat_mean: 0.022514]
[Iter 5/659] [sdf_loss: 59.827568] [sdf: 0.950487] [g_hat_mean: -0.001727]
[Iter 6/659] [sdf_loss: 79.560730] [sdf: 0.947687] [g_hat_mean: 0.003719]
[Iter 7/659] [sdf_loss: 53.685181] [sdf: 0.730529] [g_hat_mean: -0.000498]
[Iter 8/659] [sdf_loss: 0.000088] [sdf: -0.000776] [g_hat_mean: -0.019843]
[Iter 9/659] [sdf_loss: 0.000244] [sdf: 0.001494] [g_hat_mean: -0.018766]
[Iter 10/659] [sdf_loss: 0.007980] [sdf: -0.009161] [g_hat_mean: 0.000018]
[Iter 11/659] [sdf_loss: 0.000742] [sdf: -0.002016] [g_hat_mean: -0.021885]
[Iter 12/659] [sdf_loss: 0.001678] [sdf: 0.003399] [g_hat_mean: -0.025211]
[Iter 13/659] [sdf_loss: 0.

[Iter 109/659] [sdf_loss: 0.001872] [sdf: 0.001751] [g_hat_mean: 0.012373]
[Iter 110/659] [sdf_loss: 0.000117] [sdf: -0.000371] [g_hat_mean: 0.010835]
[Iter 111/659] [sdf_loss: 0.002068] [sdf: 0.002020] [g_hat_mean: 0.013193]
[Iter 112/659] [sdf_loss: 0.022630] [sdf: 0.005769] [g_hat_mean: 0.003635]
[Iter 113/659] [sdf_loss: 0.000000] [sdf: -0.000010] [g_hat_mean: 0.002669]
[Iter 114/659] [sdf_loss: 0.000080] [sdf: 0.000387] [g_hat_mean: 0.001629]
[Iter 115/659] [sdf_loss: 0.009241] [sdf: -0.003614] [g_hat_mean: 0.011664]
[Iter 116/659] [sdf_loss: 0.046840] [sdf: -0.009548] [g_hat_mean: -0.000077]
[Iter 117/659] [sdf_loss: 0.004685] [sdf: 0.002484] [g_hat_mean: 0.012891]
[Iter 118/659] [sdf_loss: 0.025955] [sdf: 0.005782] [g_hat_mean: 0.012684]
[Iter 119/659] [sdf_loss: 0.038073] [sdf: -0.006807] [g_hat_mean: -0.015232]
[Iter 120/659] [sdf_loss: 0.035162] [sdf: 0.006695] [g_hat_mean: 0.004222]
[Iter 121/659] [sdf_loss: 0.000059] [sdf: -0.000274] [g_hat_mean: -0.004613]
[Iter 122/659] [

[Iter 217/659] [sdf_loss: 0.000199] [sdf: 0.000346] [g_hat_mean: 0.003662]
[Iter 218/659] [sdf_loss: 0.001556] [sdf: 0.000888] [g_hat_mean: -0.007852]
[Iter 219/659] [sdf_loss: 0.001186] [sdf: -0.000830] [g_hat_mean: -0.012004]
[Iter 220/659] [sdf_loss: 38.197178] [sdf: 0.126656] [g_hat_mean: 0.003885]
[Iter 221/659] [sdf_loss: 0.001070] [sdf: -0.000725] [g_hat_mean: 0.021907]
[Iter 222/659] [sdf_loss: 0.217344] [sdf: 0.009158] [g_hat_mean: 0.031811]
[Iter 223/659] [sdf_loss: 0.068256] [sdf: -0.005753] [g_hat_mean: 0.024580]
[Iter 224/659] [sdf_loss: 0.009413] [sdf: -0.002230] [g_hat_mean: 0.021620]
[Iter 225/659] [sdf_loss: 97.310516] [sdf: 0.218321] [g_hat_mean: 0.014409]
[Iter 226/659] [sdf_loss: 549.792786] [sdf: 0.471619] [g_hat_mean: 0.018051]
[Iter 227/659] [sdf_loss: 0.004447] [sdf: 0.001449] [g_hat_mean: -0.006465]
[Iter 228/659] [sdf_loss: 0.076088] [sdf: 0.006031] [g_hat_mean: 0.003617]
[Iter 229/659] [sdf_loss: 1397.285034] [sdf: 0.861471] [g_hat_mean: 0.019449]
[Iter 230/6

[Iter 325/659] [sdf_loss: 1813.387329] [sdf: 0.759526] [g_hat_mean: 0.014617]
[Iter 326/659] [sdf_loss: 149.504166] [sdf: -0.184983] [g_hat_mean: 0.027255]
[Iter 327/659] [sdf_loss: 0.243544] [sdf: -0.007294] [g_hat_mean: -0.005466]
[Iter 328/659] [sdf_loss: 0.641891] [sdf: 0.011983] [g_hat_mean: 0.009706]
[Iter 329/659] [sdf_loss: 0.109374] [sdf: -0.004481] [g_hat_mean: 0.011256]
[Iter 330/659] [sdf_loss: 0.011544] [sdf: 0.001530] [g_hat_mean: 0.006059]
[Iter 331/659] [sdf_loss: 0.019897] [sdf: 0.001970] [g_hat_mean: -0.000394]
[Iter 332/659] [sdf_loss: 0.170296] [sdf: 0.005590] [g_hat_mean: 0.000199]
[Iter 333/659] [sdf_loss: 0.029373] [sdf: -0.002289] [g_hat_mean: 0.000343]
[Iter 334/659] [sdf_loss: 0.046000] [sdf: -0.002954] [g_hat_mean: 0.010493]
[Iter 335/659] [sdf_loss: 0.192069] [sdf: 0.004584] [g_hat_mean: -0.060258]
[Iter 336/659] [sdf_loss: 303.386200] [sdf: 0.209612] [g_hat_mean: -0.006519]
[Iter 337/659] [sdf_loss: 0.077150] [sdf: -0.004061] [g_hat_mean: 0.003607]
[Iter 33

[Iter 432/659] [sdf_loss: 2888.577393] [sdf: 0.604617] [g_hat_mean: 0.071506]
[Iter 433/659] [sdf_loss: 283.402740] [sdf: 0.192236] [g_hat_mean: 0.026530]
[Iter 434/659] [sdf_loss: 2333.519531] [sdf: 0.552957] [g_hat_mean: 0.001988]
[Iter 435/659] [sdf_loss: 4276.760742] [sdf: 0.708958] [g_hat_mean: -0.009591]
[Iter 436/659] [sdf_loss: 1780.915649] [sdf: 0.468168] [g_hat_mean: 0.048108]
[Iter 437/659] [sdf_loss: 0.004442] [sdf: 0.000715] [g_hat_mean: 0.065275]
[Iter 438/659] [sdf_loss: 0.004229] [sdf: -0.000682] [g_hat_mean: 0.010365]
[Iter 439/659] [sdf_loss: 0.002403] [sdf: -0.000584] [g_hat_mean: 0.027728]
[Iter 440/659] [sdf_loss: 170.769562] [sdf: 0.144886] [g_hat_mean: -0.033460]
[Iter 441/659] [sdf_loss: 26.282108] [sdf: 0.053948] [g_hat_mean: -0.006506]
[Iter 442/659] [sdf_loss: 0.000202] [sdf: -0.000143] [g_hat_mean: -0.002365]
[Iter 443/659] [sdf_loss: 0.035865] [sdf: -0.001931] [g_hat_mean: 0.000156]
[Iter 444/659] [sdf_loss: 14.888023] [sdf: 0.040844] [g_hat_mean: 0.067438]

[Iter 538/659] [sdf_loss: 0.013906] [sdf: 0.001169] [g_hat_mean: -0.002363]
[Iter 539/659] [sdf_loss: 0.006141] [sdf: 0.000745] [g_hat_mean: -0.068427]
[Iter 540/659] [sdf_loss: 91.773460] [sdf: 0.112775] [g_hat_mean: 0.025355]
[Iter 541/659] [sdf_loss: 1.331595] [sdf: 0.014514] [g_hat_mean: 0.025073]
[Iter 542/659] [sdf_loss: 0.002394] [sdf: -0.000723] [g_hat_mean: -0.019524]
[Iter 543/659] [sdf_loss: 2.487014] [sdf: -0.017742] [g_hat_mean: 0.069131]
[Iter 544/659] [sdf_loss: 0.035039] [sdf: -0.002433] [g_hat_mean: 0.065062]
[Iter 545/659] [sdf_loss: 0.017522] [sdf: 0.001905] [g_hat_mean: 0.030071]
[Iter 546/659] [sdf_loss: 0.437365] [sdf: 0.009166] [g_hat_mean: 0.040030]
[Iter 547/659] [sdf_loss: 0.027876] [sdf: 0.002262] [g_hat_mean: 0.032306]
[Iter 548/659] [sdf_loss: 0.102642] [sdf: -0.004397] [g_hat_mean: 0.006816]
[Iter 549/659] [sdf_loss: 0.030633] [sdf: 0.002349] [g_hat_mean: 0.037657]
[Iter 550/659] [sdf_loss: 0.025425] [sdf: 0.001657] [g_hat_mean: -0.012678]
[Iter 551/659] [

[Iter 645/659] [sdf_loss: 952.232666] [sdf: 0.391097] [g_hat_mean: 0.058887]
[Iter 646/659] [sdf_loss: 2452.382324] [sdf: 0.582896] [g_hat_mean: 0.091578]
[Iter 647/659] [sdf_loss: 5354.444336] [sdf: 1.240822] [g_hat_mean: -0.042293]
[Iter 648/659] [sdf_loss: 12823.214844] [sdf: 1.009651] [g_hat_mean: 0.048361]
[Iter 649/659] [sdf_loss: 8071.586914] [sdf: 0.755612] [g_hat_mean: -0.023319]
[Iter 650/659] [sdf_loss: 383.016693] [sdf: 0.215893] [g_hat_mean: 0.071843]
[Iter 651/659] [sdf_loss: 4.460891] [sdf: 0.023752] [g_hat_mean: 0.024972]
[Iter 652/659] [sdf_loss: 6223.739746] [sdf: 0.427124] [g_hat_mean: 0.019736]
[Iter 653/659] [sdf_loss: 0.005718] [sdf: 0.000803] [g_hat_mean: 0.074872]
[Iter 654/659] [sdf_loss: 269.052612] [sdf: 0.168907] [g_hat_mean: 0.040079]
[Iter 655/659] [sdf_loss: 0.503791] [sdf: -0.006055] [g_hat_mean: 0.070993]
[Iter 656/659] [sdf_loss: 0.771129] [sdf: 0.010433] [g_hat_mean: 0.059400]
[Iter 657/659] [sdf_loss: 0.001669] [sdf: -0.000405] [g_hat_mean: 0.005226]

In [16]:
t1 = torch.tensor([1,2,3,4])
t2 = torch.tensor([1,2,3,4])
t3 = torch.mul(t1,t2)
t3

tensor([ 1,  4,  9, 16])