In [None]:
import torch
from models import Generator
gaf_generator = Generator()
gaf_generator.load_state_dict(torch.load('models/last/g0.pt'))
gaf_generator.eval()

In [6]:
from dataset import ForexData
import pandas as pd
df = pd.read_csv('data/eurusd_minute.csv')
print('Loaded csv')
forex_data = ForexData(df.BidClose.values, 64)

Loaded csv


In [None]:
from transforms import reverse_gaf, reverse_noisy_gaf
import numpy as np
import matplotlib.pyplot as plt

def gen_logs(n):
    x = torch.randn(n, 100)
    gafs = gaf_generator(x).squeeze()
    log_prices = np.array([reverse_noisy_gaf(mat).detach().numpy() for mat in gafs])
    return log_prices

def gen_diag_logs(n):
    x = torch.randn(n, 100)
    gafs = gaf_generator(x).squeeze()
    log_prices = reverse_gaf(gafs).detach().numpy() 
    return log_prices

def norm_ret(log_prices):
    returns = np.diff(log_prices, axis=1)
    return returns

In [None]:
from transforms import reverse_gaf
import numpy as np

n = 1000

idx = np.random.choice(range(len(forex_data)-n), n, replace=False)

real_logs = reverse_gaf(forex_data[idx]).numpy()
real_returns = norm_ret(real_logs).reshape(-1)

fake_logs = gen_logs(n)
fake_returns = norm_ret(fake_logs).reshape(-1)

fake_logs_diag = gen_diag_logs(n)
fake_returns_diag = norm_ret(fake_logs_diag).reshape(-1)

from scipy.stats import wasserstein_distance
w_distance = wasserstein_distance(real_returns, fake_returns)
w_distance_diag = wasserstein_distance(real_returns, fake_returns_diag)
print(f'Wasserstein distance: {w_distance:.4f}')
print(f'Wasserstein distance (diagonal decoding): {w_distance_diag:.4f}')

In [None]:
from os import devnull
from matplotlib.pyplot import figure
figure(figsize=(10, 8), dpi=80)

fig, (ax1,ax2,ax3) = plt.subplots(1,3)


ax1.hist(real_returns, bins=100, alpha=0.5, label='real', density=True)
ax1.hist(fake_returns, bins=100, alpha=0.5, label='fake', density=True)
ax1.legend()

ax2.hist(real_returns, bins=100, alpha=0.5, label='real', density=True)
ax2.hist(fake_returns_diag, bins=100, alpha=0.5, label='fake (diagonal)', density=True)
ax2.legend()

ax3.hist(fake_returns, bins=100, alpha=0.5, label='fake', density=True)
ax3.hist(fake_returns_diag, bins=100, alpha=0.5, label='fake (diagonal)', density=True)
ax3.legend()

# set fig size
fig.set_size_inches(15, 5)

plt.show()