In [None]:
import torch
import os
from torchvision.utils import save_image
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm

from idem_net_celeba import IdemNetCeleba
from idem_net_mnist import IdemNetMnist

In [None]:
run_id = "celeba20241130-101804"
epoch_num = "_final.pth"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(device)
checkpoint_path = f"checkpoints/{run_id}/{epoch_num}"

if "celeba" in run_id:
  model = IdemNetCeleba(3) # IdemNetMnist()
else:
  model = IdemNetMnist()

state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
model.load_state_dict(state_dict)

In [None]:
def save_model_outputs_to_images(model, output_dir, batch_size, num_generations, device):
    """
    Generate outputs batchwise from a model and save them as images.

    Parameters:
        model: The PyTorch model to generate outputs.
        dataloader: DataLoader to provide input data in batches.
        output_dir: Directory to save the images.
        device: Device to run the model on (e.g., 'cuda' or 'cpu').
    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    counter = 0  # Counter for filenames

    progress = tqdm(range(int(num_generations / batch_size)))
    with torch.no_grad():  # Disable gradient computation
        for _ in progress:
            outputs = model(torch.randn([batch_size, 3, 64, 64], device=device))  # Generate outputs
            
            # Assuming outputs are images (e.g., RGB normalized tensors)
            for i, output in enumerate(outputs):
                # Optionally normalize or adjust output here
                output_image = to_pil_image(output)  # Convert to PIL image
                
                # Save image to file
                filename = os.path.join(output_dir, f"output_{counter:05d}.png")
                output_image.save(filename)
                
                counter += 1  # Increment counter for unique filenames

    print(f"Images saved to {output_dir}")

# Example Usage
# Assuming you have a DataLoader `dataloader` and a PyTorch model `model`

output_dir = "./data/generated_images_perceptual"

# Save images from model
save_model_outputs_to_images(model.to(device), output_dir, batch_size=256, num_generations=50000, device=device)