In [5]:
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f'Using device: {device}')

# Load pretrained BigGAN model
model = BigGAN.from_pretrained('biggan-deep-256').to(device)

# Generate synthetic images for a specific class (e.g., class index 0)
def generate_synthetic_images(class_idx, num_images, truncation=0.4):
    model.eval()
    synthetic_images = []

    with torch.no_grad():
        for _ in range(num_images):
            # Generate random noise vector
            noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1)
            noise_vector = torch.tensor(noise_vector, dtype=torch.float32)

            # Generate one-hot encoded class vector
            class_vector = one_hot_from_int([class_idx], batch_size=1)
            class_vector = torch.tensor(class_vector, dtype=torch.float32)

            # Generate synthetic image
            noise_vector = noise_vector.to(device)
            class_vector = class_vector.to(device)
            synthetic_image = model(noise_vector, class_vector, truncation)

            # Post-process the synthetic image
            synthetic_image = (synthetic_image + 1) / 2  # Scale from [-1, 1] to [0, 1]
            synthetic_images.append(synthetic_image)

    return synthetic_images

# Example: Generate 100 synthetic images for class 0
synthetic_images = generate_synthetic_images(class_idx=0, num_images=100)
for i, img in enumerate(synthetic_images):
    save_image(img, f'synthetic_image_class_0_{i}.png')

# Data augmentation: Add synthetic images to the training set
# You can modify the MultimodalDataset class to include both original and synthetic data


Using device: mps


RuntimeError: Tensor for argument weight is on cpu but expected on mps