In [1]:
import torch
from torch import nn

from ctgan import TVAE
from ctgan.data_transformer import DataTransformer
from sdmetrics.reports.single_table import QualityReport
from sdv.metadata import SingleTableMetadata
from sdv.single_table.utils import detect_discrete_columns

import os
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

from ctgan import TVAE
import warnings


warnings.filterwarnings("ignore")

In [2]:
def get_dataset(filename: str):
    data = pd.read_csv(filename, sep=",")
    data.reset_index()
    data = data.iloc[:, 2:-6]

    # remove columns with nan values
    tmp = data.isna().any()
    na_columns = tmp[lambda x: x].index.to_list()
    print("columns with nan values", na_columns)
    data.drop(columns=na_columns, axis=1, inplace=True)
    return data


filename = os.path.normpath("../data/fulldataset.csv")
data = get_dataset(filename)

idx = list(range(data.shape[0]))
fit_idx, val_idx = train_test_split(idx, test_size=0.20, random_state=42)

fit = data.iloc[fit_idx, :]
val = data.iloc[val_idx, :]

binary_columns = [col for col in data.columns if data[col].dtype == "int64"]
data[binary_columns] = data[binary_columns].astype("category")

metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
discrete_columns = detect_discrete_columns(metadata, data)

columns with nan values ['data_172', 'data_173']


Check whether dataset was transformed before or not 

In [3]:
is_transformed = True

savepath = os.path.normpath("../data/transformed_data_tvae")
filename = "transformed_data.pkl"
if not is_transformed:
    data_transformer = DataTransformer()
    data_transformer.fit(fit, discrete_columns)
    os.makedirs(savepath, exist_ok=True)
    torch.save(
        {
            "fit": fit,
            "val": val,
            "data_transformer": data_transformer,
        },
        os.path.join(savepath, filename),
    )
else:
    cnt = torch.load(
        os.path.join(
            savepath,
            filename,
        )
    )

    fit = cnt["fit"]
    val = cnt["val"]
    data_transformer = cnt["data_transformer"]

Fit the model to dataset

In [4]:
# from ctgan import TVAE
# tvae = TVAE(
#     epochs=30,
#     cuda=True
# )
# tvae.fit(fit, discrete_columns, transformer=data_transformer)

Generate and save the quality report

In [5]:
@torch.no_grad()
def sample(model, n_samples, device=None):
    model.decoder.eval()

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    steps = n_samples // model.batch_size + 1
    data = []
    for _ in range(steps):
        mean = torch.zeros(model.batch_size, model.embedding_dim)
        noise = torch.normal(mean=mean, std=mean+1).to(model.device)
        fake, sigmas = model.decoder(noise)
        fake = torch.tanh(fake)
        data.append(fake.detach().cpu().numpy())

    data = np.concatenate(data, axis=0)
    data = data[:n_samples]
    return model.transformer.inverse_transform(data, sigmas.detach().cpu().numpy())


In [6]:
# fake = sample(tvae, val.shape[0])
# real = val


In [7]:
# quality_report = QualityReport()
# quality_report.generate(real, fake, metadata.to_dict())
# quality_report.get_visualization(property_name="Column Shapes").show()
# quality_report.get_visualization(property_name="Column Pair Trends").show()

# savepath = os.path.normpath("../results/sdmetrics_quality_reports/")
# os.makedirs(savepath, exist_ok=True)
# quality_report.save(os.path.join(savepath, "quality_report_tvae.pkl"))

Can we do better?

In [8]:
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models import VAMP, VAMPConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.trainers import BaseTrainerConfig
from pythae.models.base.base_utils import ModelOutput
from pythae.trainers import BaseTrainer, BaseTrainerConfig
from pythae.data.datasets import BaseDataset

from uqvae.models import Encoder, Decoder

