In [2]:

import os
import re
import numpy as np
import PIL.Image
import torch
import dnnlib
import legacy
from tqdm import tqdm

# ---------------------- SETTINGS ----------------------
IMG_SIZE = 128  # Resize your images to this size
BATCH_SIZE = 16
GENERATED_FRAMES_DIR = './generated_frames'
REAL_FRAMES_DIR = './real_frames'
NETWORK_PKL = 'stylegan3-r-ffhq-1024x1024.pkl'  # Your pre-trained model file
SEQUENCE_NAME = 'real_sequence03'
EPOCHS = 1000
LATENT_DIM = 512
# ------------------------------------------------------

# Create the directory for generated frames if it doesn't exist
os.makedirs(GENERATED_FRAMES_DIR, exist_ok=True)

# --------------------- LOADING PRE-TRAINED STYLEGAN3 --------------------- #

def load_pretrained_model():
    print(f'Loading networks from "{NETWORK_PKL}"...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    with dnnlib.util.open_url(NETWORK_PKL) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)
    print('Model loaded successfully.')
    return G, device

# --------------------- GENERATING AND SAVING FRAMES --------------------- #

def save_generated_frames(generated_images, epoch, sequence_name=SEQUENCE_NAME):
    generated_images = (generated_images.permute(0, 2, 3, 1).cpu().numpy() * 127.5 + 128).astype(np.uint8)
    for i, frame in enumerate(generated_images):
        img = PIL.Image.fromarray(frame, 'RGB')
        file_name = f'{sequence_name}_frame{epoch}_{i:04d}.jpg'
        img.save(os.path.join(GENERATED_FRAMES_DIR, file_name))

def generate_images(G, device, seeds, truncation_psi=1.0, noise_mode='const'):
    label = torch.zeros([1, G.c_dim], device=device)
    for seed in seeds:
        print(f'Generating image for seed {seed}...')
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
        img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
        save_generated_frames(img, seed)

# --------------------- TRAINING FUNCTION --------------------- #

def train_gan(G, device, epochs, batch_size):
    for epoch in tqdm(range(epochs)):
        # Generate images
        seeds = np.random.randint(0, 10000, batch_size)
        generate_images(G, device, seeds)

if __name__ == '__main__':
    G, device = load_pretrained_model()
    train_gan(G, device, EPOCHS, BATCH_SIZE)


ModuleNotFoundError: No module named 'dnnlib'

ModuleNotFoundError: No module named 'dnnlib'

In [2]:
%pip install dnnlib

Note: you may need to restart the kernel to use updated packages.


ERROR: Could not find a version that satisfies the requirement dnnlib (from versions: none)
ERROR: No matching distribution found for dnnlib


In [None]:
pimnp