Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FaceLocator - mask conditioning in unet - sample code - animateanything #28

Open
johndpope opened this issue Mar 24, 2024 · 0 comments
Open

Comments

@johndpope
Copy link
Owner

johndpope commented Mar 24, 2024

https://github.com/alibaba/animate-anything/blob/main/train_svd.py

Screenshot from 2024-03-25 05-49-43

seems like the entire unet is fine tuned for this task which sort of aligns to image.

def finetune_unet(accelerator, pipeline, batch, use_offset_noise,
    rescale_schedule, offset_noise_strength, unet, motion_mask, 
    P_mean=0.7, P_std=1.6):
    pipeline.vae.eval()
    pipeline.image_encoder.eval()
    device = unet.device
    dtype = pipeline.vae.dtype
    vae = pipeline.vae
    # Convert videos to latent space
    pixel_values = batch['pixel_values']
    bsz, num_frames = pixel_values.shape[:2]

    frames = rearrange(pixel_values, 'b f c h w-> (b f) c h w').to(dtype)
    latents = vae.encode(frames).latent_dist.mode() * vae.config.scaling_factor
    latents = rearrange(latents, '(b f) c h w-> b f c h w', b=bsz)

    # enocde image latent
    image = pixel_values[:,0].to(dtype)
    noise_aug_strength = math.exp(random.normalvariate(mu=-3, sigma=0.5))
    image = image + noise_aug_strength * torch.randn_like(image)
    image_latent = vae.encode(image).latent_dist.mode() * vae.config.scaling_factor

    if motion_mask:
        mask = batch['mask']
        mask = mask.div(255)
        h, w = latents.shape[-2:]
        mask = T.Resize((h, w), antialias=False)(mask)
        mask[mask<0.5] = 0
        mask[mask>=0.5] = 1
        mask = repeat(mask, 'b h w -> b f 1 h w', f=num_frames).detach().clone()
        mask[:,0] = 0
        freeze = repeat(image_latent, 'b c h w -> b f c h w', f=num_frames)
        condition_latent = latents * (1-mask) + freeze * mask
    else:
        condition_latent = repeat(image_latent, 'b c h w->b f c h w',f=num_frames)


    pipeline.image_encoder.to(device, dtype=dtype)
    images = _resize_with_antialiasing(pixel_values[:,0], (224, 224)).to(dtype)
    images = (images + 1.0) / 2.0 # [-1, 1] -> [0, 1]
    images = pipeline.feature_extractor(
        images=images,
        do_normalize=True,
        do_center_crop=False,
        do_resize=False,
        do_rescale=False,
        return_tensors="pt",
    ).pixel_values 
    image_embeddings = pipeline._encode_image(images, device, 1, False)

    encoder_hidden_states = image_embeddings
    uncond_hidden_states = torch.zeros_like(image_embeddings)
    
    if random.random() < 0.15: 
        encoder_hidden_states = uncond_hidden_states
    # Add noise to the latents according to the noise magnitude at each timestep
    # (this is the forward diffusion process) #[bsz, f, c, h , w]
    rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device)
    sigma = (rnd_normal * P_std + P_mean).exp()
    c_skip = 1 / (sigma**2 + 1)
    c_out =  -sigma / (sigma**2 + 1) ** 0.5
    c_in = 1 / (sigma**2 + 1) ** 0.5
    c_noise = (sigma.log() / 4).reshape([bsz])
    loss_weight = (sigma ** 2 + 1) / sigma ** 2

    noisy_latents = latents + torch.randn_like(latents) * sigma
    input_latents = torch.cat([c_in * noisy_latents, 
        condition_latent/vae.config.scaling_factor], dim=2)
    if motion_mask:
        input_latents = torch.cat([mask, input_latents], dim=2)

    motion_bucket_id = 127
    fps = 7
    added_time_ids = pipeline._get_add_time_ids(fps, motion_bucket_id, 
        noise_aug_strength, image_embeddings.dtype, bsz, 1, False)
    added_time_ids = added_time_ids.to(device)

    loss = 0

    accelerator.wait_for_everyone()
    model_pred = unet(input_latents, c_noise, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids).sample
    predict_x0 = c_out * model_pred + c_skip * noisy_latents 
    loss += ((predict_x0 - latents)**2 * loss_weight).mean()
    if motion_mask:
        loss += F.mse_loss(predict_x0*(1-mask), condition_latent*(1-mask))
    return loss

