**Note**
- make sure to configure the working directory of your jupyter notebook as the project root directory (NOT the `notebooks` folder) to avoid importing errors

In [1]:
#@title Create the folders
%%capture
JUPYTER_MODE = True # False if running scripts on cluster
import os
folder_paths = ['cache/1D', 'out', 'src'] # create folders if they don't exist yet
for path in folder_paths:
    os.makedirs(path, exist_ok=True)
%pip install git+https://github.com/hosford42/EMNIST.git # Original EMNIST dataset is corrupted; this is an alternative download

In [2]:
#@title Import Libraries
import warnings; warnings.filterwarnings("ignore")
from tqdm.notebook import trange; from itertools import cycle
import numpy as np; import scipy.stats as stats
import torch; import torch.nn as nn; import torch.nn.functional as F
import torch.distributions as D; from torch.optim import Adam, SGD
from torch.distributions import Gamma, Normal, Weibull, LogNormal, MixtureSameFamily
from torch.utils.data import Subset, Dataset, DataLoader, TensorDataset
import numpy as np; import matplotlib.pyplot as plt; import functools
from src.model_1D import GaussianFourierProjection, Dense, ScoreNet, FusionNet, \
    loss_fn, fusion_loss_fn, find_test_loss, generic_train, exp_train
from src.samplers import marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler_1D, ode_sampler_1D
from src.datasets import make_data, make_mixture_data, Dataset_1D
from src.training import EarlyStopper, ExponentialMovingAverage
from src.plotter import plot_loss
# set_up_backend("torch", data_type="float32") # enable GPU support and set the floating point precision
seed = 515 # random seed for reproducibility
torch.manual_seed(seed); torch.cuda.manual_seed(seed)

## For matplotlib final printing
plt.rcParams.update({
    'font.size': 10,
    'font.family': 'sans-serif',
    'font.sans-serif': ['DejaVu Sans'],
    'axes.titlesize': 12,
    'axes.labelsize': 10,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'lines.linewidth': 1.5,
    'lines.markersize': 6,
    'figure.dpi': 100,
    'savefig.dpi': 300,
})

# Build Datasets

In [None]:
#@title Dataset parameters
TRAIN_SIZE = 512; VAL_SIZE = int(TRAIN_SIZE / 4)

sigma =  25.0; device = 'cuda'
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma, device = device)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma, device = device)
# Target distribution
mix_params = {
    'means': torch.tensor([-6., 6]),
    'stds': torch.tensor([0.8, 0.8]),
    'weights': torch.tensor([0.6, 0.4]),
    'train_size': TRAIN_SIZE, 'train_batch': TRAIN_SIZE if TRAIN_SIZE < 256 else 256,
    'val_size': VAL_SIZE, 'val_batch': VAL_SIZE if VAL_SIZE < 256 else 256,
    'test_size': 8096, 'test_batch': 1024
}

# Auxiliary distributions - two bimodal mixture of gaussians
auxs_info = [
    {
        'means': torch.tensor([-3., 10.]),
        'stds': torch.tensor([1., 1.]),
        'weights': torch.tensor([0.5, 0.5]),
        'train_size': 16384, 'train_batch': 4096,
        'val_size': 4096, 'val_batch': 2048,
        'test_size': 8096, 'test_batch': 1024
    },
    {
        'means': torch.tensor([-7., 9.]),
        'stds': torch.tensor([1., 1.]),
        'weights': torch.tensor([0.7, 0.3]),
        'train_size': 16384, 'train_batch': 4096,
        'val_size': 4096, 'val_batch': 2048,
        'test_size': 8096, 'test_batch': 1024
    }
]

VERBOSE = True # Print dataset specs
RUN_MODES = ['train', 'val', 'test']

get_mixture_data = lambda mode: make_mixture_data(Normal, mix_params, mode)
make_loader = lambda dataaset, batchSize: DataLoader(TensorDataset(dataaset), batchSize,
                                            shuffle = True, num_workers = 0)
# Load vanilla training data
(train_tar_data, train_tar_mean, train_tar_std), (val_tar_data, val_tar_mean, val_tar_std), \
    (test_tar_data, test_tar_mean, test_tar_std) = map(get_mixture_data, RUN_MODES)
train_tar_loader, val_tar_loader, test_tar_loader = \
    map(make_loader, [train_tar_data, val_tar_data, test_tar_data],
        [mix_params['train_batch'],mix_params['val_batch'],mix_params['test_batch']])
if VERBOSE:
    print('---target datasets specs---')
    print('train data info: ', (train_tar_loader.dataset.tensors[0].shape, train_tar_mean, train_tar_std))
    print('val data info: ', (val_tar_loader.dataset.tensors[0].shape, val_tar_mean, val_tar_std))
    print('test data info:', (test_tar_loader.dataset.tensors[0].shape, test_tar_mean, test_tar_std))
    print('train val test batch sizes:',
          train_tar_loader.batch_size, val_tar_loader.batch_size, test_tar_loader.batch_size)

# Load auxiliary training data
aux_datas, aux_means, aux_stds = [{name: [] for name in RUN_MODES} for i in range(3)]
train_aux_loaders, val_aux_loaders, test_aux_loaders = [],[],[]
print('\n','---aux datasets specs---')
for i, info in enumerate(auxs_info):
    print('\n',f'---auxiliary dataset set {i}---')
    get_aux_i_data = lambda runType : make_mixture_data(Normal, info, runType)
    (train_aux_data, train_aux_mean, train_aux_std), (val_aux_data, val_aux_mean, val_aux_std), \
        (test_aux_data, test_aux_mean, test_aux_std) = map(get_aux_i_data, RUN_MODES)
    store_aux_datas = lambda runType, dataa: aux_datas[runType].append(dataa)
    store_aux_means = lambda runType, mean: aux_means[runType].append(mean)
    store_aux_stds = lambda runType, std: aux_stds[runType].append(std)
    _,_,_ = map(store_aux_datas, RUN_MODES, [train_aux_data, val_aux_data, test_aux_data])
    _,_,_ = map(store_aux_means, RUN_MODES, [train_aux_mean, val_aux_mean, test_aux_mean])
    _,_,_ = map(store_aux_stds, RUN_MODES, [train_aux_std, val_aux_std, test_aux_std])
    train_aux_loaders.append(make_loader(aux_datas['train'][i], info['train_batch']))
    val_aux_loaders.append(make_loader(aux_datas['val'][i], info['val_batch']))
    test_aux_loaders.append(make_loader(aux_datas['test'][i], info['test_batch']))
    if VERBOSE:
        print(f'train, ', train_aux_loaders[i].dataset.tensors[0].shape, aux_means['train'][i], aux_stds['train'][i], train_aux_loaders[i].batch_size)
        print(f'val, ', val_aux_loaders[i].dataset.tensors[0].shape, aux_means['val'][i], aux_stds['val'][i], val_aux_loaders[i].batch_size)
        print(f'test, ', test_aux_loaders[i].dataset.tensors[0].shape, aux_means['test'][i], aux_stds['test'][i], test_aux_loaders[i].batch_size)

# Training

## Baseline Training

In [None]:
#@title Baseline - directly train on target Gaussian mixtures data
base_train_params = {
    'n_epochs': 2048, 'lr': 1e-4,
    'patience': 100, 'max_fraction': 1.0, 'decay_points': [-1],
    'ckpt_path': f'cache/1D/base_1d_{TRAIN_SIZE}.pth'
}
base_score, base_train_losses, base_val_losses, base_ema_losses, base_last_saved_epoch = \
    generic_train(JUPYTER_MODE, train_tar_loader, val_tar_loader,
                  base_train_params, ema_decay =0.999)

In [None]:
#@title visualize baseline loss dynamics
# evaluate test set performance
if True: # toggle true or false to turn on evaluation
    find_test_loss(base_score, test_tar_loader, f'baseline {TRAIN_SIZE}')

plot_loss(base_train_losses, base_val_losses, base_ema_losses,
          base_last_saved_epoch, offset = int(base_train_params['n_epochs'] / 10),)
# save_path = f'out/loss_1d_{TRAIN_SIZE}.png'

## Auxiliary Training

In [5]:
#@title Training to approximate the auxiliary distributions
# number of epochs - increase if necessary
aux_train_params_list = [
    {   'name': 'aux 0',
        'n_epochs': 1024, 'lr': 1e-4,
        'patience': 50, 'max_fraction': 0.5, 'decay_points': [-1],
        'ckpt_path': 'cache/1D/aux_1d_0.pth'
    },
    {   'name': 'aux 1',
        'n_epochs': 1024, 'lr': 1e-4,
        'patience': 50, 'max_fraction': 0.5, 'decay_points': [-1],
        'ckpt_path': 'cache/1D/aux_1d_1.pth'
    },
]

