# Load dataset

In [None]:
import os, sys
import torch.nn as nn
import torch
added_path = os.path.join(os.path.abspath(".."), "VAUB-gp")
if added_path not in sys.path:
    sys.path.append(added_path)
# print(sys.path)
import torch.optim as optim
from itertools import chain
from tqdm import tqdm
from synthetic_exp_util import get_dataloader, calculate_auroc, vae_loss, calculate_gp_loss, Score_fn, UNet
import matplotlib.pyplot as plt



In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 2 * hidden_dim),
            nn.SiLU(),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # Mean and log-variance
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 2 * hidden_dim),
            nn.SiLU(),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        self.bn_mean = nn.BatchNorm1d(latent_dim)

    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = h.chunk(2, dim=-1)
        mean = self.bn_mean(mean)
        logvar = torch.clamp(logvar, max=4)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), z, mean, logvar
    
    def init_weights(self, scale=0.1):
        def weights_init(m):
            if isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, a=-scale, b=scale)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.apply(weights_init)

    
def train_vaub_gp(mode, device, is_vanilla, input_dim, latent_dim, alpha, loops, hidden_dim, timesteps, sigma_max,
                  sigma_min, lr_vae, lr_score, beta, gp_lambda, num_epochs, dataloader_domain1, dataloader_domain2,
                  dataloader_score1, dataloader_score2, num_visual=1, num_log=1, plot=True):

    vae1 = VAE(input_dim, hidden_dim, latent_dim).to(device)
    vae2 = VAE(input_dim, hidden_dim, latent_dim).to(device)
    # vae1.init_weights()
    # vae2.init_weights()
    score_model = Score_fn(UNet(in_dim=2, out_dim=2, num_timesteps=timesteps, is_warm_init=False), sigma_min=sigma_min, sigma_max=sigma_max, num_timesteps=timesteps, device=device).to(device)
    optimizer_vae = optim.Adam(chain(vae1.parameters(), vae2.parameters()), lr=lr_vae)
    optimizer_score = torch.optim.Adam(score_model.parameters(), lr=lr_score)

    total_loss_list = []
    recon_loss_list = []
    kl_loss_list = []
    gp_loss_list = []
    # Training
    for epoch in tqdm(range(num_epochs)):
        vae1.train()
        vae2.train()
        total_loss = 0
        total_recon_loss = 0
        total_kld_encoder_posterior = 0
        total_kld_prior = 0

        for i, (data1, data2) in enumerate(zip(dataloader_domain1, dataloader_domain2)):
            x1, label1 = data1
            x2, label2 = data2
            x1, x2 = x1.to(device), x2.to(device)
            optimizer_vae.zero_grad()

            recon_x1, z1, mean1, logvar1 = vae1(x1)
            recon_x2, z2, mean2, logvar2 = vae2(x2)
            x, recon_x, z, mean, logvar = torch.vstack((x1, x2)), torch.vstack((recon_x1, recon_x2)), torch.vstack((z1, z2)), torch.vstack((mean1, mean2)), torch.vstack((logvar1, logvar2))

            # DSM = score_model.get_LSGM_loss(z, is_mixing=True, is_residual=True, is_vanilla=is_vanilla)
            score = score_model.get_mixing_score_fn(z, 5*torch.ones(z.shape[0], device=device).type(torch.long), detach=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha) - 0.05 * z
            score = torch.matmul(score.unsqueeze(1), z.unsqueeze(-1)).sum()
            # score = -torch.sqrt(torch.matmul(score.unsqueeze(1), z.unsqueeze(-1)).sum()**2)

            if mode == 'Gaussian':
                loss, recon_loss, kld_encoder_posterior, kld_prior = vae_loss(recon_x, x, mean, logvar, beta, score=None, DSM=None)
            else:
                loss, recon_loss, kld_encoder_posterior, kld_prior = vae_loss(recon_x, x, mean, logvar, beta, score=score, DSM=None)

            # dist_func_x = get_lp_dist(p=2)
            # dist_func_z = get_lp_dist(p=2)
            # gp_loss = sum([compute_gp_loss(x, z, dist_func_x, dist_func_z) for x, z in zip([x1, x2], [z1, z2])])
            gp_loss = gp_lambda * calculate_gp_loss([x1, x2], [z1, z2])

            gp_loss_list.append(gp_loss.item())

            loss += gp_loss
            total_loss_list.append((loss).item())
            recon_loss_list.append((recon_loss).item())
            kl_loss_list.append((kld_encoder_posterior+kld_prior).item())

            loss.backward()
            optimizer_vae.step()

            # if epoch % 25 == 0 and i==0:
            #     print(f'score loss: {score}')
            #     print(f'LSGM loss: {DSM}')
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_kld_encoder_posterior += kld_encoder_posterior.item()
            total_kld_prior += kld_prior.item()

            # Update Score Function
            for loop in range(loops):
                data1, data2 = next(iter(zip(dataloader_score1, dataloader_score2)))
                x1, label1 = data1
                x2, label2 = data2
                x1, x2 = x1.to(device), x2.to(device)
                recon_x1, z1, mean1, logvar1 = vae1(x1)
                recon_x2, z2, mean2, logvar2 = vae2(x2)
                x, recon_x, z, mean, logvar, labels = torch.vstack((x1, x2)), torch.vstack((recon_x1, recon_x2)), torch.vstack((z1, z2)), torch.vstack((mean1, mean2)), torch.vstack((logvar1, logvar2)), torch.vstack((label1, label2))
                # print(loop)
                if loop == (loops-1) and (epoch+1) % (num_epochs//num_visual) == 0 and i==0:
                    # print(f"Epoch {epoch} DSM average loss:", end=' ')
                    recon_x1_z2 = vae1.decode(z2)
                    import seaborn as sns
                    
                    # Set up seaborn style for improved aesthetics
                    sns.set(style='white')
                    
                    fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
                    
                    # X1 Given Z2 plot
                    ax[0].scatter(recon_x1_z2.detach().cpu()[:, 0], recon_x1_z2.detach().cpu()[:, 1], marker='.', color='darkblue', s=30)
                    # ax[0].set_title('Translated x1 given z2', fontsize=14)
                    # ax[0].set_xlabel('Dimension 1', fontsize=12)
                    # ax[0].set_ylabel('Dimension 2', fontsize=12)
                    
                    # X1 Reconstruction plot
                    ax[1].scatter(recon_x1.detach().cpu()[:, 0], recon_x1.detach().cpu()[:, 1], marker='.', color='darkgreen', s=30)
                    # ax[1].set_title('x1 reconstructed', fontsize=14)
                    # ax[1].set_xlabel('Dimension 1', fontsize=12)
                    # ax[1].set_ylabel('Dimension 2', fontsize=12)
                    
                    # Data split scatter plot
                    n_samples = x1.shape[0] if x1.shape[0] < 200 else 200
                    data = z.detach().cpu()
                    data1, data2 = data.chunk(2)
                    labels1, labels2 = labels.view((-1,)).chunk(2)
                    
                    # Data subsets by labels
                    data1_l1, data1_l2 = data1[labels1 == 0], data1[labels1 == 1]
                    data2_l1, data2_l2 = data2[labels2 == 0], data2[labels2 == 1]
                    
                    # Plot scatter with different markers and labels
                    ax[2].scatter(data1_l1[:n_samples, 0], data1_l1[:n_samples, 1], marker='+', label='D1_L1', c='b', s=40)
                    ax[2].scatter(data1_l2[:n_samples, 0], data1_l2[:n_samples, 1], marker='o', label='D1_L2', c='b', s=40, edgecolors='k')
                    ax[2].scatter(data2_l1[:n_samples, 0], data2_l1[:n_samples, 1], marker='+', label='D2_L1', c='g', s=40)
                    ax[2].scatter(data2_l2[:n_samples, 0], data2_l2[:n_samples, 1], marker='o', label='D2_L2', c='g', s=40, edgecolors='k')
                    
                    # Add title, labels, legend
                    # ax[2].set_title('Latent Space Representation', fontsize=14)
                    # ax[2].set_xlabel('Latent Dim 1', fontsize=12)
                    # ax[2].set_ylabel('Latent Dim 2', fontsize=12)
                    ax[2].legend(fontsize=10, loc='upper right')
                    
                    # Improve layout and space between subplots
                    plt.tight_layout()
                    
                    # Show the plot
                    plt.show()

                    
                                        
                    score_model.update_score_fn(z, optimizer=optimizer_score, max_timestep=None, verbose=True, is_mixing=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha)
                else:
                    score_model.update_score_fn(z, optimizer=optimizer_score, max_timestep=None, is_mixing=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha)
        # Print every 25 epochs
        if (epoch + 1) % (num_epochs//num_log) == 0:
            print(f'Epoch {epoch+1}, Total Loss: {total_loss:.2f}, Recon Loss: {total_recon_loss:.2f}, '
                  f'Encoder Posterior Loss: {total_kld_encoder_posterior:.2f}, Prior Loss: {total_kld_prior:.2f}, '
                  f'Gp loss: {gp_loss}')

    if plot:
        fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(12, 4))  # Adjust figsize as needed

        # Plot Data and Set Titles
        axs[0].plot(total_loss_list, label='total lost')
        axs[0].set_title('total lost')

        axs[1].plot(recon_loss_list, label='recon list')
        axs[1].set_title('recon list')

        axs[2].plot(kl_loss_list, label='kl list')
        axs[2].set_title('kl list')

        axs[3].plot(gp_loss_list, label='gp list')
        axs[3].set_title('gp list')

        plt.tight_layout()
        plt.show()

    return (vae1, vae2), score_model, z, labels

# num_samples = 50

In [None]:
# Get data
n_points = 50
batch_size = 50
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

# default value
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.25
gp_lambda = 500
num_epochs = 500

# train the model
for _ in range(1):
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="score", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

In [None]:
# default value
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 5e-2
lr_score = 2e-3
beta = 2
gp_lambda = 0.5
num_epochs = 1000

# Get data
n_points = 50
batch_size = 50
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

for _ in range(1):
    # train the model
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="Gaussian", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

# num_samples = 100

In [None]:
# Get data
n_points = 100
batch_size = 100
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

# default value
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.5
gp_lambda = 500
num_epochs = 500

# train the model
for _ in range(1):
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="score", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

In [None]:
# default value
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 5e-2
lr_score = 2e-3
beta = 2
gp_lambda = 0.5
num_epochs = 500

# Get data
n_points = 100
batch_size = 100
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

for _ in range(1):
    # train the model
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="Gaussian", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

# num_samples = 200

In [None]:
# Get data
n_points = 200
batch_size = 200
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

# default value
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.5
gp_lambda = 200
num_epochs = 500

# train the model
for _ in range(1):
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="score", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

In [None]:
# default value
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 1
gp_lambda = 0.5
num_epochs = 500

# Get data
n_points = 200
batch_size = 200
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

for _ in range(5):
    # train the model
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="Gaussian", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=False)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