UPDATE - maybe we can just use Alibaba's pretrained 6.8gb model.
https://github.com/alibaba/animate-anything?tab=readme-ov-file
just need to wire into
motion_bucket_id: int = 127,
MaskStableVideoDiffusionPipeline

The finetune_unet function is designed to fine-tune a U-Net model for video frame generation, integrating it with the Stable Video Diffusion model from Hugging Face's diffusers library. Here's a line-by-line explanation:

python
Copy code
def finetune_unet(accelerator, pipeline, batch, use_offset_noise,
    rescale_schedule, offset_noise_strength, unet, motion_mask, 
    P_mean=0.7, P_std=1.6):
Defines the function with parameters including the training accelerator, the video diffusion pipeline, the training batch, a flag for using offset noise, unet model, etc.
python
Copy code
    pipeline.vae.eval()
    pipeline.image_encoder.eval()
Sets the VAE and image encoder of the pipeline to evaluation mode to prevent them from being updated during training.
python
Copy code
    device = unet.device
    dtype = pipeline.vae.dtype
Retrieves the device and data type from the UNet model and the VAE for consistent tensor operations.
python
Copy code
    vae = pipeline.vae
Extracts the VAE model from the pipeline.
python
Copy code
    pixel_values = batch['pixel_values']
    bsz, num_frames = pixel_values.shape[:2]
Extracts pixel values from the batch and determines the batch size and number of frames.
python
Copy code
    frames = rearrange(pixel_values, 'b f c h w-> (b f) c h w').to(dtype)
    latents = vae.encode(frames).latent_dist.mode() * vae.config.scaling_factor
    latents = rearrange(latents, '(b f) c h w-> b f c h w', b=bsz)
Converts batched frames into a sequence of frames, encodes them using the VAE to get latent representations, and then reshapes them back to their original batched format.
python
Copy code
    image = pixel_values[:,0].to(dtype)
    noise_aug_strength = math.exp(random.normalvariate(mu=-3, sigma=0.5))
    image = image + noise_aug_strength * torch.randn_like(image)
    image_latent = vae.encode(image).latent_dist.mode() * vae.config.scaling_factor
Extracts the first frame from the pixel values, adds noise to it, and then encodes it using the VAE to get the latent representation.
python
Copy code
    if motion_mask:
        mask = batch['mask']
        ...
        condition_latent = latents * (1-mask) + freeze * mask
    else:
        condition_latent = repeat(image_latent, 'b c h w->b f c h w',f=num_frames)
If a motion mask is used, it applies the mask to the latent representations to freeze certain parts of the image across frames. Otherwise, it repeats the image latent across all frames.
python
Copy code
    pipeline.image_encoder.to(device, dtype=dtype)
    images = _resize_with_antialiasing(pixel_values[:,0], (224, 224)).to(dtype)
    ...
    image_embeddings = pipeline._encode_image(images, device, 1, False)
Moves the image encoder to the appropriate device and data type, resizes the first frame, and then encodes it to obtain image embeddings.
python
Copy code
    encoder_hidden_states = image_embeddings
    uncond_hidden_states = torch.zeros_like(image_embeddings)
Sets up the encoder hidden states and creates a tensor for unconditional hidden states.
python
Copy code
    if random.random() < 0.15: 
        encoder_hidden_states = uncond_hidden_states
With a 15% chance, replaces the encoder hidden states with unconditional hidden states to add stochasticity.
python
Copy code
    rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device)
    ...
    loss_weight = (sigma ** 2 + 1) / sigma ** 2
Generates random noise and calculates the loss weight based on the noise strength.
python
Copy code
    noisy_latents = latents + torch.randn_like(latents) * sigma
    input_latents = torch.cat([c_in * noisy_latents, 
        condition_latent/vae.config.scaling_factor], dim=2)
Adds noise to the latents and concatenates them with the condition latent, preparing them as input for the UNet.
python
Copy code
    motion_bucket_id = 127
    fps = 7
    added_time_ids = pipeline._get_add_time_ids(fps, motion_bucket_id, 
        noise_aug_strength, image_embeddings.dtype, bsz, 1, False)
