In [None]:
from src.data import load_dataset_and_make_dataloaders
from src.train import add_noise
from src.utils import to_unit_range
import torch
import matplotlib.pyplot as plt

In [None]:
dataloaders, info = load_dataset_and_make_dataloaders(dataset_name="FashionMNIST", root_dir="../data", batch_size=4)
valid_loader = iter(dataloaders.valid)

In [None]:
batch = next(valid_loader)
sigmas = [0, 0.1, 0.5, 1, 2, 3, 5, 8, 10, 20]
print(batch[0].min(), batch[0].max())

ys = []
for sigma in sigmas:
    y = add_noise(batch[0], torch.tensor(sigma))
    ys.append(y)

ys = torch.cat(ys, axis=1)

In [None]:
fig, axes = plt.subplots(ys.shape[0], ys.shape[1], figsize=(ys.shape[1]/2, ys.shape[0]/2))
fig.subplots_adjust(wspace=0.0, hspace=0.0)

for ii, (ax, img) in enumerate(zip(axes.ravel(), ys.reshape(-1, *y.shape[-2:]))):
    scaled_img = to_unit_range(img)
    ax.imshow(scaled_img, interpolation="none", cmap="grey", vmin=0, vmax=1)
    ax.set_axis_off()
    if ii < len(sigmas):
        title = sigmas[ii]
        if ii == 0:
            title = "$\sigma = $" + str(title)
        ax.set_title(title, fontsize=8)