aux_train_losses, aux_val_losses, aux_ema_losses, aux_last_saved_epochs = [[] for i in range(4)]
train_indices = [0,1,2] # [0,1,2]
# train score models for each component distribution (embarrasingly parallel)
for i, v_train_params in enumerate(aux_train_params_list):
    print(f'\n-----Component model {i}-----')
    if i not in train_indices: # only train select component models
        print('Skipped...'); continue
    aux_score, train_losses, val_losses, ema_losses, last_saved_epoch = \
        exp_train(JUPYTER_MODE, train_aux_loaders[i], val_aux_loaders[i], v_train_params)
    _,_,_,_ = map(lambda llist, item: llist.append(item), # append training information
            [aux_train_losses, aux_val_losses, aux_ema_losses, aux_last_saved_epochs],
            [train_losses, val_losses, ema_losses, last_saved_epoch])

In [None]:
#@title evaluate test performance & visualize auxiliary loss dynamics
aux_score_models = nn.ModuleList() # load the checkpoints here
for param in aux_train_params_list:
    score = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn)).to(device)
    ckpt = torch.load(param['ckpt_path'], map_location=device)
    score.load_state_dict(ckpt)
    aux_score_models.append(score)

ID = 0; model_info = f'aux {ID}'
VISUAL, CALC_TEST = True, True

if CALC_TEST:
    find_test_loss(aux_score_models[ID], test_aux_loaders[ID], text = model_info)
if VISUAL:
    plot_loss(aux_train_losses[ID], aux_val_losses[ID],
              aux_ema_losses[ID], aux_last_saved_epochs[ID],
              offset = 200, text = model_info)

## Fusion Training

In [139]:
# specifiy training parameters
fusion_train_params = {
    'n_epochs': 100,
    'optimizer': SGD,
    'lr': 1e-1, 'momentum': 0.9,
    'ckpt_path': f'cache/1D/fusion_1d_{TRAIN_SIZE}.pth'
}

# Load the auxiliary scores / turn off their autograd
aux_scores = nn.ModuleList()
for i in range(len(auxs_info)):
    ckpt = torch.load(f'cache/1D/aux_1d_{i}.pth', map_location=device)
    m = torch.nn.DataParallel(
        ScoreNet(marginal_prob_std=marginal_prob_std_fn))
    m.load_state_dict(ckpt)
    for param in m.parameters():
        param.requires_grad = False
    aux_scores.add_module(f'aux {i}', m)

In [85]:
#@title ScoreFusion train function

MAX_FRACTION_OVER = 0.5 # tolerance, hyperparameter

def fusion_train(jupyter_mode, train_loader, val_loader, aux_scores, train_params):
    fusion_score = FusionNet(aux_scores).to(device)
    n_epochs, named_optimizer, lr, ckpt_path, momentum = \
        [train_params[key] for key in ['n_epochs',
                                       'optimizer', 'lr', 'ckpt_path', 'momentum']]
    tqdm_epoch = trange(n_epochs) if jupyter_mode else range(n_epochs)
    if named_optimizer == SGD:
        optimizer = named_optimizer(fusion_score.parameters(), lr=lr, momentum = momentum)
    else:
        optimizer = named_optimizer(fusion_score.parameters(), lr=lr)
    val_iter = iter(val_loader)
    min_val_loss = float('inf')
    train_losses, val_losses = [], []
    gotBetter = True

    for epoch in tqdm_epoch:
        for x in train_loader:
            x = x[0].to(device)
            loss = fusion_loss_fn(fusion_score, x, marginal_prob_std_fn)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item());
        with torch.no_grad():
            try:
                x_val = next(val_iter)
            except StopIteration: # Reinitialize the iterator and fetch the next batch
                val_iter = iter(val_loader)
                x_val = next(val_iter)
            loss = loss.item()
            x_val = x_val[0].to(device)
            val_loss = loss_fn(fusion_score, x_val, marginal_prob_std_fn).item()
            val_losses.append(val_loss)
            gotBetter = True if val_loss < min_val_loss else False
            gotWorse = True if val_loss > min_val_loss * (1 + MAX_FRACTION_OVER) else False
        if jupyter_mode:
            tqdm_epoch.set_description(f'train Loss: {loss:4f}; val loss: {val_loss:4f}')
        if gotBetter: # update lower bound, if got better validation loss
            min_val_loss = val_loss
            torch.save(fusion_score.state_dict(), ckpt_path)
        # if not gotWorse:
        #     torch.save(fusion_score.state_dict(), ckpt_path)

    return fusion_score, train_losses, val_losses

In [None]:
#@title Train the fushion model (like constrained least squares on lambdas)
fusion_score, fusion_train_losses, fusion_val_losses = \
    fusion_train(JUPYTER_MODE, train_tar_loader, val_tar_loader, aux_scores, fusion_train_params)
for name, param in fusion_score.named_parameters():
    if name == 'lambdas_logits':
        print(F.softmax(param, dim = 0))

# Sampling

## Auxiliary Sampling

In [None]:
#@title Sampling from the auxiliary score models
VERBOSE = True
aux_samples = []
aux_score_models = nn.ModuleList()
for i, (info, std, mean) in enumerate(zip(auxs_info, aux_stds['train'], aux_means['train'])):
    aux_score = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn)).to(device)
    aux_score.load_state_dict(torch.load(f'cache/1D/aux_1d_{i}.pth', map_location=device))
    aux_score_models.add_module(f'{i}', aux_score)

    sample_batch_size = info['test_size']
    sampler = Euler_Maruyama_sampler_1D
    aux_sample = sampler(aux_score,
                    marginal_prob_std_fn,
                    diffusion_coeff_fn,
                    sample_batch_size,
                    device=device, jupyter_mode = JUPYTER_MODE)
    # Remap back to original data space
#    aux_sample = aux_sample.ravel().to('cpu') * train_tar_std.to('cpu') + train_tar_mean.to('cpu') # map to target space
    aux_sample = aux_sample.ravel().to('cpu') * std.to('cpu') + mean.to('cpu') # UNCOMMENT FOR NORMAL AUXILIARY SAMPLING
    aux_samples.append(aux_sample)
    if VERBOSE:
        print('\n', f'Sampling auxiliary model {i}...')
        print(torch.min(aux_sample), torch.max(aux_sample)) # check the range of the fusion samples

In [5]:
#@title target & auxiliary distribution parameters
def make_mixture_dists(mixture_params, comp_prob: torch.distributions):
    num_mixes = len(mixture_params)
    base_probs = [MixtureSameFamily for i in range(num_mixes)]
    dists = []
    for i, prob in enumerate(base_probs):
        dists.append(prob(D.Categorical(mixture_params[i]['weights']),
                        comp_prob(mixture_params[i]['means'], mixture_params[i]['stds'])))
    return dists

In [None]:
#@title Plot auxiliary generated samples vs ground truth
ID = 0

test_aux_data_np = test_aux_loaders[ID].dataset.tensors[0].cpu().ravel()
test_aux_data_np = (test_aux_data_np * aux_stds['test'][ID].cpu() +
                    aux_means['test'][ID].cpu()).numpy()

wd_aux = stats.wasserstein_distance(test_aux_data_np, aux_samples[ID].numpy())
print(f"W1 distance for auxiliary {ID} model:", wd_aux)
plt.figure(figsize=(12, 5))

plt.hist(aux_samples[ID], bins = 400, histtype = u'step',
         density = True, label = 'auxiliary');
# plt.hist(test_aux_data_np, bins = 100,
#           density = True, histtype = u'step', color = 'red', label = 'Test');
x_values = np.linspace(-20, 20, 400)
dist = make_mixture_dists(auxs_info, Normal)[ID]
p_values = torch.exp(dist.log_prob(torch.tensor(x_values))).to('cpu').numpy()
plt.plot(x_values, p_values, label = "Ground Truth", linewidth=3, color = 'black')

plt.xlim(-20, 20); plt.title(f"Auxiliary {ID} versus test")
plt.xlabel("x"); plt.legend(); plt.grid(True)