class bEncoder(BaseEncoder):
    def __init__(self, data_dim, compress_dims, embedding_dim, add_dropouts=False, p=None):
        BaseEncoder.__init__(self)
        seq = []
        dim = data_dim
        for item in compress_dims:
            seq += [
                nn.Linear(dim, item), 
                nn.ReLU(),
            ]
            dim = item

            if add_dropouts:
                seq += [nn.Dropout(p)]

        self.seq = nn.Sequential(*seq)
        self.fc1 = nn.Linear(dim, embedding_dim)
        self.fc2 = nn.Linear(dim, embedding_dim)

    def forward(self, x):
        output = ModelOutput()
        xx = self.seq(x)
        output["embedding"] = self.fc1(xx)
        output["log_covariance"] = self.fc2(xx)
        return output

class bDecoder(BaseDecoder):
    def __init__(self, embedding_dim, decompress_dims, data_dim, add_dropouts=False, p=None):
        BaseDecoder.__init__(self)

        seq = []
        dim = embedding_dim
        for item in decompress_dims:
            seq += [
                nn.Linear(dim, item),
                nn.ReLU(),
            ]
            dim = item
            if add_dropouts:
                seq += [nn.Dropout(p)]

        seq.append(nn.Linear(dim, data_dim))
        self.seq = nn.Sequential(*seq)
        # self.sigma = nn.Parameter(torch.ones(data_dim) * 0.1)
    
    def forward(self, x):
        output = ModelOutput()
        output["reconstruction"] = self.seq(x)
        # output["sigma"] = self.sigma
        return output

data_dim = fit.shape[-1]
embedding_dim = 128
compress_dims = (128, 128)
decompress_dims = (128, 128)

encoder = bEncoder(data_dim, compress_dims, embedding_dim)
decoder = bDecoder(embedding_dim, decompress_dims, data_dim)


model_config = VAMPConfig(
    input_dim=(data_dim,),
    lat_dim = embedding_dim,
)
model  = VAMP(
    model_config=model_config,
    encoder=encoder,
    decoder=decoder,
)

# config = BaseTrainerConfig(
#     output_dir="my_model",
#     learning_rate=1e-4,
#     per_device_train_batch_size=64,
#     per_device_eval_batch_size=64,
#     num_epochs=10,  # Change this to train the model a bit more
# )


train_dataset = BaseDataset(
    torch.from_numpy(fit.values.astype("float32")),
    torch.ones(fit.shape[0]),
)

training_config = BaseTrainerConfig()
trainer = BaseTrainer(
    model=model,
    train_dataset=train_dataset,
    training_config=training_config,
)

trainer.train()

! No eval dataset provided ! -> keeping best model on train.

Model passed sanity check !
Ready for training.

Training params:
 - max_epochs: 100
 - per_device_train_batch_size: 64
 - per_device_eval_batch_size: 64
 - checkpoint saving every: None
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
Scheduler: None

Successfully launched training !



Training of epoch 1/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 164.6207
--------------------------------------------------------------------------


Training of epoch 2/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 163.915
--------------------------------------------------------------------------


Training of epoch 3/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 163.5521
--------------------------------------------------------------------------


Training of epoch 4/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 163.6102
--------------------------------------------------------------------------


Training of epoch 5/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 163.0156
--------------------------------------------------------------------------


Training of epoch 6/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 150.9111
--------------------------------------------------------------------------


Training of epoch 7/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 133.6991
--------------------------------------------------------------------------


Training of epoch 8/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 127.5398
--------------------------------------------------------------------------


Training of epoch 9/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 122.3484
--------------------------------------------------------------------------


Training of epoch 10/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 118.5853
--------------------------------------------------------------------------


Training of epoch 11/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 116.8888
--------------------------------------------------------------------------


Training of epoch 12/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 115.7083
--------------------------------------------------------------------------


Training of epoch 13/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 114.326
--------------------------------------------------------------------------


Training of epoch 14/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 113.02
--------------------------------------------------------------------------


Training of epoch 15/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 111.9236
--------------------------------------------------------------------------


Training of epoch 16/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 110.7361
--------------------------------------------------------------------------


Training of epoch 17/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 109.9936
--------------------------------------------------------------------------


Training of epoch 18/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 109.259
--------------------------------------------------------------------------


Training of epoch 19/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 108.2964
--------------------------------------------------------------------------


Training of epoch 20/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 107.6722
--------------------------------------------------------------------------


Training of epoch 21/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 107.0937
--------------------------------------------------------------------------


Training of epoch 22/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 106.3654
--------------------------------------------------------------------------


Training of epoch 23/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 106.0007
--------------------------------------------------------------------------


Training of epoch 24/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 105.5184
--------------------------------------------------------------------------


Training of epoch 25/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 105.0297
--------------------------------------------------------------------------


Training of epoch 26/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 104.7389
--------------------------------------------------------------------------


Training of epoch 27/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 104.4348
--------------------------------------------------------------------------


Training of epoch 28/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 104.0686
--------------------------------------------------------------------------


Training of epoch 29/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 104.0071
--------------------------------------------------------------------------


Training of epoch 30/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 103.2795
--------------------------------------------------------------------------


Training of epoch 31/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 102.9878
--------------------------------------------------------------------------


Training of epoch 32/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 102.5121
--------------------------------------------------------------------------


Training of epoch 33/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 102.2334
--------------------------------------------------------------------------


Training of epoch 34/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 102.0295
--------------------------------------------------------------------------


Training of epoch 35/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 101.6482
--------------------------------------------------------------------------


Training of epoch 36/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 101.16
--------------------------------------------------------------------------


Training of epoch 37/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 100.8698
--------------------------------------------------------------------------


Training of epoch 38/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 100.6863
--------------------------------------------------------------------------


Training of epoch 39/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 100.238
--------------------------------------------------------------------------


Training of epoch 40/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 100.0337
--------------------------------------------------------------------------


Training of epoch 41/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 99.8231
--------------------------------------------------------------------------


Training of epoch 42/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 99.5222
--------------------------------------------------------------------------


Training of epoch 43/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 99.3247
--------------------------------------------------------------------------


Training of epoch 44/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 99.1759
--------------------------------------------------------------------------


Training of epoch 45/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 98.9531
--------------------------------------------------------------------------


Training of epoch 46/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 98.7403
--------------------------------------------------------------------------


Training of epoch 47/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 98.5118
--------------------------------------------------------------------------


Training of epoch 48/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 98.4544
--------------------------------------------------------------------------


Training of epoch 49/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 98.1204
--------------------------------------------------------------------------


Training of epoch 50/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 97.8638
--------------------------------------------------------------------------


Training of epoch 51/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 97.7339
--------------------------------------------------------------------------


Training of epoch 52/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 97.4901
--------------------------------------------------------------------------


Training of epoch 53/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 97.4099
--------------------------------------------------------------------------


Training of epoch 54/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 97.2982
--------------------------------------------------------------------------


Training of epoch 55/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 96.9494
--------------------------------------------------------------------------


Training of epoch 56/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 96.8473
--------------------------------------------------------------------------


Training of epoch 57/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 96.5126
--------------------------------------------------------------------------


Training of epoch 58/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 96.4848
--------------------------------------------------------------------------


Training of epoch 59/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 96.3154
--------------------------------------------------------------------------


Training of epoch 60/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 96.0505
--------------------------------------------------------------------------


Training of epoch 61/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.977
--------------------------------------------------------------------------


Training of epoch 62/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.8526
--------------------------------------------------------------------------


Training of epoch 63/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.7222
--------------------------------------------------------------------------


Training of epoch 64/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.6551
--------------------------------------------------------------------------


Training of epoch 65/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.2597
--------------------------------------------------------------------------


Training of epoch 66/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.1129
--------------------------------------------------------------------------


Training of epoch 67/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.0894
--------------------------------------------------------------------------


Training of epoch 68/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 94.7809
--------------------------------------------------------------------------


