In [7]:
import numpy as np
import matplotlib.pyplot as plt
from data import *
from utils import *
from NMF import *
import time
import pandas as pd
import seaborn
from copy import deepcopy
import warnings
warnings.filterwarnings("ignore")

Handling of data is done in data.py, and tensorflow/keras to load data. As long as these packages are installed, the experiments here should be replicated by any computer.

Convergence experiment
---

We generate data, $2500$ of all types of data, with $0$ and $1$ digits. We select deterministic weights equal to $1/2$, and preselect $d = 32$. We do all experiments with the same data. The loss we show is the respective loss of the different methods. For all methods, we only show convergence of the first source. For ANMF we select the batch size of the adversarial data to be $100$ for all experiments. 

In [None]:
np.random.seed(0)
mnist = MNIST()
Ms = [0,1]
M = len(Ms)
N_sup = 2500
N_sup_test = 0

mnist.generate_supervised(Ms = Ms, N_sup = N_sup, N_sup_test = N_sup_test, type = "det", weights = [1.0/M] * M)

V_sup = np.copy(mnist.x_sup_train[:,0,:,:].reshape((mnist.N_sup, 784)).T)
U_s = np.copy(mnist.y_sup_train.reshape((mnist.N_sup,M,784)).T)
U_sup = []
U_test = np.copy(mnist.y_sup_test.reshape((mnist.N_sup_test,M,784)).T)
U_test_fit = []
for i in range(M):
    U_sup.append(U_s[:,i,:])
    U_test_fit.append(U_test[:,i,:])
V_test = np.copy(mnist.x_sup_test[:,0,:,:].reshape((mnist.N_sup_test, 784)).T)

In [None]:
# Convergence of NMF 
np.random.seed(0)
plt.rcParams.update({'font.size': 14})

batch_size = [50,500,1000,2500]
linestyles = ['dashdot']*10
times = []
loss_stds = []
for i,b in enumerate(batch_size):
    tick = time.time()
    nmf = NMF(d = 64, epochs = 100, batch_size = b, mu_W = 1e-10, mu_H = 1e-10, normalize = True)
    loss_stds.append(nmf.fit_std(U_r = U_sup[0],conv = True))
    tock = time.time()
    times.append(tock-tick)
print(times)

In [None]:
# Plotting
for i,b in enumerate(batch_size): 
    plt.plot(np.arange(1,101), loss_stds[i][1:],label = f"{b}", linestyle = linestyles[i])
plt.legend(title = "Batch Size", loc = "upper right")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.title(f"Convergence of SMU for NMF")
plt.grid()
plt.savefig("fig/std_conv.png")
plt.show()

In [None]:
# Convergence of Adversarial NMF
plt.rcParams.update({'font.size': 14}) 
np.random.seed(0)

batch_size = [500,500,1000,2500]
batch_size_z = [100,250,500,5000]
linestyles = ['dashdot']*10
times = []
loss_advs = []

for i,b in enumerate(batch_size):
    tick = time.time()
    nmf_sep = NMF_separation(ds = [64,64], epochs = 100, prob = "adv", true_sample = "std", normalize = True, update_H = False,
        mu_W = 1e-10, mu_H = 1e-10, batch_size = batch_size[i], batch_size_z = batch_size_z[i], tau_A = 0.1)
    U_z = nmf_sep.create_adversarial(U_sup, V_sup)
    loss_advs.append(nmf_sep.NMFs[0].fit_adv(U_r = U_sup[0], U_z = U_z[0], conv = True))

    tock = time.time()
    times.append(tock-tick)
print(times)


In [None]:
# Plotting
for i, b in enumerate(batch_size): 
    plt.plot(np.arange(1,len(loss_advs[i])), loss_advs[i][1:],label = f"({b}, {batch_size_z[i]})", linestyle = linestyles[i])
plt.legend(title = r"Batch Sizes " + "\n" + "(True, Adv)")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.title("Convergence of SMU for ANMF")
plt.grid()
plt.savefig("fig/adv_conv.png")
plt.show()


In [None]:
# Convergence of discriminative NMF
np.random.seed(1)
plt.rcParams.update({'font.size': 14}) 