# num_samples = 500

In [None]:
# Get data
n_points = 500
batch_size = 500
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

# default value
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.25
gp_lambda = 200
num_epochs = 500

# train the model
for _ in range(1):
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="score", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

In [None]:
# default value
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.4
gp_lambda = 0.5
num_epochs = 500

# Get data
n_points = 500
batch_size = 500
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

for _ in range(1):
    # train the model
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="Gaussian", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=False)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

# num_samples = 10

In [None]:
# Get data
n_points = 20
batch_size = 20
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

# default value
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.25
gp_lambda = 500
num_epochs = 500

# train the model
for _ in range(1):
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="score", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

In [None]:
# default value
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.5
gp_lambda = 0.5
num_epochs = 500

# Get data
n_points = 20
batch_size = 20
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

for _ in range(1):
    # train the model
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="Gaussian", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

# num_samples = 1000

In [None]:
# Get data
n_points = 1000
batch_size = 1000
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

# default value
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.5
gp_lambda = 200
num_epochs = 1000

# train the model
for _ in range(1):
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="score", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")

In [None]:
# default value
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
is_vanilla = True
input_dim = 2
latent_dim = 2
alpha = None

# score related
loops = 1
hidden_dim = 24
timesteps = 200
sigma_max = 0.4
sigma_min = 0.01

# rest hyperparams
lr_vae = 1e-2
lr_score = 2e-3
beta = 0.8
gp_lambda = 0.5
num_epochs = 500

# Get data
n_points = 1000
batch_size = 1000
dataloader_domain1, dataloader_domain2, dataloader_score1, dataloader_score2 = get_dataloader(n_points=n_points, batch_size=batch_size, plot=False)

for _ in range(1):
    # train the model
    (vae1, vae2), score_model, z, labels = train_vaub_gp(mode="Gaussian", device=device, is_vanilla=is_vanilla, input_dim=input_dim, latent_dim=latent_dim, alpha=alpha, loops=loops, hidden_dim=hidden_dim, timesteps=timesteps, sigma_max=sigma_max, sigma_min=sigma_min, lr_vae=lr_vae, lr_score=lr_score, beta=beta, gp_lambda=gp_lambda, num_epochs=num_epochs, dataloader_domain1=dataloader_domain1, dataloader_domain2=dataloader_domain2, dataloader_score1=dataloader_score1, dataloader_score2=dataloader_score2, plot=True)
    
    auc_score = calculate_auroc(z, labels)
    print(f"{auc_score:.2f}")