In [1]:
%load_ext autoreload
%autoreload 2

import sys

import torch

if r"../../../kb-data-cleaning/kbclean" not in sys.path:
    sys.path.append(r"../../../kb-data-cleaning/kbclean")

method = "oc_gan"

## Load hyper-parameters for experiments

In [2]:
import yaml

hparams = yaml.load(open(f"../../config/{method}.yaml", "r"), Loader=yaml.FullLoader)
hparams

{'batch_size': 3500,
 'input_dim': 200,
 'gen_hid_dim': 100,
 'latent_dim': 50,
 'dropout_p': 0.75,
 'lr': 0.0005,
 'amp_level': 'O1',
 'max_length': 100}

In [3]:
from argparse import Namespace

import torch

hparams = Namespace(**hparams)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Preprocess data

In [4]:
import numpy as np

encoded_data = np.load("../../data/numpy/encoded.npy")

encoded_data.shape

(1421482, 20)

In [5]:
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()
encoded_data = scaler.fit_transform(encoded_data)

In [6]:
from torch.utils.data import DataLoader, random_split

train_length = int(len(encoded_data) * 0.7)
train_dataset, val_dataset = random_split(
    list(encoded_data), [train_length, len(encoded_data) - train_length],
)

len(train_dataset)

995037

In [7]:
train_dataloader = DataLoader(
    train_dataset, batch_size=hparams.batch_size, num_workers=16,
)

val_dataloader = DataLoader(
    val_dataset, batch_size=hparams.batch_size, num_workers=16,
)

## Load pre-trained encoder + regular-GAN discriminator

In [8]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from utils.logger import MyTensorBoardLogger
from models.gan import RGANDiscriminator, OneClassGAN

# rgan = RGANDiscriminator(hparams)

# trainer = Trainer(
#     gpus=[0, 1, 2, 3],
#     amp_level="O1",
#     default_root_dir="../../checkpoints/rgan/",
#     distributed_backend="dp",
#     logger=MyTensorBoardLogger("../../tt_logs", "rgan"),
#     max_epochs=10
# )
# trainer.fit(
#     rgan, train_dataloader=train_dataloader, val_dataloaders=[val_dataloader]
# )

rgan = RGANDiscriminator.load_from_checkpoint("../../tt_logs/rgan/version_0/checkpoints/epoch=9.ckpt")

In [None]:
ogan = OneClassGAN(hparams, rgan)

trainer = Trainer(
    gpus=[0, 1, 2, 3],
    amp_level="O1",
    benchmark=False,
    default_save_path="../../checkpoints/ocgan/",
    distributed_backend="ddp",
    auto_scale_batch_size='binsearch',
    logger=MyTensorBoardLogger("../../tt_logs", "ocgan"),
    max_epochs=10
)
trainer.fit(
    ogan, train_dataloader=train_dataloader
)