In [1]:
import pandas as pd 
import numpy as np
from utils import * 
from grid_search_parallelized import GAN,Generator,Discriminator,Params
import json


In [2]:
file_path = 'data_train_log_return.csv'
header = ["stock1", "stock2", "stock3", "stock4"]
df_train = pd.read_csv(file_path, header=None,index_col=0)
df_train.columns = header


In [4]:
with open("submission/model_params.json", "r") as json_file:
    config = json.load(json_file)
    latent_dim = config['latent_dim']
    g_config = config['generator']
    d_config = config['discriminator']

opt = Params()  # Assuming Params is a class that contains other GAN parameters
opt.latent_dim = latent_dim 
opt.n_epochs =0 ## model weights are already trained

def generate_noise(n_samples):
    # Create covariance matrix with 1 on the diagonal and random values for non-diagonal elements
    covariance_matrix= 0.75 ** np.abs(np.subtract.outer ( np.arange(opt.latent_dim),np.arange (opt.latent_dim)))
    noise = np.random.multivariate_normal(mean=np.zeros(opt.latent_dim),
                                            cov=covariance_matrix,
                                            size=n_samples)
    squared_noise = noise**2
    cube_noise = noise **3
    noise = np.concatenate([noise, squared_noise,cube_noise], axis=1)
    return noise

## build GAN
generator = Generator(latent_dim,output_shape = opt.shape_data, **g_config)

### load weigths
generator.model.load_weights('submission/generator_weights.h5')
discriminator = Discriminator(opt.shape_data,**d_config)
gan = GAN(generator,discriminator,opt)


### import noise
noise = pd.read_csv('submission/noise.csv',index_col=0)
noise = noise.values
synthetic_data = generator.model.predict(noise)
synthetic_data = pd.DataFrame(synthetic_data,columns = df_train.columns) /100
synthetic_data





Unnamed: 0,stock1,stock2,stock3,stock4
0,0.008924,0.001562,0.008722,0.009693
1,0.029719,0.023891,0.018772,0.019281
2,0.016568,0.004815,0.006206,0.012109
3,0.031298,0.058209,0.002368,0.007062
4,0.028111,0.008061,0.004144,0.034699
...,...,...,...,...
405,0.014071,0.037233,0.006712,0.009907
406,0.006035,0.001339,0.009284,0.004527
407,0.017424,0.018642,0.003280,0.002421
408,0.035604,0.014567,0.005901,0.020795
