In [None]:
!pip install tqdm torchvision librosa wandb matplotlib lpips mediapipe pytorch-msssim scikit-image piq realesrgan

In [None]:
!curl -L -o checkpoints/RealESRGAN_x4plus.pth https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth

In [None]:
# Train syncnet from scrach
!python transformer_syncnet_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/syncnet_checkpoint \
--use_wandb False \
--train_root local_train_files

In [None]:
# syncnet training resume
!python transformer_syncnet_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/syncnet_checkpoint \
--checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000188400.pth \
--use_wandb False \
--train_root local_train_files \
--use_augmentation False

In [None]:
# Train wav2lip
!python transformer_wav2lip_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/wav2lip_checkpoint/more-ref \
--use_wandb False \
--syncnet_checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000202200.pth \
--train_root local_train_files \
--num_of_unet_layers 1 \
--use_augmentation False

In [None]:
# Train wav2lip resume
!python transformer_wav2lip_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/wav2lip_checkpoint/face-enhancer \
--use_wandb False \
--syncnet_checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000221800.pth \
--checkpoint_path checkpoints/wav2lip_checkpoint/face-enhancer/checkpoint_step000139000.pth \
--train_root local_train_files/ \
--use_augmentation False \
--num_of_unet_layers 1

## Helper functions

In [None]:
# Generate the training file
import os

def get_subfolders(directory):
    return [os.path.join(directory, name) for name in os.listdir(directory)
            if os.path.isdir(os.path.join(directory, name))]

for subdir in get_subfolders('training_data'):
  for subdir2 in get_subfolders(subdir):
    for root, dirs, files in os.walk(subdir2):
        print(subdir2)
        # Extract the desired portion (last two parts of the path)
        desired_portion = os.path.join(*subdir2.split(os.sep)[-2:])

        # # Path to the output text file
        output_file_path = "output.txt"

        # # Append the extracted portion to the text file as a new line
        with open(output_file_path, 'a') as f:
             f.write(desired_portion + '\n')

        print(f"Appended '{desired_portion}' to {output_file_path}")


## Load checkpoint

In [1]:
import torch

def load_wav2lip_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)

    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s, strict=False)
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    if optimizer != None:
      for param_group in optimizer.param_groups:
        param_group['lr'] = 0.00001

    return model

def _load(checkpoint_path):
    checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

## LoRA fine tuning syncnet

In [None]:
!pip install peft numpy

In [1]:
def load_syncnet_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s, strict=False)
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    if optimizer != None:
      for param_group in optimizer.param_groups:
        param_group['lr'] = 0.00001

    return model



In [None]:
from peft import LoraConfig, get_peft_model
import torch.nn as nn
import torch
from torch.utils import data as data_utils
from models import TransformerSyncnet as TransformerSyncnet
from syncnet_dataset import Dataset

cross_entropy_loss = nn.CrossEntropyLoss()

# Step 1: Create an instance of the TransformerSyncnet model
model = TransformerSyncnet(num_heads=8, num_encoder_layers=6)

# Step 2: Load the pre-trained weights
load_syncnet_checkpoint("checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth", model, None, reset_optimizer=True, overwrite_global_states=False)

# Define the LoRA configuration
lora_config = LoraConfig(
    r=8,  # Rank of the low-rank update
    lora_alpha=32,  # Scaling factor
    target_modules=['self_attn.out_proj'],  # Targeting the attention layers
    lora_dropout=0.1,
    bias="none"
)

# Apply LoRA to the transformer encoder
model.transformer_encoder = get_peft_model(model.transformer_encoder, lora_config)

# Define optimizer for fine-tuning
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


train_dataset = Dataset('train', 'training_data/')

train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=10, shuffle=True,
        num_workers=0)

# Fine-tuning loop (example)
for epoch in range(2):
    model.train()
    for step, (x, mel, y) in enumerate(train_data_loader):
        # Your training loop here
        optimizer.zero_grad()
        output, audio_embedding, face_embedding = model(x, mel)
        loss = cross_entropy_loss(output, y) #if (global_epoch // 50) % 2 == 0 else contrastive_loss2(a, v, y)
        print('The loss', loss.item())
        loss.backward()
        optimizer.step()

# Save LoRA adapter weights after fine-tuning
torch.save(model.state_dict(), "lora_adapter_weights.pth")