In [None]:
#@title Closet auxiliary to target ground truth - W1
VERBOSE = False
aux_samples_tar = []
test_tar_data_np = test_tar_loader.dataset.tensors[0].cpu().ravel()
test_tar_data_np = (test_tar_data_np * test_tar_std.cpu() + test_tar_mean.cpu()).numpy()
aux_score_models = nn.ModuleList()
for i, (info) in enumerate(auxs_info):
    aux_score = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn)).to(device)
    aux_score.load_state_dict(torch.load(f'cache/1D/aux_1d_{i}.pth', map_location=device))
    aux_score_models.add_module(f'{i}', aux_score)

    sample_batch_size = info['test_size']
    sampler = Euler_Maruyama_sampler_1D
    aux_sample = sampler(aux_score,
                    marginal_prob_std_fn,
                    diffusion_coeff_fn,
                    sample_batch_size,
                    device=device, jupyter_mode = JUPYTER_MODE)
    # map to target data space
    aux_sample = aux_sample.ravel().to('cpu') * train_tar_std.to('cpu') + train_tar_mean.to('cpu') # UNCOMMENT FOR NORMAL AUXILIARY SAMPLING
    aux_samples_tar.append(aux_sample)

for i, aux_sample in enumerate(aux_samples_tar):
    wd_aux = stats.wasserstein_distance(test_tar_data_np, aux_sample.numpy())
    print(f"W1 distance between auxiliary {i} and target ground truth:", wd_aux)

## Baseline Sampling

In [None]:
#@title Sampling from baseline score model
TRAIN_SIZE = 1024
NUM_TRIALS = 10

test_tar_data_np = test_tar_loader.dataset.tensors[0].cpu().ravel()
test_tar_data_np = (test_tar_data_np * test_tar_std.cpu() + test_tar_mean.cpu()).numpy()
wd_baselines = torch.zeros((NUM_TRIALS,))

for i in range(NUM_TRIALS):
    base_score = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
    base_score = base_score.to(device)
    ckpt = torch.load(f'cache/1D/base_1d_{TRAIN_SIZE}.pth', map_location=device)
    base_score.load_state_dict(ckpt)

    sample_batch_size = mix_params['test_size']
    sampler = Euler_Maruyama_sampler_1D
    base_samples = sampler(base_score,
                    marginal_prob_std_fn,
                    diffusion_coeff_fn,
                    sample_batch_size,
                    device=device, jupyter_mode=False)

    base_samples = base_samples.ravel().to('cpu')
    # undo data standardization to remap to the original data space
    base_samples = base_samples * train_tar_std.to('cpu') + train_tar_mean.to('cpu')
    # torch.min(base_samples), torch.max(base_samples) # check the range of the fusion samples
    wd_baseline = stats.wasserstein_distance(test_tar_data_np, base_samples.numpy())
    wd_baselines[i] = wd_baseline
print(wd_baselines.mean(), wd_baselines.std())

In [None]:
#@title Plot baseline-generated data vs ground truth


wd_baseline = stats.wasserstein_distance(test_tar_data_np, base_samples.numpy())
print("W1 distance for baseline model:", wd_baseline)
plt.figure(figsize=(12, 5))

plt.hist(base_samples, bins = 200, histtype = u'step',
         density = True, label = 'Vanilla');
plt.hist(test_tar_data_np, bins = 200,
          density = True, histtype = u'step', color = 'red', label = 'Test');
plt.xlim(-30, 30); plt.title("Vanilla versus test")
plt.xlabel("x"); plt.legend(); plt.grid(True)

## Fusion Sampling

In [None]:
#@title Sampling from the fusion model
# Below set the optimal lambdas obtained from Frank-Wolfe
NUM_TRIALS = 10
wd_fusions = torch.zeros((NUM_TRIALS,))

lambdas = torch.tensor([0.5781, 0.4219], device = device) # 1024
for i in range(NUM_TRIALS):
    fusion_score = FusionNet(aux_score_models, lambdas)
    fusion_score = fusion_score.to(device)
    sample_batch_size = mix_params['test_size']
    sampler = Euler_Maruyama_sampler_1D
    fusion_samples = sampler(fusion_score,
                    marginal_prob_std_fn,
                    diffusion_coeff_fn,
                    sample_batch_size,
                    device=device)
    fusion_samples = fusion_samples.ravel()
    fusion_samples = fusion_samples.to('cpu')
    fusion_samples = fusion_samples * train_tar_std.to('cpu') + train_tar_mean.to('cpu')
    wd_fusion = stats.wasserstein_distance(test_tar_data_np, fusion_samples.numpy())
    wd_fusions[i] = wd_fusion
print(wd_fusions.mean(), wd_fusions.std())

