In [None]:
%%capture
!pip install numpy pandas matplotlib torch gretel-synthetics


In [None]:
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

from gretel_synthetics.timeseries_dgan.dgan import DGAN
from gretel_synthetics.timeseries_dgan.config import DGANConfig, Normalization

In [None]:
S3_BASE_URL = "https://gretel-public-website.s3.us-west-2.amazonaws.com/datasets/wiki-web-traffic-data/"

wwt = pd.read_csv(S3_BASE_URL + "wikipedia-web-traffic-training.csv", index_col=0)
wwt.drop(columns=["domain", "access", "agent"], inplace=True)
assert wwt.shape[1] == 553
ATTRIBUTE_COLUMNS = ["domain_name", "access_name", "agent_name"]

wwt

In [None]:
# Extract numpy arrays for charts
wwt_attributes = wwt[ATTRIBUTE_COLUMNS].to_numpy()
print(wwt_attributes.shape)
wwt_features = np.expand_dims(wwt.drop(columns=ATTRIBUTE_COLUMNS).to_numpy(), axis=-1)
print(wwt_features.shape)

In [None]:
# Plot a few time series
for index in [0,1,2]:
    plt.plot(wwt_features[index,:,0])

plt.xlabel("day")
plt.ylabel("scaled page views")
plt.title("Sample WIKI time series")
plt.show()

In [None]:
# Autocorrelation computation
# From https://github.com/fjxmlzn/DoppelGANger/issues/20#issuecomment-858234890
EPS = 1e-8

def autocorr(X, Y):
    Xm = torch.mean(X, 1).unsqueeze(1)
    Ym = torch.mean(Y, 1).unsqueeze(1)
    r_num = torch.sum((X - Xm) * (Y - Ym), 1)
    r_den = torch.sqrt(torch.sum((X - Xm)**2, 1) * torch.sum((Y - Ym)**2, 1))

    r_num[r_num == 0] = EPS
    r_den[r_den == 0] = EPS

    r = r_num / r_den
    r[r > 1] = 0
    r[r < -1] = 0

    return r


def get_autocorr(feature):
    feature = torch.from_numpy(feature)
    feature_length = feature.shape[1]
    autocorr_vec = torch.Tensor(feature_length - 2)

    for j in range(1, feature_length - 1):
        autocorr_vec[j - 1] = torch.mean(autocorr(feature[:, :-j],
                                                  feature[:, j:]))

    return autocorr_vec.cpu().detach().numpy()

# Train fast model
Modified params for larger batch_size to better utilize the GPU.

Specific changes from params used in https://github.com/fjxmlzn/DoppelGANger
* batch_size=1000 (was 100)
* learning_rate=1e-4 (1e-3), changed for generator and discriminators



In [None]:
# Train DGAN model

config1 = DGANConfig(
    max_sequence_len=wwt.shape[1] - len(ATTRIBUTE_COLUMNS),
    sample_len=10,
    use_attribute_discriminator=True,
    gradient_penalty_coef=10.0,
    attribute_gradient_penalty_coef=10.0,
    generator_learning_rate=1e-4,
    discriminator_learning_rate=1e-4,
    attribute_discriminator_learning_rate=1e-4,
    attribute_loss_coef=1.0,
    apply_feature_scaling=False,  # features are already scaled to [-1,1]
    apply_example_scaling=True,
    normalization=Normalization.MINUSONE_ONE,
    batch_size=1000,
    epochs=400,
)

wwt_model1 = DGAN(config=config1)

start_time = time.time()
wwt_model1.train_dataframe(
    df=wwt,
    attribute_columns=ATTRIBUTE_COLUMNS,
)

# Generate data

synthetic1 = wwt_model1.generate_dataframe(50000)

end_time = time.time()
print("Elapsed time: {} seconds".format(end_time - start_time))

In [None]:
synthetic1

In [None]:
synthetic1.to_csv("synthetic_pytorch_fast.csv")

In [None]:
# Extract numpy arrays for charts
wwt_synthetic_attributes1 = synthetic1[ATTRIBUTE_COLUMNS].to_numpy()
print(wwt_synthetic_attributes1.shape)
wwt_synthetic_features1 = np.expand_dims(synthetic1.drop(columns=ATTRIBUTE_COLUMNS).to_numpy().astype("float"), axis=-1)
print(wwt_synthetic_features1.shape)



In [None]:
# Compare real and synthetic distribution of page views
plt.hist([wwt_features.flatten(),wwt_synthetic_features1.flatten()], bins=25, label=["real", "synthetic"])

plt.title("Feature value distribution")
plt.legend()
plt.show()


In [None]:
wwt_synthetic_acf1 = get_autocorr(wwt_synthetic_features1)
wwt_acf = get_autocorr(wwt_features)

In [None]:
# Figure 1, autocorrelation
plt.plot(wwt_acf, label="real")
plt.plot(wwt_synthetic_acf1, label="generated")
plt.xlabel("Time lag (days)")
plt.ylabel("Autocorrelation")
plt.title("Autocorrelation of daily page views for WWT dataset")
plt.legend()
plt.show()



In [None]:
# Zoom in on first 50 day lags of autocorrelation
plt.plot(wwt_acf[0:50], label="real")
plt.plot(wwt_synthetic_acf1[0:50], label="generated")
plt.xlabel("Time lag (days)")
plt.ylabel("Autocorrelation")
plt.legend()
plt.show()

# Original params
Paper uses batch_size=100 which is slower. Also a bit more inconsistent if the training produces a good model than the lower learning rate and larger batch size.

In [None]:
# Train DGAN model

config2 = DGANConfig(
    max_sequence_len=wwt.shape[1] - len(ATTRIBUTE_COLUMNS),
    sample_len=10,
    use_attribute_discriminator=True,
    gradient_penalty_coef=10.0,
    attribute_gradient_penalty_coef=10.0,
    generator_learning_rate=1e-3,
    discriminator_learning_rate=1e-3,
    attribute_discriminator_learning_rate=1e-3,
    attribute_loss_coef=1.0,
    apply_feature_scaling=False,  # features are already scaled to [-1,1]
    apply_example_scaling=True,
    normalization=Normalization.MINUSONE_ONE,
    batch_size=100,
    epochs=400,
)

wwt_model2 = DGAN(config=config2)

start_time = time.time()
wwt_model2.train_dataframe(
    df=wwt,
    attribute_columns=ATTRIBUTE_COLUMNS,
)

# Generate data

synthetic2 = wwt_model2.generate_dataframe(50000)

end_time = time.time()
print("Elapsed time: {} seconds".format(end_time - start_time))

In [None]:
synthetic2.head()

In [None]:
synthetic2.to_csv("synthetic_pytorch_original.csv")

In [None]:
wwt_synthetic_attributes2 = synthetic2[ATTRIBUTE_COLUMNS].to_numpy()
print(wwt_synthetic_attributes2.shape)
wwt_synthetic_features2 = np.expand_dims(synthetic2.drop(columns=ATTRIBUTE_COLUMNS).to_numpy(), axis=-1).astype("float")
print(wwt_synthetic_features2.shape)

In [None]:
wwt_synthetic_acf2 = get_autocorr(wwt_synthetic_features2)

In [None]:
# Figure 1
plt.plot(wwt_acf, label="real")
plt.plot(wwt_synthetic_acf2, label="generated")
plt.xlabel("Time lag (days)")
plt.ylabel("Autocorrelation")
plt.title("Autocorrelation of daily page views for WWT dataset")
plt.legend()
plt.show()