In [None]:
# Load the base model
# Step 1: Create an instance of the TransformerSyncnet model
model = TransformerSyncnet(num_heads=8, num_encoder_layers=6)

# Step 2: Load the pre-trained weights
load_syncnet_checkpoint("checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth", model, None, reset_optimizer=True, overwrite_global_states=False)

# Load the LoRA adapter weights
lora_weights = torch.load("lora_syncnet_weights.pth")

# Update only the LoRA weights (specific keys in the state dict)
model.load_state_dict(lora_weights, strict=False)


## Lora fine tune wav2lip

In [43]:
import torch
import torch.nn as nn
from models import LoRAConv2d, ResUNet, LoRATransposeConv2d  # Assume you have a LoRA-wrapped Conv2d

use_cuda = torch.cuda.is_available()
checkpoint_path = "checkpoints/wav2lip_checkpoint/2-step-process/checkpoint_step000042000.pth"
lora_checkpoint_point = "checkpoints/wav2lip_lora/lora_wav2lip_eddy_weights.pth"

def apply_lora_to_model(model, lora_rank=4, lora_scaling=0.5):
    num_of_layers = 500
    acc = 0

    for name, module in model.named_modules():
        if acc >= num_of_layers:
            break
        
        
        
        if ( "output_block.output_block" in name.lower()):
            
            if isinstance(module, nn.Conv2d):
              print('The module name', name)
              # Create a LoRA-wrapped Conv2d layer
              lora_module = LoRAConv2d(
                  in_channels=module.in_channels,
                  out_channels=module.out_channels,
                  kernel_size=module.kernel_size,
                  stride=module.stride,
                  padding=module.padding,
                  dilation=module.dilation,
                  groups=module.groups,
                  bias=(module.bias is not None),
                  lora_rank=lora_rank,
                  lora_scaling=lora_scaling
              )
              # Copy the weights from the original module
              lora_module.conv.weight = module.weight
              if module.bias is not None:
                  lora_module.conv.bias = module.bias
              # Replace the module in the model
              parent_module = model
              *parent_names, child_name = name.split('.')
              for parent_name in parent_names:
                  parent_module = getattr(parent_module, parent_name)
              setattr(parent_module, child_name, lora_module)

              acc += 1

            # if isinstance(module, nn.ConvTranspose2d):
            #   print('The tranpose module name', name)
            #   # Create a LoRA-wrapped Conv2d layer
            #   lora_module = LoRATransposeConv2d(
            #       in_channels=module.in_channels,
            #       out_channels=module.out_channels,
            #       kernel_size=module.kernel_size,
            #       stride=module.stride,
            #       padding=module.padding,
            #       lora_rank=lora_rank,
            #       lora_scaling=lora_scaling
            #   )
            #   # Copy the weights from the original module
            #   lora_module.conv.weight = module.weight
            #   if module.bias is not None:
            #       lora_module.conv.bias = module.bias
            #   # Replace the module in the model
            #   parent_module = model
            #   *parent_names, child_name = name.split('.')
            #   for parent_name in parent_names:
            #       parent_module = getattr(parent_module, parent_name)
            #   setattr(parent_module, child_name, lora_module)

            #   acc += 1
    return model

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

In [None]:
# Start training

from torch import optim
from wav2lip_dataset import Dataset
from tqdm import tqdm
from torch.utils import data as data_utils



recon_loss = nn.L1Loss()
cross_entropy_loss = nn.CrossEntropyLoss()

train_dataset = Dataset('train', 'preprocessed', 'lora_training_files', False)
train_data_loader = data_utils.DataLoader(train_dataset, batch_size=10, shuffle=True,num_workers=0)

device = torch.device("cuda" if use_cuda else "cpu")


# Instantiate your model
model = ResUNet()

optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=0.00001)

load_wav2lip_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=True)



# Apply LoRA to the model
model = apply_lora_to_model(model, lora_rank=4, lora_scaling=1)

#print(model)

# Freeze original model parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze LoRA parameters
for name, module in model.named_modules():
    if isinstance(module, LoRAConv2d):
        module.lora_A.requires_grad = True
        module.lora_B.requires_grad = True