batch_size_sup = [50,500,1000,2500]
linestyles = ['dashdot']*10
ds = [64,64]
times = []
loss_sups = []
for i,b in enumerate(batch_size_sup):
    tick = time.time()
    nmf = NMF(d = np.sum(ds), ds = ds, epochs = 100, mu_W = 1e-10, mu_H = 1e-10, warm_start_epochs = 10,batch_size_sup = b, normalize = True)
    loss_sups.append(nmf.fit_sup(U_sup = U_sup, V_sup = V_sup, conv = True))
    tock = time.time()
    times.append(tock-tick)
print(times)

In [None]:
# Plotting
for i, b in enumerate(batch_size_sup):
    plt.plot(np.arange(1,101),loss_sups[i][0,1:],label = f"{b}", linestyle = linestyles[i])
plt.legend(title = "Batch Size", loc = "upper right")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.title("Convergence of SMU for DNMF")
plt.grid()
plt.savefig("fig/sup_conv.png")
plt.show()

In [None]:
# Convergence of D+ANMF
plt.rcParams.update({'font.size': 14}) 
np.random.seed(0)

batch_size_r = [500,100,1000,2500]
batch_size_z = [100, 100, 250, 5000]
batch_size_sup = [100, 500, 1000, 2500]
linestyles = ['dashdot']*10
times = []
loss_fulls = []
for i,b in enumerate(batch_size_r):
    tick = time.time()
    nmf_sep = NMF_separation(ds = [64,64], epochs = 100, mu_W = 1e-10, mu_H = 1e-10, prob = "full",
        batch_size = batch_size_r[i], batch_size_z = batch_size_z[i], batch_size_sup = batch_size_sup[i], tau_A = 0.05, tau_S = 0.5,
        true_sample = "sup")
    U_z = nmf_sep.create_adversarial(U_sup, V_sup)
    loss_fulls.append(nmf_sep.NMF_concat.fit_full(U_r = U_sup, U_z = U_z, U_sup = U_sup, V_sup = V_sup, conv = True))
    tock = time.time()
    times.append(tock-tick)
print(times)
    

In [None]:
# Plotting
for i, b in enumerate(batch_size_r):
    plt.plot(np.arange(1,101),loss_fulls[i][0,1:],label = f"({b}, {batch_size_z[i]}, {batch_size_sup[i]})", linestyle = linestyles[i])
plt.legend(title = "Batch Sizes" + '\n' + "(True, Adv, Sup)")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.title("Convergence of SMU for D+ANMF")
plt.grid()
plt.savefig("fig/full_conv.png")
plt.show()

We see that for all methods converge, and smaller batch sizes tend to yield faster convergence. Selecting the batch sizes too small can lead to slower convergence and too much randomness. It is also interesting to see that standard NMF converges to a smaller loss than DNMF, because DNMF. This is to be expected as DNMF has to fit the basises while also splitting mixed data, while standard NMF only has to learn bases.


Data rich experiment
---

In this experiment we want to test the proposed methods in a setting where we have a lot of strong supervised data to see how the data settings compare. We test this on $0$ and $1$ digits, as NMF performs relatively well in this situation. We select $5000$ data of each source and $1000$ test data.


In [None]:
# Parameters for experiment
np.random.seed(1)
number_of_experiments = 1
M = 2
Ms_all = [0,1]
N_sup = 5000
N_sup_test = 1000
mu_W = 1e-8
mu_H = 1e-6

epochs = 100
test_epochs = 100
batch_size = 500
batch_size_sup = 500
batch_size_z = 100
wiener = True

Ds = [16,32,48,64,80,96,112,128]

probs = ["std", "adv","sup","exem"]

df = pd.DataFrame(columns = ["id", "d", "prob", "mean_psnrs"])

taus = [0.0,0.1,0.0,0.0]


for ex in range(number_of_experiments):
    Ms = np.random.choice(Ms_all, M, replace = False)

    # Generate data
    mnist = MNIST()

    mnist.generate_supervised(Ms = Ms, N_sup = N_sup, N_sup_test = N_sup_test, pytorch = False, type = "det", weights = [1.0/M] * M)
    V_sup = np.copy(mnist.x_sup_train[:,0,:,:].reshape((mnist.N_sup, 784)).T)
    U_s = np.copy(mnist.y_sup_train.reshape((mnist.N_sup,M,784)).T)
    U_sup = []
    for i in range(M):
        U_sup.append(U_s[:,i,:])

    V_test = np.copy(mnist.x_sup_test[:,0,:,:].reshape((mnist.N_sup_test, 784)).T)
    U_test = np.copy(mnist.y_sup_test.reshape((mnist.N_sup_test,M,784)).T)

    # arrays to store results
    # Each problem, each source, each data
    psnrs = np.zeros((len(probs), M, len(Ds), N_sup_test)) 
  

    seps = []
    for j,d in enumerate(Ds):
        for i, prob in enumerate(probs):
            print(prob)

            # Fit
            sep = NMF_separation(ds = [d] * M, tau_A = taus[i], normalize = True, update_H = False, true_sample = "std",
                epochs = epochs,prob = prob, use_adv_basis = False,
                mu_W = mu_W, mu_H = mu_H, test_epochs = test_epochs, wiener = True,
                batch_size = batch_size, batch_size_z = d, batch_size_sup = batch_size_sup)

            # Use standard NMF as initial conditions
            if prob == "adv" or prob == "sup":
                for k in range(M):
                    sep.NMFs[k].W = np.copy(stdWs[k])

            if prob == "std" or prob == "exem" or prob == "adv": 
                sep.fit(U_r = U_sup, V = V_sup)
            else:
                sep.fit(U_sup = U_sup, V_sup = V_sup)

            if prob == "std":
                stdWs = []
                for k in range(M):
                    stdWs.append(sep.NMFs[k].W)


            # Separate
            out = sep.separate(V_test)

            # Measure quality SHOULD USE eval member function of sep
            psnrs[i,:,j,:] = PSNR(U_test,out)

            df = pd.concat([df, pd.DataFrame({"d": [d]*N_sup_test, "id": np.arange(0,N_sup_test), "prob": [prob]*N_sup_test, "mean_psnrs": np.mean(psnrs[i,:,j,:],axis = 0)})])
            seps.append(deepcopy(sep))

df.to_csv('data_rich.csv')

In [None]:
# Plotting
data_rich_df = pd.read_csv('data_rich.csv', index_col = 0)
plt.rcParams.update({'font.size': 10}) 
seaborn.lineplot(data_rich_df, x = "d", y = "mean_psnrs", hue = "prob", 
    estimator = "median", errorbar = "se", markers = True, dashes = True ,
    marker = 'o',legend=False)
plt.grid()
plt.ylabel("Median PSNR")
plt.legend(labels = ["NMF", "_nolegend_" , "ANMF", "_nolegend_", "DNMF", "_nolegend_", "ENMF"])
plt.title("Impact of number of basis vectors for strong supervised data")
plt.savefig("fig/Data_rich.pdf")
plt.show()

We see that ANMF outperforms all other methods, and performance increases with $d$.

Data rich - Specific example
---

We now plot a specific data point

In [None]:
# Pick some separations
seps_ = seps[12:16]
for sep in seps_:
    sep.wiener = True
    print(sep.ds)

psnrs_std = seps_[0].eval(U_test[:,:,:], V_test, "psnr" , axis = 0)
psnrs_adv = seps_[1].eval(U_test[:,:,:], V_test, "psnr" , axis = 0)
psnrs_sup = seps_[2].eval(U_test[:,:,:], V_test, "psnr" , axis = 0)
psnrs_exem = seps_[3].eval(U_test[:,:,:], V_test, "psnr", axis = 0)

out_std = seps_[0].separate(V_test[:,:])
out_adv = seps_[1].separate(V_test[:,:])
out_sup = seps_[2].separate(V_test[:,:])
out_exem = seps_[3].separate(V_test[:,:])

diff_adv = []

for i in range(len(psnrs_std)):
    diff_adv.append(psnrs_adv[i]-psnrs_std[i])

ids = np.argsort(diff_adv)

psnrs_std = seps_[0].eval(U_test[:,:,:], V_test[:,:], "psnr", aggregate = None)
psnrs_adv = seps_[1].eval(U_test[:,:,:], V_test[:,:], "psnr", aggregate = None)
psnrs_sup = seps_[2].eval(U_test[:,:,:], V_test[:,:], "psnr", aggregate = None)
psnrs_exem = seps_[3].eval(U_test[:,:,:], V_test[:,:], "psnr", aggregate = None)


In [None]:
# Select id where adversarial outperforms standard NMF.
id = ids[-1] 
vmin = 0
vmax = 0.5

plt.imshow(V_test[:,id].reshape((28,28)), cmap = "gray", vmin = vmin)
plt.axis('off')
plt.savefig('fig/v.png', bbox_inches = "tight")
plt.show()


plt.imshow(U_test[:,0,id].reshape((28,28)), cmap = "gray", vmin = vmin)
plt.axis('off')
plt.savefig('fig/u0.png', bbox_inches = "tight")
plt.show()


plt.imshow(U_test[:,1,id].reshape((28,28)), cmap = "gray", vmin = vmin)
plt.axis('off')
plt.savefig('fig/u1.png', bbox_inches = "tight")
plt.show()


fig, ax = plt.subplots()
ax.imshow(out_std[:,0,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u1_std.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_std[0,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u1_std_psnr.png', bbox_inches = "tight")
plt.show()

fig, ax = plt.subplots()
ax.imshow(out_std[:,1,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u0_std.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_std[1,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u0_std_psnr.png', bbox_inches = "tight")
plt.show()

fig, ax = plt.subplots()
ax.imshow(out_adv[:,0,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u1_adv.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_adv[0,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u1_adv_psnr.png', bbox_inches = "tight")
plt.show()

fig, ax = plt.subplots()
ax.imshow(out_adv[:,1,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u0_adv.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_adv[1,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u0_adv_psnr.png', bbox_inches = "tight")
plt.show()

fig, ax = plt.subplots()
ax.imshow(out_sup[:,0,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u1_sup.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_sup[0,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u1_sup_psnr.png', bbox_inches = "tight")
plt.show()

fig, ax = plt.subplots()
ax.imshow(out_sup[:,1,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u0_sup.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_sup[1,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u0_sup_psnr.png', bbox_inches = "tight")
plt.show()

fig, ax = plt.subplots()
ax.imshow(out_exem[:,0,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u1_exem.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_exem[0,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u1_exem_psnr.png', bbox_inches = "tight")
plt.show()

fig, ax = plt.subplots()
ax.imshow(out_exem[:,1,id].reshape((28,28)), cmap = "gray", vmin = vmin)
ax.axis('off')
plt.savefig('fig/u1_exem.png')
ax.text(0.02, 0.95, "PSNR: {:.2f}".format(psnrs_exem[1,id]),
        horizontalalignment='left',
        verticalalignment='top',
        transform=ax.transAxes,
        color='white', fontsize = 20)
plt.savefig('fig/u0_exem_psnr.png', bbox_inches = "tight")
plt.show()




Data rich - Effect of $\tau_A$
---

We now run an equivalent experiment where we try several values of $\tau_A$.

In [None]:
np.random.seed(1)
number_of_experiments = 1
M = 2
Ms_all = [0,1]
N_adv = 1000
N_sup = 5000
N_sup_test = 1000
mu_W = 1e-8
mu_H = 1e-6

epochs_std = 100
epochs = 100
test_epochs = 100
batch_size = 500
batch_size_sup = 500
batch_size_z = 100
wiener = True

Ds = [32,64,96,128]

probs = ["std", "adv","adv","adv", "adv", "adv", "adv"]

df = pd.DataFrame(columns = ["id", "d", "prob", "tau", "mean_psnrs"])

taus = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.5]


for ex in range(number_of_experiments):
    Ms = np.random.choice(Ms_all, M, replace = False)

    # Generate data
    mnist = MNIST()
    #mnist.generate_adverserial(Ms = Ms, Ns = [N_adv]*M, N_V = N_V, type = "det", weights = [1.0/M]*M)
    #U_r = []
    #for i in range(M):
    #    U_r.append(np.copy(mnist.x_r_train[i].reshape((mnist.Ns_adv[i],784)).T))
    #V = np.copy(mnist.x_v_train.reshape((mnist.N_adv_V,784)).T)

    mnist.generate_supervised(Ms = Ms, N_sup = N_sup, N_sup_test = N_sup_test, pytorch = False, type = "det", weights = [1.0/M] * M)
    V_sup = np.copy(mnist.x_sup_train[:,0,:,:].reshape((mnist.N_sup, 784)).T)
    U_s = np.copy(mnist.y_sup_train.reshape((mnist.N_sup,M,784)).T)
    U_sup = []
    for i in range(M):
        U_sup.append(U_s[:,i,:])

    V_test = np.copy(mnist.x_sup_test[:,0,:,:].reshape((mnist.N_sup_test, 784)).T)
    U_test = np.copy(mnist.y_sup_test.reshape((mnist.N_sup_test,M,784)).T)

    # arrays to store results
    psnrs = np.zeros((len(probs), M, len(Ds), N_sup_test)) 

    seps = []
    for j,d in enumerate(Ds):
        for i, prob in enumerate(probs):
            print(prob)

            # Fit
            sep = NMF_separation(ds = [d] * M, tau_A = taus[i], normalize = True, update_H = False, true_sample = "std",
                epochs = epochs_std if prob == "std" else epochs,prob = prob, use_adv_basis = False,
                mu_W = mu_W, mu_H = mu_H, test_epochs = test_epochs, wiener = True,
                batch_size = batch_size, batch_size_z = d, batch_size_sup = batch_size_sup)

            # Use standard NMF as initial conditions
            if prob == "adv" or prob == "sup":
                for k in range(M):
                    sep.NMFs[k].W = np.copy(stdWs[k])
            
            sep.fit(U_r = U_sup, V = V_sup, U_sup = U_sup, V_sup = V_sup)

            if prob == "std":
                stdWs = []
                for k in range(M):
                    stdWs.append(sep.NMFs[k].W)


            # Separate
            out = sep.separate(V_test)

            # Measure quality SHOULD USE eval member function of sep
            psnrs[i,:,j,:] = PSNR(U_test,out)

            df = pd.concat([df, pd.DataFrame({"d": [d]*N_sup_test, "id": np.arange(0,N_sup_test), "prob": [prob]*N_sup_test, "tau": [taus[i]]*N_sup_test, "mean_psnrs": np.mean(psnrs[i,:,j,:],axis = 0)})])
            seps.append(deepcopy(sep))

df.to_csv('data_rich_tau.csv')

In [None]:
# Plotting
data_rich_tau_df = pd.read_csv('data_rich_tau.csv', index_col = 0)
plt.rcParams.update({'font.size': 10}) 
seaborn.lineplot(data_rich_tau_df, x = "tau", y = "mean_psnrs", hue = "d",
    estimator = "median", errorbar = "se", markers = True, dashes = True,
    marker = 'o',legend=False, palette = "flare")
plt.grid()
plt.ylabel("Median PSNR")
plt.legend(labels = [r"$d = 32$", "_nolegend_" , r"$d = 64$", "_nolegend_", r"$d = 96$", "_nolegend", r"$d = 128$"])
plt.title("Impact of " + r"$\tau_A$" + " for ANMF")
plt.xlabel(r"$\tau_A$")
plt.savefig("fig/Data_rich_tau.pdf")
plt.show()

Data poor tuning experiment
---

We now do a larger tuning experiment where we have low amounts of data. We select $250$ strong supervised data and $500$ weak supervised data of each source. We mix each digit with a "one"-digit and use random search to find the best parameters for this specific setting.

In [None]:
"""
BIG NUMERICAL EXAMPLE WITH RANDOM SEARCH
"""
np.random.seed(0)
number_of_experiments = 9
number_of_searches = 10
M = 2
Ms_all = [0,2,3,4,5,6,7,8,9]
N_adv = 500
N_V = 0
N_sup = 250
N_sup_test = 1000

d_dict = {"name": 'ds', "dist": lambda: [64, 64]}
normalize_dict = {"name": 'normalize', "dist": lambda: np.random.choice([True,False], replace = True)}
mu_H_dict = {"name": 'mu_H', "dist": lambda: np.power(10,np.random.uniform(-10,-3))}
mu_W_dict = {"name": 'mu_W', "dist": lambda: np.power(10,np.random.uniform(-10,-3))}
epochs_dict = {"name": 'epochs', "dist": lambda: np.random.randint(1,100)}
test_epochs_dict = {"name": 'test_epochs', "dist": lambda: np.random.randint(100,150)}
batch_size_r_dict = {"name": 'batch_size', "dist": lambda: np.random.choice([250,500], replace = True)}
batch_size_z_dict = {"name": 'batch_size_z', "dist": lambda: np.random.choice([125,250], replace = True)}
batch_size_sup_dict = {"name": 'batch_size_sup', "dist": lambda: np.random.choice([125], replace = True)}
true_sample_dict = {"name": 'true_sample', "dist": lambda: np.random.choice(["std", "sup"], replace = True)}
tau_A_dict = {"name": 'tau_A', "dist": lambda: np.random.uniform(0.0,0.3)}
tau_S_dict = {"name": 'tau_S', "dist": lambda: np.random.uniform(0.0,0.5)}
betas_dict = {"name": 'betas', "dist": lambda: [np.random.uniform(0.5,1.0), np.random.uniform(0.5,1.0)]} 

probs = ["exem", "std", "adv", "sup", "full"]
results = {"M": [], "exem" : [], "std" : [], "adv" : [], "sup" : [], "full": []}

mnist = MNIST()

for i in range(number_of_experiments):


    Ms = [1,i + (i>=1)]
    results["M"].append(i + (i>=1))

    mnist.generate_adverserial(Ms = Ms, Ns = [N_adv]*M, N_V = N_V, type = "det", weights = [1.0/M]*M)
    U_r = []
    for i in range(M):
        U_r.append(np.copy(mnist.x_r_train[i][:,:,:].reshape((mnist.Ns_adv[i],784)).T))
    V = np.copy(mnist.x_v_train.reshape((mnist.N_adv_V,784)).T)
    
    mnist.generate_supervised(Ms = Ms, N_sup = N_sup, N_sup_test = N_sup_test, type = "det", weights = [1.0/M] * M)
    
    V_sup = np.copy(mnist.x_sup_train[:,0,:,:].reshape((mnist.N_sup, 784)).T)
    U_s = np.copy(mnist.y_sup_train.reshape((mnist.N_sup,M,784)).T)
    U_sup = []
    U_test = np.copy(mnist.y_sup_test.reshape((mnist.N_sup_test,M,784)).T)
    U_test_fit = []
    for i in range(M):
        U_sup.append(U_s[:,i,:])
        U_test_fit.append(U_test[:,i,:])
    V_test = np.copy(mnist.x_sup_test[:,0,:,:].reshape((mnist.N_sup_test, 784)).T)


    for prob in probs:
        print(prob)

        prob_dict = {"name": 'prob', "dist": lambda: prob}

        if prob == "exem" or prob == "std":
            W_dict = {"name": 'Ws', "dist": lambda: None}

        param_dicts = [d_dict,
            normalize_dict,
            mu_H_dict, 
            prob_dict, 
            mu_W_dict, 
            epochs_dict, 
            test_epochs_dict,
            batch_size_r_dict, 
            batch_size_z_dict, 
            batch_size_sup_dict,
            tau_A_dict,
            tau_S_dict,
            true_sample_dict,
            betas_dict,
            W_dict,
        ]

        rs = random_search(NMF_separation, param_dicts, N_ex = number_of_searches, cv = 0 if prob != "full" and prob != "sup" else 2, source_aggregate = "mean", data_aggregate = "median")
        rs.fit(U_r = U_r, V = V, U_sup = U_sup, V_sup = V_sup, refit = True)

        results[prob].append(np.median(rs.best_model.eval(U_test[:,:,:],V_test[:,:], "psnr")))

        if prob == "std" or prob == "exem":
            Ws = [np.copy(rs.best_model.NMFs[0].W), np.copy(rs.best_model.NMFs[1].W)]
            W_dict = {"name": 'Ws', "dist": lambda : Ws}

In [None]:
print(results)
data_poor_df = pd.DataFrame(results)
data_poor_df.to_csv('data_poor_tuning.csv')

In [None]:
data_poor_df = pd.read_csv('data_poor_tuning.csv', index_col = 0)
#for i in range(len(data_poor_df)):
#    data_poor_df["M"].loc[i] = i + (i>=1)
new_df = data_poor_df.copy(deep = True)
#
for column in new_df:
    if column != "M":
        new_df[column] = new_df[column] - data_poor_df["std"]
ax = seaborn.lineplot(data = new_df.drop("M", axis = 1).drop("exem", axis = 1), marker = 'o')
ax.set_xticklabels(['','zero', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine'])
ax.set_xlabel('Digit mixed with "one" digit')
ax.set_ylabel(r"$\Delta$" + "Median PSNR")
plt.title("Comparison in the constrained data setting for different classes of digits")
plt.legend(labels = ["NMF", "_nolegend_", "ANMF", "_nolegend_", "DNMF", "_nolegend_", "D+ANMF"])
plt.grid()
plt.savefig('fig/data_poor.pdf')
plt.show()