Training of epoch 69/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 94.7708
--------------------------------------------------------------------------


Training of epoch 70/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 95.5004
--------------------------------------------------------------------------


Training of epoch 71/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 94.4498
--------------------------------------------------------------------------


Training of epoch 72/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 94.2262
--------------------------------------------------------------------------


Training of epoch 73/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 94.3881
--------------------------------------------------------------------------


Training of epoch 74/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 93.8955
--------------------------------------------------------------------------


Training of epoch 75/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 93.8046
--------------------------------------------------------------------------


Training of epoch 76/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 93.579
--------------------------------------------------------------------------


Training of epoch 77/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 93.4158
--------------------------------------------------------------------------


Training of epoch 78/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 93.2522
--------------------------------------------------------------------------


Training of epoch 79/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 93.1122
--------------------------------------------------------------------------


Training of epoch 80/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 93.303
--------------------------------------------------------------------------


Training of epoch 81/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 92.8091
--------------------------------------------------------------------------


Training of epoch 82/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 92.6064
--------------------------------------------------------------------------


Training of epoch 83/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 92.5236
--------------------------------------------------------------------------


Training of epoch 84/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 92.2764
--------------------------------------------------------------------------


Training of epoch 85/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 92.1856
--------------------------------------------------------------------------


Training of epoch 86/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 92.1756
--------------------------------------------------------------------------


Training of epoch 87/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 91.8609
--------------------------------------------------------------------------


Training of epoch 88/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 91.7357
--------------------------------------------------------------------------


Training of epoch 89/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 91.5648
--------------------------------------------------------------------------


Training of epoch 90/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 91.2945
--------------------------------------------------------------------------


Training of epoch 91/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 91.0865
--------------------------------------------------------------------------


Training of epoch 92/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 90.8832
--------------------------------------------------------------------------


Training of epoch 93/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 90.9308
--------------------------------------------------------------------------


Training of epoch 94/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 90.427
--------------------------------------------------------------------------


Training of epoch 95/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 90.4132
--------------------------------------------------------------------------


Training of epoch 96/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 90.2483
--------------------------------------------------------------------------


Training of epoch 97/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 90.0374
--------------------------------------------------------------------------


Training of epoch 98/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 89.8606
--------------------------------------------------------------------------


Training of epoch 99/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 89.6784
--------------------------------------------------------------------------


Training of epoch 100/100:   0%|          | 0/188 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 89.4626
--------------------------------------------------------------------------
Training ended!
Saved final model in dummy_output_dir\VAMP_training_None\VAMP_rep_None\final_model


In [22]:
x = {"data": torch.randn(val.shape, device=torch.device("cuda"))}
fake = model(x).recon_x

In [9]:
# from uqvae.models.tvae_wrapper import TVAMP

# tvamp = TVAMP(
#     train_data=fit,
#     embedding_dim=128,
#     compress_dims=(128, 128),
#     decompress_dims=(128, 128),
#     l2scale=0.0,
#     discrete_columns=discrete_columns,
#     transformer=data_transformer,
#     epochs=30,
#     batch_size=128,
#     cuda=True,
#     transform_data=False
# )

# tvamp.fit()

In [10]:
# x = torch.randn((10, 1601), device=torch.device("cuda"))
# tvamp.encoder(x)

In [24]:
# fake = sample(tvamp, val.shape[0])

fake = pd.DataFrame(fake.detach().cpu().numpy(), index=val.index, columns=val.columns)
real = val

quality_report = QualityReport()
quality_report.generate(real, fake, metadata.to_dict())
quality_report.get_visualization(property_name="Column Shapes").show()
quality_report.get_visualization(property_name="Column Pair Trends").show()

savepath = os.path.normpath("../results/sdmetrics_quality_reports/")
os.makedirs(savepath, exist_ok=True)
quality_report.save(os.path.join(savepath, "quality_report_tvamp.pkl"))

Creating report:  75%|███████▌  | 3/4 [01:38<00:43, 43.82s/it]