In [1]:
import torch
import os
from torchvision.utils import save_image
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm

from idem_net_celeba import IdemNetCeleba
from idem_net_mnist import IdemNetMnist
from gen_utils import generate_frequency_noise
from data_loader import load_CelebA

In [2]:
run_id = "celeba20241130-101804"
epoch_num = "_final.pth"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(device)
checkpoint_path = f"checkpoints/{run_id}/{epoch_num}"

if "celeba" in run_id:
  model = IdemNetCeleba(3) # IdemNetMnist()
else:
  model = IdemNetMnist()

state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
model.load_state_dict(state_dict)

cpu


<All keys matched successfully>

In [4]:
# Example Usage
# Assuming you have a DataLoader `dataloader` and a PyTorch model `model`

output_dir = "./data/generated_images_perceptual"

batch_size = 256

# get the data loader
data_loader, _ = load_CelebA(batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
def save_model_outputs_to_images(model, output_dir, batch_size, fft_noise, data_loader, num_generations, device):
    """
    Generate outputs batchwise from a model and save them as images.

    Parameters:
        model: The PyTorch model to generate outputs.
        output_dir: Directory to save the images.
        batch_size: Number of images to generate per batch.
        fft_noise: Whether to use FFT noise for input.
        data_loader: DataLoader to provide input data in batches.
        num_generations: Total number of images to generate.
        device: Device to run the model on (e.g., 'cuda' or 'cpu').
    """
    if fft_noise:
        output_dir = "./data/generated_images_fft"

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    counter = 0  # Counter for filenames
    progress = tqdm(range(int(num_generations / batch_size)))
    data_iter = iter(data_loader)  # Create an iterator for the DataLoader

    with torch.no_grad():  # Disable gradient computation
        for _ in progress:
            if not fft_noise:
                noise = torch.randn([batch_size, 3, 64, 64], device=device)
            else:
                try:
                    # Get the next batch of real images
                    real_imgs, _ = next(data_iter)
                except StopIteration:
                    # Reinitialize the iterator if it runs out of data
                    data_iter = iter(data_loader)
                    real_imgs, _ = next(data_iter)

                real_imgs = real_imgs.to(device)  # Move real images to the target device
                noise = generate_frequency_noise(real_imgs[: batch_size])

            outputs = model(noise)

            for i, output in enumerate(outputs):
                output_image = to_pil_image(output)  # Convert to PIL image
                filename = os.path.join(output_dir, f"output_{counter:05d}.png")
                output_image.save(filename)
                counter += 1  # Increment counter for unique filenames

    print(f"Images saved to {output_dir}")


# Save images from model
save_model_outputs_to_images(model.to(device), output_dir, batch_size, fft_noise=False, data_loader=data_loader, num_generations=256, device=device)

100%|██████████| 1/1 [00:05<00:00,  5.25s/it]

Images saved to ./data/generated_images_perceptual