Sets up motion bucket and frame per second (fps) parameters and calculates additional time ids for the model.
python
Copy code
    loss = 0
    accelerator.wait_for


The finetune_unet function is designed to fine-tune a U-Net model for video frame generation, integrating it with the Stable Video Diffusion model from Hugging Face's diffusers library. Here's a line-by-line explanation:

python
Copy code
def finetune_unet(accelerator, pipeline, batch, use_offset_noise,
    rescale_schedule, offset_noise_strength, unet, motion_mask, 
    P_mean=0.7, P_std=1.6):
Defines the function with parameters including the training accelerator, the video diffusion pipeline, the training batch, a flag for using offset noise, unet model, etc.
python
Copy code
    pipeline.vae.eval()
    pipeline.image_encoder.eval()
Sets the VAE and image encoder of the pipeline to evaluation mode to prevent them from being updated during training.
python
Copy code
    device = unet.device
    dtype = pipeline.vae.dtype
Retrieves the device and data type from the UNet model and the VAE for consistent tensor operations.
python
Copy code
    vae = pipeline.vae
Extracts the VAE model from the pipeline.
python
Copy code
    pixel_values = batch['pixel_values']
    bsz, num_frames = pixel_values.shape[:2]
Extracts pixel values from the batch and determines the batch size and number of frames.
python
Copy code
    frames = rearrange(pixel_values, 'b f c h w-> (b f) c h w').to(dtype)
    latents = vae.encode(frames).latent_dist.mode() * vae.config.scaling_factor
    latents = rearrange(latents, '(b f) c h w-> b f c h w', b=bsz)
Converts batched frames into a sequence of frames, encodes them using the VAE to get latent representations, and then reshapes them back to their original batched format.
python
Copy code
    image = pixel_values[:,0].to(dtype)
    noise_aug_strength = math.exp(random.normalvariate(mu=-3, sigma=0.5))
    image = image + noise_aug_strength * torch.randn_like(image)
    image_latent = vae.encode(image).latent_dist.mode() * vae.config.scaling_factor
Extracts the first frame from the pixel values, adds noise to it, and then encodes it using the VAE to get the latent representation.
python
Copy code
    if motion_mask:
        mask = batch['mask']
        ...
        condition_latent = latents * (1-mask) + freeze * mask
    else:
        condition_latent = repeat(image_latent, 'b c h w->b f c h w',f=num_frames)
If a motion mask is used, it applies the mask to the latent representations to freeze certain parts of the image across frames. Otherwise, it repeats the image latent across all frames.
python
Copy code
    pipeline.image_encoder.to(device, dtype=dtype)
    images = _resize_with_antialiasing(pixel_values[:,0], (224, 224)).to(dtype)
    ...
    image_embeddings = pipeline._encode_image(images, device, 1, False)
Moves the image encoder to the appropriate device and data type, resizes the first frame, and then encodes it to obtain image embeddings.
python
Copy code
    encoder_hidden_states = image_embeddings
    uncond_hidden_states = torch.zeros_like(image_embeddings)
Sets up the encoder hidden states and creates a tensor for unconditional hidden states.
python
Copy code
    if random.random() < 0.15: 
        encoder_hidden_states = uncond_hidden_states
With a 15% chance, replaces the encoder hidden states with unconditional hidden states to add stochasticity.
python
Copy code
    rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device)
    ...
    loss_weight = (sigma ** 2 + 1) / sigma ** 2
Generates random noise and calculates the loss weight based on the noise strength.
python
Copy code
    noisy_latents = latents + torch.randn_like(latents) * sigma
    input_latents = torch.cat([c_in * noisy_latents, 
        condition_latent/vae.config.scaling_factor], dim=2)
Adds noise to the latents and concatenates them with the condition latent, preparing them as input for the UNet.
python
Copy code
    motion_bucket_id = 127
    fps = 7
    added_time_ids = pipeline._get_add_time_ids(fps, motion_bucket_id, 
        noise_aug_strength, image_embeddings.dtype, bsz, 1, False)
Sets up motion bucket and frame per second (fps) parameters and calculates additional time ids for the model.
python
Copy code
    loss = 0
    accelerator.wait_for
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant