# Preprocess & add Noise

In [1]:
from dataclasses import dataclass
from utils.visualization import plot_reconstructed_images

from preprocessing.noise import BlockNoise, BlurNoise
from preprocessing.custom_preprocess import ImagePreprocessor
from preprocessing.load_data import load_image_data
from algorithms import MultiplicativeUpdateNMF, RobustNMF, SparseNMF

@dataclass
class ExperimentParam:
    nmf: str # 'multiplicative_update' | 'robust' | 'sparse'
    dataset: str # 'yaleB' | 'orl'
    noises: list # ['block_noise', 'blur_noise']

class Experiment:
    def __init__(self, param: ExperimentParam):
        self.param = param


    def run(self, run_name: str):
        # set up
        preprocessor = ImagePreprocessor(max_size=3_000)
        block_noise = BlockNoise(mean=0.2, std=0.05, prob=0.6)
        blur_noise = BlurNoise(mean=2, std=0.5, prob=0.6)

        # noise
        noises = []
        if "block_noise" in self.param.noises:
            noises.append(block_noise)
        if "blur_noise" in self.param.noises:
            noises.append(blur_noise)

        # load data
        data_path = './data/CroppedYaleB' if self.param.dataset == 'yaleB' else './data/ORL'
        image_data, image_labels, image_objects = load_image_data(
            data_path,
            preprocess_function=preprocessor.preprocess,
            noise_functions=noises,
            data_fraction=0.5,
        )

        # run NMF algorithms
        if self.param.nmf == 'multiplicative_update':
            nmf_object = MultiplicativeUpdateNMF(
                V=image_data.T,
                num_features=20,
                max_iters=1000,
            )
        elif self.param.nmf == 'robust':
            nmf_object = RobustNMF(
                V=image_data.T,
                num_features=20,
                max_iters=1000,
                lambda_param=0.5,
                learning_rate=0.0001,
            )
        elif self.param.nmf == 'sparse':
            nmf_object = SparseNMF(
                V=image_data.T,
                num_features=20,
                max_iters=1000,
                alpha=0.5,
                beta=0.5
            )
        else:
            raise ValueError("nmf must be one of 'multiplicative_update', 'robust', 'sparse'")

        nmf_object.fit(
            plot_metrics=False,
            use_tqdm=True,
            early_stop=True,
            patience=20,
            tol=1e-4,
        )

        # collect metrics
        metrics = nmf_object.evaluate(image_labels)
        reconstructed_images_path = plot_reconstructed_images(nmf_object.get_reconstruction().T, image_objects, preprocessor)

        return {
            "name": run_name,
            "param": self.param,
            "metrics": metrics,
            "reconstructed_images_path": reconstructed_images_path,
        }



In [2]:
# Run Experiments

param = ExperimentParam(
    nmf='multiplicative_update',
    dataset='orl',
    noises=[],
)

results = []
for i in range(1, 6):
    experiment = Experiment(param)
    result = experiment.run(f"base_orl_{i}")
    results.append(result)

MultiplicativeUpdateNMF Progress:   3%|▎         | 28/1000 [00:38<22:33,  1.39s/it, Reconstruction RMSE=0.183, Cost Function=1.3e+5] 


KeyboardInterrupt: 