# Fine-tuning loop (example)
prog_bar = tqdm(enumerate(train_data_loader))
nepochs = 100
epoch = 1
while epoch <= nepochs:
  #for step, (x, indiv_mels, mel, gt) in prog_bar:
  for step, (x, indiv_mels, mel, gt) in enumerate(train_data_loader):
    model.train()
    # Your training loop here
    optimizer.zero_grad()
    g =  model(indiv_mels, x)
    loss = recon_loss(g, gt)
    loss.backward()
    optimizer.step()

    print('Step: {}, Epoch: {}, Sync Loss: {}, L1: {}'.format(step, epoch, 0, loss.item()))

  epoch += 1

# Save LoRA adapter weights after fine-tuning
torch.save(model.state_dict(), lora_checkpoint_point)




In [None]:
# Instantiate your model
model = Wav2Lip()

optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=0.001)

load_wav2lip_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=True)

# Load the LoRA fine-tuned weights
lora_params = torch.load(lora_checkpoint_point)

# Load the LoRA weights into the original model without changing the original weights
model.load_state_dict(lora_params, strict=False)

val_dataset = Dataset('val', 'training_data')
val_data_loader = data_utils.DataLoader(val_dataset, batch_size=1, shuffle=True,num_workers=0)

for x, indiv_mels, mel, gt in val_data_loader:
            if x.shape[0] == 1:
              step += 1
              model.eval()

              # Move data to CUDA device
              x = x.to(device)
              gt = gt.to(device)
              indiv_mels = indiv_mels.to(device)
              mel = mel.to(device)

              g = model(indiv_mels, x)
              
              l1loss = recon_loss(g, gt)
              print('The eval L1 loss', l1loss.item())


## Inference

In [None]:
!python inference.py \
--checkpoint_path checkpoints/wav2lip_checkpoint/face-enhancer/checkpoint_step000505000_striped.pth \
--face input/eddy.mp4 \
--audio input/en_1.wav \
--outfile results/eddy_505000_1layers_gen_en1_1.mp4 \
--model_layers 3 \
--use_ref_img True \
--use_esrgan False \
--iteration 1 
#--lora_checkpoint_path checkpoints/wav2lip_lora/lora_wav2lip_eddy_weights.pth 

## Facemesh

In [83]:
import os

def get_all_file_paths(base_dir):
    file_paths = []  # List to store all file paths

    # Walk through each directory and subdirectory
    for root, dirs, files in os.walk(base_dir):
        for file in files:
            # Construct full file path
            full_path = os.path.join(root, file)
            # Append the full path to the list
            file_paths.append(full_path)

    return file_paths




In [None]:
import mediapipe as mp
import cv2
import numpy as np
import math

mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)

# Define your base directory
base_dir = 'training_data/'

# Get all file paths
all_file_paths = get_all_file_paths(base_dir)

for fname in all_file_paths:
  if fname.lower().endswith(".jpg") and not fname.lower().endswith("_landmarks.jpg"):
    img = cv2.imread(fname)
    img = cv2.resize(img, (192, 192))
    result = face_mesh.process(img)
        
    # Initialize the landmark channel as a zero matrix
    landmark_channel = np.zeros((192, 192), dtype=np.uint8)
    output_fname = fname.replace('.jpg', '_landmarks.jpg')  # Modify this line based on your file naming preference

    if result.multi_face_landmarks:
            
      for face_landmarks in result.multi_face_landmarks:
        for landmark in face_landmarks.landmark:
          # Convert landmark to pixel coordinates
          
          x = min(int(math.floor(landmark.x * 192)), 191)
          y = min(int(math.floor(landmark.y * 192)), 191)

          # Mark the corresponding location in the landmark_channel
          landmark_channel[y, x] = 255  # Mark as white pixel for the landmark
          # Save the landmark channel as an image
          cv2.imwrite(output_fname, landmark_channel)
          
    else:
      print('No face', fname)
      cv2.imwrite(output_fname, landmark_channel)

In [None]:
import os

base_dir = 'training_data/'

# Get all file paths
all_file_paths = get_all_file_paths(base_dir)

for fname in all_file_paths:
    if fname.lower().endswith("_landmarks.jpg"):
        try:
            os.remove(fname)
            print(f"Deleted: {fname}")
        except OSError as e:
            print(f"Error deleting {fname}: {e}")


## Draw network

In [None]:
!pip install torchviz

In [None]:
from torchviz import make_dot
from models import TransformerSyncnet as TransformerSyncnet

model = TransformerSyncnet(num_heads=8, num_encoder_layers=6)

face_data = torch.randn(5, 15, 96, 192) 
audio_data = torch.randn(5, 1, 80, 16)

y = model(face_data, audio_data)
#make_dot(y, params=dict(model.named_parameters())).render("model_architecture_param", format="png")

make_dot(y).render("simplified_syncnet_model", format="png")


In [None]:
from torchviz import make_dot
from models import Wav2Lip

model = Wav2Lip(num_of_blocks=2)

face_data = torch.randn(5, 9, 192, 192)
audio_data = torch.randn(5, 1, 80, 16)

y = model(audio_data, face_data)
#make_dot(y, params=dict(model.named_parameters())).render("model_architecture_param", format="png")

make_dot(y).render("simplified_wav2lip_model", format="png")

## Calculate performance metrics

In [3]:
import torch
import torch.nn.functional as F
import numpy as np



def calculate_psnr(pred_img, target_img, max_pixel_value=255.0):
    """
    Calculate the PSNR between two images.
    
    Args:
        pred_img (torch.Tensor): Generated image tensor of shape (batch, channels, height, width)
        target_img (torch.Tensor): Ground truth image tensor of shape (batch, channels, height, width)
        max_pixel_value (float): Maximum pixel value (255 for 8-bit images)
    
    Returns:
        psnr (float): Peak Signal-to-Noise Ratio value
    """
    # Calculate the Mean Squared Error (MSE)
    mse = F.mse_loss(pred_img, target_img)
    
    if mse == 0:
        return float('inf')  # If MSE is zero, PSNR is infinite

    psnr = 20 * torch.log10(torch.tensor(max_pixel_value)) - 10 * torch.log10(mse)
    
    return psnr.item()



In [None]:
from torch import optim
from wav2lip_dataset import Dataset
from tqdm import tqdm
from torch.utils import data as data_utils
from models import ResUNet
import lpips
import torch
import piq
from pytorch_msssim import ms_ssim, ssim
import cv2


print(piq.__version__)

use_cuda = torch.cuda.is_available()
dataset_path = 'metrics'
training_files_path = 'metrics_files'

# dataset_path = 'training_data'
# training_files_path = 'local_train_files'


train_dataset = Dataset('train', dataset_path, training_files_path, False)
train_data_loader = data_utils.DataLoader(train_dataset, batch_size=4, shuffle=True,num_workers=0)

device = torch.device("cuda" if use_cuda else "cpu")


# Instantiate your model
#checkpoint_path = "checkpoints/wav2lip_checkpoint/more-ref/checkpoint_step000116000.pth"
#checkpoint_path = "checkpoints/wav2lip_checkpoint/more-ref-2layers/checkpoint_step000182000.pth"

#checkpoint_path = "checkpoints/wav2lip_checkpoint/more-ref-3layers/checkpoint_step000505000.pth"
checkpoint_path = "checkpoints/wav2lip_checkpoint/face-enhancer/checkpoint_step000160000.pth"


model = ResUNet(1)

load_wav2lip_checkpoint(checkpoint_path, model, None, reset_optimizer=True)

lpips_loss = lpips.LPIPS(net='vgg').to(device)  # You can choose 'alex', 'vgg', or 'squeeze'

prog_bar = tqdm(enumerate(train_data_loader))

iteration = 10
total_psnr = 0
total_lpips = 0
total_ssim = 0
total_brisque = 0

global_loipses = []
global_ssim = []
global_brisque = []

psnr_max = 0.0
ssim_max = 0.0
lipis_min = 10.0
brisque_min = 100.0

        

#model.eval()

