In [1]:
import numpy as np
import pandas as pd

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader

In [3]:
import models

In [4]:
train_VAE = True
train_GAN = True

In [5]:
E = 200 # GeV
M = 114 # GeV
c = 1

In [6]:
# input_dimension = 50
# data = np.random.beta(1, 2, (100_000, input_dimension))

In [7]:
data = np.loadtxt("data/events.txt")
# data = np.delete(data, [0, 4], axis=1)
input_dimension = data.shape[1]

In [8]:
data.mean(), data.std()

(20.579669222304243, 60.14590681015765)

In [9]:
data.shape

(100000, 10)

In [10]:
latent_dim = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae = models.VAE(input_dimension, latent_dim, device=device).to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) 


In [11]:
print(device, vae_optimizer.defaults)

cpu {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}


In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(vae)

1375710

In [13]:
class PCDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        if self.transform:
            item = self.transform(item)
        return item

In [14]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
X_train, X_test = train_test_split(data, test_size=0.15, random_state=42)
scaler = StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

In [15]:
print(X_train.mean(), X_train.std())
print(X_test.mean(), X_test.std())

3.989010295413209e-12 0.9999999999999998
0.00033962210244871914 1.003165219771273


In [16]:
train_dataset = PCDataset(X_train)
val_dataset = PCDataset(X_test)

In [17]:
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=True)

In [18]:
import time

In [19]:
def train_VAE_helper(epochs):
    vae.train()

    for i in range(epochs):

        start_time = time.time()

        vae_loss_avg, recon_loss_avg, KL_loss_avg = 0, 0, 0

        for batch in train_dataloader:

            vae_loss, recon_loss, KL_loss = vae.loss_function(batch.float().to(device))

            vae_loss_avg += vae_loss.detach().cpu().numpy() / len(train_dataloader)
            recon_loss_avg += recon_loss.detach().cpu().numpy() / len(train_dataloader)
            KL_loss_avg += KL_loss.detach().cpu().numpy() / len(train_dataloader)

            vae_optimizer.zero_grad()
            vae_loss.backward()
            vae_optimizer.step()

        print("Epoch:", i, f"time passed: {time.time() - start_time}s", f"Train Losses (avg.): {vae_loss_avg:.5f}, {recon_loss_avg:.5f}, {KL_loss_avg:.5f}")

In [20]:
if train_VAE:
    train_VAE_helper(5)

Epoch: 0 time passed: 1.1889012575149536e-05s Train Losses (avg.): 0.02810, 0.02564, 0.00491
Epoch: 1 time passed: 1.210441541671753e-05s Train Losses (avg.): 0.00252, 0.00251, 0.00002
Epoch: 2 time passed: 1.3817602157592774e-05s Train Losses (avg.): 0.00154, 0.00153, 0.00001
Epoch: 3 time passed: 1.2868263244628906e-05s Train Losses (avg.): 0.00115, 0.00114, 0.00000
Epoch: 4 time passed: 1.3550132513046264e-05s Train Losses (avg.): 0.00091, 0.00090, 0.00000


In [21]:
pd.DataFrame(X_train).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
count,85000.0,85000.0,85000.0,85000.0,85000.0,85000.0,85000.0,85000.0,85000.0,85000.0
mean,4.118267e-11,2.825452e-17,-4.9654400000000006e-17,-7.638784e-15,-1.278439e-12,-2.825452e-17,4.9654400000000006e-17,7.638784e-15,4.169381e-15,-1.829606e-14
std,1.000006,1.000006,1.000006,1.000006,1.000006,1.000006,1.000006,1.000006,1.000006,1.000006
min,-4.049064,-1.946276,-1.952758,-1.065234,-4.343428,-1.946913,-1.936472,-2.781868,-2.526795,-1.732742
25%,-0.6769501,-0.6963382,-0.702504,-0.9202029,-0.6740192,-0.6998756,-0.6994429,-0.8392746,-0.8758796,-0.8655321
50%,0.001601171,0.0008194046,-0.002646193,-0.3354269,0.001623124,-0.0008194046,0.002646193,0.3354269,0.1155677,-0.0007758934
75%,0.6771499,0.6998756,0.6994429,0.8392746,0.6755138,0.6963382,0.702504,0.9202029,0.9005816,0.8669307
max,4.266014,1.946913,1.936472,2.781868,4.197578,1.946276,1.952758,1.065234,1.505174,1.735803


In [22]:
input = torch.tensor(X_test).float().to(device)

In [23]:
input

tensor([[-0.4450,  1.7686, -0.5562,  ..., -0.4366, -0.5630, -1.4007],
        [ 1.0305,  0.5344, -1.8574,  ..., -1.4185, -1.3158, -0.3086],
        [ 0.9248,  1.8641,  0.5211,  ..., -1.3171, -1.2385, -1.4271],
        ...,
        [ 0.7533, -1.4933, -1.2526,  ..., -1.0162, -1.0103,  0.9683],
        [-1.6089, -0.9841,  1.6691,  ..., -1.1514, -1.1128,  0.5870],
        [-0.7212, -0.6345,  1.7932,  ..., -1.5185, -1.3927,  0.3753]])

In [24]:
vae.eval()
output, _, _ = vae(input.float())

In [25]:
output.detach().cpu().numpy()

array([[-0.47420746,  1.7344112 , -0.5791223 , ..., -0.44630432,
        -0.5461898 , -1.4156277 ],
       [ 0.98803216,  0.53016865, -1.8923774 , ..., -1.4962971 ,
        -1.3737543 , -0.2784671 ],
       [ 0.9675129 ,  1.7924148 ,  0.5299396 , ..., -1.3837998 ,
        -1.2794145 , -1.3923792 ],
       ...,
       [ 0.70091206, -1.5071204 , -1.2787206 , ..., -1.0722032 ,
        -1.0529424 ,  0.96420383],
       [-1.6333561 , -1.1203703 ,  1.686213  , ..., -1.1872065 ,
        -1.1452025 ,  0.67121375],
       [-0.74024844, -0.7351668 ,  1.7828931 , ..., -1.5783894 ,
        -1.4255934 ,  0.4268867 ]], dtype=float32)

In [26]:
import torch.nn.functional as F
F.mse_loss(input, output)

tensor(0.0019, grad_fn=<MseLossBackward0>)

In [27]:
scaler.inverse_transform(output.detach())

array([[100.97621646,  89.08042753, -29.36183777, ...,  30.44820564,
          1.89510231,   0.28728508],
       [101.04908593,  27.21798745, -96.89193301, ..., -17.86608267,
          1.39225378,   1.31722868],
       [101.04806337,  92.06009364,  27.66825267, ..., -12.68963915,
          1.44957692,   0.30834165],
       ...,
       [101.03477754, -77.43839284, -65.33652516, ...,   1.64814442,
          1.58718696,   2.44273455],
       [100.91845128, -57.57088004,  87.12605046, ...,  -3.6436108 ,
          1.53112744,   2.17736899],
       [100.96295853, -37.78281781,  92.09752386, ..., -21.64347079,
          1.36075506,   1.95607829]])

In [28]:
scaler.inverse_transform(input.detach())

array([[100.97767136,  90.83617262, -28.18133728, ...,  30.89663198,
          1.8849023 ,   0.30082782],
       [101.05120271,  27.43679191, -95.09550229, ..., -14.28524883,
          1.42745347,   1.2899062 ],
       [101.04593388,  95.7434981 ,  27.2143253 , ...,  -9.62098303,
          1.47443745,   0.27693803],
       ...,
       [101.03738629, -76.72693233, -63.99310538, ...,   4.22612515,
          1.6130702 ,   2.44643926],
       [100.91967194, -50.57134785,  86.24701884, ...,  -1.997712  ,
          1.55081789,   2.10112217],
       [100.96390777, -32.61309232,  92.62601312, ..., -18.88935201,
          1.38076104,   1.90933569]])

In [29]:
gan = models.GAN(input_shape = input_dimension)

In [32]:
count_parameters(gan), count_parameters(gan.discriminator), count_parameters(gan.generator)

(853899, 137217, 716682)

In [33]:
def train_GAN_helper(epochs):
    for i in range(epochs):

        start_time = time.time()

        d_loss_avg, g_loss_avg = 0, 0

        for batch in train_dataloader:

            d_loss, g_loss = gan.train_with_batch(batch.float().to(device))

            d_loss_avg += d_loss / len(train_dataloader)
            g_loss_avg += g_loss / len(train_dataloader)

        print("Epoch:", i, f"time passed: {time.time() - start_time:.2f}s", f"Train Losses (avg.): {d_loss_avg:.5f}, {g_loss_avg:.5f}")

In [34]:
if train_GAN:
    train_GAN_helper(50)

