In [None]:
%%bash
git clone https://github.com/dmhd1/variational-autoencoder.git

In [None]:
%cd variational-autoencoder/

In [None]:
!pip install nomen
!pip install torch torchvisio

In [3]:
import flow
import train_variational_autoencoder_pytorch
import data
import torch
torch.cuda.is_available(

In [None]:
dictionary = yaml.safe_load(config)
cfg = nomen.Config(dictionary)
cfg.parse_args()

cfg.data_dir = pathlib.Path.cwd() / cfg.data_dir
cfg.train_dir = pathlib.Path.cwd() / cfg.train_dir

device = torch.device("cuda:0" if cfg.use_gpu else "cpu")
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
random.seed(cfg.seed)

model = Model(latent_size=cfg.latent_size, data_size=cfg.data_size)
if cfg.variational == "flow":
    variational = VariationalFlow(
        latent_size=cfg.latent_size,
        data_size=cfg.data_size,
        flow_depth=cfg.flow_depth,
    )
elif cfg.variational == "mean-field":
    variational = VariationalMeanField(
        latent_size=cfg.latent_size, data_size=cfg.data_size
    )
else:
    raise ValueError(
        "Variational distribution not implemented: %s" % cfg.variational
    )

model.to(device)
variational.to(device)

optimizer = torch.optim.RMSprop(
    list(model.parameters()) + list(variational.parameters()),
    lr=cfg.learning_rate,
    centered=True,
)

kwargs = {"num_workers": 4, "pin_memory": True} if cfg.use_gpu else {}
train_data, valid_data, test_data = load_binary_mnist(cfg, **kwargs)

best_valid_elbo = -np.inf
num_no_improvement = 0

for step, batch in enumerate(cycle(train_data)):
    x = batch[0].to(device)
    model.zero_grad()
    variational.zero_grad()
    z, log_q_z = variational(x, n_samples=1)
    log_p_x_and_z = model(z, x)
    # average over sample dimension
    elbo = (log_p_x_and_z - log_q_z).mean(1)
    # sum over batch dimension
    loss = -elbo.sum(0)
    loss.backward()
    optimizer.step()

    if step % cfg.log_interval == 0:
        print(
            f"step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy().mean():.2f}"
        )
        with torch.no_grad():
            valid_elbo, valid_log_p_x = evaluate(
                cfg.n_samples, model, variational, valid_data
            )
        print(
            f"step:\t{step}\t\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}"
        )
        if valid_elbo > best_valid_elbo:
            num_no_improvement = 0
            best_valid_elbo = valid_elbo
            states = {
                "model": model.state_dict(),
                "variational": variational.state_dict(),
            }
            torch.save(states, cfg.train_dir / "best_state_dict")
        else:
            num_no_improvement += 1

        if num_no_improvement > cfg.early_stopping_interval:
            checkpoint = torch.load(cfg.train_dir / "best_state_dict")
            model.load_state_dict(checkpoint["model"])
            variational.load_state_dict(checkpoint["variational"])
            with torch.no_grad():
                test_elbo, test_log_p_x = evaluate(
                    cfg.n_samples, model, variational, test_data
                )
            print(
                f"step:\t{step}\t\ttest elbo: {test_elbo:.2f}\ttest log p(x): {test_log_p_x:.2f}"
            )
            break
dictionary = yaml.safe_load(config)
cfg = nomen.Config(dictionary)
cfg.parse_args()

cfg.data_dir = pathlib.Path.cwd() / cfg.data_dir
cfg.train_dir = pathlib.Path.cwd() / cfg.train_dir

device = torch.device("cuda:0" if cfg.use_gpu else "cpu")
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
random.seed(cfg.seed)

model = Model(latent_size=cfg.latent_size, data_size=cfg.data_size)
if cfg.variational == "flow":
    variational = VariationalFlow(
        latent_size=cfg.latent_size,
        data_size=cfg.data_size,
        flow_depth=cfg.flow_depth,
    )
elif cfg.variational == "mean-field":
    variational = VariationalMeanField(
        latent_size=cfg.latent_size, data_size=cfg.data_size
    )
else:
    raise ValueError(
        "Variational distribution not implemented: %s" % cfg.variational
    )

model.to(device)
variational.to(device)

optimizer = torch.optim.RMSprop(
    list(model.parameters()) + list(variational.parameters()),
    lr=cfg.learning_rate,
    centered=True,
)

kwargs = {"num_workers": 4, "pin_memory": True} if cfg.use_gpu else {}
train_data, valid_data, test_data = load_binary_mnist(cfg, **kwargs)

best_valid_elbo = -np.inf
num_no_improvement = 0

for step, batch in enumerate(cycle(train_data)):
    x = batch[0].to(device)
    model.zero_grad()
    variational.zero_grad()
    z, log_q_z = variational(x, n_samples=1)
    log_p_x_and_z = model(z, x)
    # average over sample dimension
    elbo = (log_p_x_and_z - log_q_z).mean(1)
    # sum over batch dimension
    loss = -elbo.sum(0)
    loss.backward()
    optimizer.step()

    if step % cfg.log_interval == 0:
        print(
            f"step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy().mean():.2f}"
        )
        with torch.no_grad():
            valid_elbo, valid_log_p_x = evaluate(
                cfg.n_samples, model, variational, valid_data
            )
        print(
            f"step:\t{step}\t\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}"
        )
        if valid_elbo > best_valid_elbo:
            num_no_improvement = 0
            best_valid_elbo = valid_elbo
            states = {
                "model": model.state_dict(),
                "variational": variational.state_dict(),
            }
            torch.save(states, cfg.train_dir / "best_state_dict")
        else:
            num_no_improvement += 1

        if num_no_improvement > cfg.early_stopping_interval:
            checkpoint = torch.load(cfg.train_dir / "best_state_dict")
            model.load_state_dict(checkpoint["model"])
            variational.load_state_dict(checkpoint["variational"])
            with torch.no_grad():
                test_elbo, test_log_p_x = evaluate(
                    cfg.n_samples, model, variational, test_data
                )
            print(
                f"step:\t{step}\t\ttest elbo: {test_elbo:.2f}\ttest log p(x): {test_log_p_x:.2f}"
            )
            break