with torch.no_grad():
  for x in range(iteration):
    for step, (x, indiv_mels, mel, gt) in enumerate(train_data_loader):
      
      g =  model(indiv_mels, x)
      g = torch.clamp(g, 0, 1)
      
      img = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255).astype(np.uint8)  # Shape becomes [1, 5, 192, 192, 3]
      gt2 = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255).astype(np.uint8)  # Shape becomes [1, 5, 192, 192, 3]

      #print("img min:", img.min().item(), "img max:", img.max().item())

      # Step 2: Loop through batch and time dimensions to save each frame individually
      for batch_idx, batch in enumerate(img):
          for t, frame in enumerate(batch):
              # Save each frame with unique filename
              filename = f'checkpoints/wav2lip_checkpoint/face-enhancer/{batch_idx}_{t}.jpg'
              cv2.imwrite(filename, frame)
              #print(f"Saved {filename}")
      
      # Calculate PSNR
      p = torch.tensor((g.detach().cpu().numpy() * 255).astype(np.uint8)).float()
      t = torch.tensor((gt.detach().cpu().numpy() * 255).astype(np.uint8)).float()
      psnr_value = calculate_psnr(p, t)
      total_psnr += psnr_value

      num_of_frames = g.shape[2]
      disc_losses = []
      ssim_losses = []
      brisque_scores = []
      

      for i in range(num_of_frames):
        gen_frame = g[:, :, i, :, :]  # Shape: [batch_size, 3, 192, 192]
        gt_frame = gt[:, :, i, :, :]    # Shape: [batch_size, 3, 192, 192]
        
        lpips_f_loss = lpips_loss(gen_frame.to(device), gt_frame.to(device))
        disc_losses.append(lpips_f_loss)

        total_lpips += lpips_f_loss

        # Calculate MS-SSIM
        ms_ssim_value = ms_ssim(gen_frame, gt_frame, data_range=1.0)
        ssim_losses.append(ms_ssim_value)
        total_ssim += ms_ssim_value

        score = piq.brisque(gen_frame)
        brisque_scores.append(score)
        total_brisque += score
      
      avg_lpips = torch.mean(torch.stack(disc_losses))
      avg_ssim = torch.mean(torch.stack(ssim_losses))
      avg_brisque_score = torch.mean(torch.stack(brisque_scores))
      
      global_loipses.append(avg_lpips)
      global_ssim.append(avg_ssim)
      global_brisque.append(avg_brisque_score)
      
      lpips_loss_value = torch.min(torch.stack(disc_losses))
      ssim_value = torch.max(torch.stack(ssim_losses))
      brisque_score = torch.min(torch.stack(brisque_scores))

      if psnr_value > psnr_max:
        psnr_max = psnr_value

      if lpips_loss_value < lipis_min:
        lipis_min = lpips_loss_value

      if ssim_value > ssim_max:
        ssim_max = ssim_value

      if brisque_score < brisque_min:
        brisque_min = brisque_score

      print(f"PSNR(higher the better): {psnr_value} dB, LIPIS(lower the better): {lpips_loss_value}, MS-SSIM(higher the better): {ssim_value}, BRISQUE(lower the better): {brisque_score}")

print(f"The max PSNR: {psnr_max} dB, LIPIS: {lipis_min}, MS-SSIM: {ssim_max}, BRISQUE: {brisque_min}")
print(f"The avg PSNR: {total_psnr/iteration} dB, LIPIS: {torch.min(torch.stack(global_loipses))}, MS-SSIM: {torch.min(torch.stack(global_ssim))}, BRISQUE: {torch.min(torch.stack(global_brisque))}")


In [None]:
import torch

# Load the original model and its weights
model = ResUNet(3)
load_wav2lip_checkpoint('checkpoints/wav2lip_checkpoint/more-ref-3layers/checkpoint_step000505000.pth', model, None, reset_optimizer=True)

# Load the fine-tuned model and its weights
# fine_tuned_model = ResUNet(3)  # Instantiate the model class again if needed
# load_wav2lip_checkpoint('checkpoints/wav2lip_checkpoint/face-enhancer/checkpoint_step000552500.pth', fine_tuned_model, None, reset_optimizer=True)

# Transfer weights for frozen layers from the original model
# for name, param in fine_tuned_model.named_parameters():
#     if name in model.state_dict() and 'face_enhancer' in name:
#         print('Copying', name)
#         model.state_dict()[name].copy_(param)

# Save the combined model
torch.save({
        "state_dict": model.state_dict(),
        "optimizer": None,
        "global_step": 505000,
        "global_epoch": 100,
    }, 'checkpoints/wav2lip_checkpoint/face-enhancer/checkpoint_step000505000_striped.pth')

#torch.save(fine_tuned_model.state_dict(), "checkpoints/wav2lip_checkpoint/face-enhancer/checkpoint_step000505000_combined.pth")