Epoch: 0 time passed: 12.738072156906128s Train Losses (avg.): 0.58388, 0.39128
Epoch: 1 time passed: 13.831665277481079s Train Losses (avg.): 0.58336, 0.39177
Epoch: 2 time passed: 12.70262098312378s Train Losses (avg.): 0.58293, 0.39176
Epoch: 3 time passed: 13.021274328231812s Train Losses (avg.): 0.58259, 0.39187
Epoch: 4 time passed: 16.3140230178833s Train Losses (avg.): 0.58256, 0.39180
Epoch: 5 time passed: 11.941186428070068s Train Losses (avg.): 0.58228, 0.39222
Epoch: 6 time passed: 11.398107051849365s Train Losses (avg.): 0.58145, 0.39231
Epoch: 7 time passed: 11.686532497406006s Train Losses (avg.): 0.58158, 0.39257
Epoch: 8 time passed: 11.522518634796143s Train Losses (avg.): 0.58171, 0.39256
Epoch: 9 time passed: 11.28159761428833s Train Losses (avg.): 0.58117, 0.39240
Epoch: 10 time passed: 11.740309000015259s Train Losses (avg.): 0.58050, 0.39245
Epoch: 11 time passed: 12.950963258743286s Train Losses (avg.): 0.58036, 0.39282
Epoch: 12 time passed: 11.70177412033081s 

In [35]:
gan.eval()

GAN(
  (discriminator): Discriminator(
    (model): Sequential(
      (0): Linear(in_features=10, out_features=512, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Sigmoid()
    )
  )
  (generator): Generator(
    (model): Sequential(
      (0): Linear(in_features=100, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace=True)
      (5): Linear(in_features=256, out_features=512, bias=True)
      (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace=True)
      (8): Linear(

In [37]:
gan_generated = scaler.inverse_transform(gan.generate(100_000).detach())

In [39]:
pd.DataFrame(gan_generated).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
count,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0
mean,101.002279,2.231386,-5.367242,-79.969061,101.007432,-2.46154,5.749641,80.206828,2.567504,1.503521
std,0.032845,32.124176,38.143343,24.446165,0.033005,32.054429,38.131589,23.863475,0.327815,0.68318
min,100.950014,-51.38741,-51.004176,-96.080099,100.950063,-51.353418,-51.839661,4.970495,1.619356,0.663725
25%,100.974047,-23.682358,-42.136508,-94.026677,100.979279,-27.496551,-32.075245,79.800535,2.508496,0.803094
50%,101.003774,4.034888,-20.36968,-88.354267,101.008773,-4.443499,21.263999,88.593596,2.662339,1.444782
75%,101.031322,27.500135,32.269019,-80.157406,101.040263,23.297379,42.288222,94.118636,2.790326,2.197374
max,101.049682,51.353418,51.839661,-4.970495,101.049342,51.38741,51.004176,96.108088,2.834605,2.475156


In [42]:
dummy_data = torch.Tensor(torch.normal(0, 1, size=(100_000, input_dimension)))
vae.eval()
vae_generated = scaler.inverse_transform(vae(dummy_data)[0].detach())

In [43]:
pd.DataFrame(vae_generated).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
count,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0
mean,100.999005,-1.156286,0.210678,-51.60534,101.001515,2.861635,-0.014338,52.207286,2.26016,1.586815
std,0.051067,33.247069,37.681944,26.876894,0.050768,33.197836,37.547272,26.914757,0.361631,0.653515
min,100.791574,-133.490389,-156.123924,-138.663897,100.789878,-122.534179,-146.077546,-69.14439,0.823752,-1.64578
25%,100.964718,-24.148581,-26.349363,-71.139865,100.967124,-20.054108,-26.371287,34.415042,2.014011,1.16253
50%,100.999106,-1.045769,0.56069,-53.720333,101.001621,2.908748,0.083792,54.23447,2.248934,1.586094
75%,101.033519,21.819946,26.871039,-33.896806,101.035994,25.91744,26.421167,71.785141,2.495454,2.016543
max,101.216289,122.755494,144.813166,70.842481,101.235118,132.610309,152.823891,140.023947,3.969447,4.623458


In [46]:
pd.DataFrame(data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
count,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0
mean,100.999815,-0.052197,0.336359,-50.928689,100.999737,0.052197,-0.336359,50.928689,2.226399,1.570741
std,0.049838,51.370386,51.45021,46.04548,0.049709,51.370386,51.45021,46.04548,0.608001,0.905922
min,100.798067,-99.998022,-99.996804,-100.0,100.784097,-99.996742,-99.9996,-77.020231,0.691638,7e-05
25%,100.96607,-35.954358,-35.795072,-93.337484,100.966294,-35.827942,-36.259057,12.218287,1.693285,0.785882
50%,100.99985,0.024934,0.227593,-66.311656,100.999805,-0.024934,-0.227593,66.311656,2.295771,1.568805
75%,101.033535,35.827942,36.259057,-12.218287,101.03335,35.954358,35.795072,93.337484,2.7745,2.356837
max,101.212441,99.996742,99.9996,77.020231,101.208068,99.998022,99.996804,100.0,3.141589,3.141583