In [None]:
#@title Plot fusion vs ground truth
test_tar_data_np = test_tar_loader.dataset.tensors[0].cpu().ravel()
test_tar_data_np = (test_tar_data_np * test_tar_std.cpu() + test_tar_mean.cpu()).numpy()
wd_fusion = stats.wasserstein_distance(test_tar_data_np, fusion_samples.numpy())
print("W1 distance for fusion model:", wd_fusion)
plt.figure(figsize=(5, 2))
plt.hist(fusion_samples, bins = 200, histtype = u'step',
         density = True, label = 'ScoreFusion', linewidth = 2);
x_values = np.linspace(-20, 20, 600)
dist = make_mixture_dists([mix_params], Normal)[0]
p_values = torch.exp(dist.log_prob(torch.tensor(x_values))).to('cpu').numpy()
plt.plot(x_values, p_values, label = "Ground Truth",
         linewidth=3, color = 'black', alpha = 0.6)
plt.xlabel("x (sample space)"); plt.legend(fontsize='small',loc='upper right', framealpha=0.5);
plt.grid(True); plt.ylabel("Density")
plt.xlim(-15, 15)
plt.savefig(f'out/1d_{TRAIN_SIZE}_comp1.png', bbox_inches='tight')

# Print-ready Comparisons

In [None]:
def count_parameters(model):
    '''Count number of trainable parameters'''
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(base_score))
print(count_parameters(fusion_score))

In [None]:
#@title Comparison 1: baseline vs ScoreFusion
test_tar_data_np = test_tar_loader.dataset.tensors[0].cpu().ravel()
test_tar_data_np = (test_tar_data_np * test_tar_std.cpu() + test_tar_mean.cpu()).numpy()

wd_base = stats.wasserstein_distance(test_tar_data_np, base_samples.numpy())
wd_fusion = stats.wasserstein_distance(test_tar_data_np, fusion_samples.numpy())
print("W1 distance for baseline:", wd_base)
print("W1 distance for ScoreFusion:", wd_fusion)
plt.figure(figsize=(5, 2))
plt.hist(base_samples, bins = 800, histtype = u'step',
         density = True, label = 'Baseline', linewidth = 2.5);
plt.hist(fusion_samples, bins = 200, histtype = u'step',
         density = True, label = 'ScoreFusion', linewidth = 2.5);
# calculate the ground truth
x_values = np.linspace(-20, 20, 600)
dist = make_mixture_dists([mix_params], Normal)[0]
p_values = torch.exp(dist.log_prob(torch.tensor(x_values))).to('cpu').numpy()
plt.plot(x_values, p_values, label = "Ground Truth",
         linewidth=2.5, color = 'black', alpha = 0.6)
plt.ylabel('Density')
plt.xlim(-15, 15); plt.grid(True)
plt.xlabel("x (sample space)"); plt.legend(fontsize='small',loc='upper right', framealpha=0.5)
plt.savefig(f'out/1d_{TRAIN_SIZE}_comp1.png', bbox_inches='tight')

In [None]:
#@title Comparison 2: true vs auxiliaries
colors = ['green', 'purple']
colors = cycle(colors)
test_tar_data_np = test_tar_loader.dataset.tensors[0].cpu().ravel()
test_tar_data_np = (test_tar_data_np * test_tar_std.cpu() + test_tar_mean.cpu()).numpy()

for i, samples in enumerate(aux_samples):
    wd_aux = stats.wasserstein_distance(test_tar_data_np, samples.numpy())
    print(f"W1 distance for aux {i} model to target: {wd_aux}")
x_values = np.linspace(-20, 20, 600)
dists = make_mixture_dists(auxs_info, Normal)

plt.figure(figsize=(5, 2))
for i, dist in enumerate(dists):
    p_values = torch.exp(dist.log_prob(torch.tensor(x_values))).to('cpu').numpy()
    plt.plot(x_values, p_values, label = f"Auxiliary {i}",
            linewidth=2.5, color = next(colors))
# calculate the ground truth
dist = make_mixture_dists([mix_params], Normal)[0]
p_values = torch.exp(dist.log_prob(torch.tensor(x_values))).to('cpu').numpy()
plt.plot(x_values, p_values, label = "Ground Truth",
         linewidth=2.5, color = 'black', alpha = 0.7)
plt.ylabel('Density')
plt.xlim(-15, 15);
plt.xlabel("x (sample space)"); plt.legend(fontsize='small',loc='upper right', framealpha=0.5); plt.grid(True)
plt.savefig(f'out/1d_{TRAIN_SIZE}_comp2.png', bbox_inches='tight')