In [5]:
import torch
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torch import nn
import numpy as np
import wandb
import os  # Import os for directory operations

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.model(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.model(input).view(-1, 1)

# Load the trained generator
generator = Generator().to(device)
generator.load_state_dict(torch.load('generator_final.pth', map_location=device))
generator.eval()

# Load the trained discriminator (optional, for completeness)
discriminator = Discriminator().to(device)
discriminator.load_state_dict(torch.load('discriminator_final.pth', map_location=device))
discriminator.eval()

# Function to generate images
def generate_images(num_images=64):
    with torch.no_grad():
        noise = torch.randn(num_images, 100, 1, 1, device=device)
        generated_images = generator(noise)
    return generated_images

# Initialize W&B
wandb.init(project="Anime_Face_GAN_Inference", name="Inference_Run")

# Create a directory for saving images if it doesn't exist
os.makedirs('inferimages', exist_ok=True)

# Generate and display images
num_images = 64
generated_images = generate_images(num_images)

# Convert images for display
grid = vutils.make_grid(generated_images, padding=2, normalize=True)
grid = grid.cpu()  # Move the tensor to CPU before converting to numpy
plt.figure(figsize=(15,15))
plt.axis("off")
plt.imshow(np.transpose(grid.numpy(), (1, 2, 0)))  # Convert tensor to numpy array
plt.savefig("inferimages/generated_anime_faces.png", dpi=300, bbox_inches='tight')
plt.close()

# Log the generated images to W&B
wandb.log({"Generated Anime Faces": wandb.Image("inferimages/generated_anime_faces.png")})

print(f"{num_images} anime faces have been generated and saved as 'inferimages/generated_anime_faces.png'")

# Optional: Save individual images and log to W&B
for i, img in enumerate(generated_images):
    img_filename = f"inferimages/generated_anime_face_{i+1}.png"
    vutils.save_image(img, img_filename, normalize=True)
    wandb.log({f"Generated Face {i+1}": wandb.Image(img_filename)})

print(f"Individual images have been saved and logged to W&B")

# Finish the W&B run
wandb.finish()

64 anime faces have been generated and saved as 'inferimages/generated_anime_faces.png'
Individual images have been saved and logged to W&B
