In [1]:
import torch.nn as nn
import torchvision.transforms as T
import torch
from IPython.display import display

device = "cuda" if torch.cuda.is_available() else "cpu"
class GeneratorDCGAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(128, 512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)
    
class GeneratorGAN(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)



In [None]:
@torch.inference_mode()
def inference_gan():
    generator = torch.compile(GeneratorGAN(256, 784)).to(device)
    generator.load_state_dict(torch.load('model_weights/mnist-G.pth'))
    x = torch.randn(30, 256, device='cuda')
    y = generator(x)
    img = T.functional.to_pil_image(y[0].cpu().view(28, 28))
    return img

@torch.inference_mode()
def inference_dcgan():
    generator = torch.compile(GeneratorDCGAN()).to(device)
    generator.load_state_dict(torch.load('model_weights/animefacedataset-G2.pth'))
    def denorm(img_tensors):
        stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
        return img_tensors * stats[1][0] + stats[0][0]
    x = torch.randn(64, 128, 1, 1, device='cuda')
    y = generator(x)
    img = T.functional.to_pil_image(denorm(y[0].cpu()))
    return img

inference_gan()