In [None]:
!pip install tqdm torchvision librosa wandb matplotlib lpips

In [9]:
# Train syncnet from scrach
!python transformer_syncnet_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/syncnet_checkpoint \
--use_wandb False \
--train_root local_train_files

usage: transformer_syncnet_train.py [-h] --data_root DATA_ROOT
                                    --checkpoint_dir CHECKPOINT_DIR
                                    [--checkpoint_path CHECKPOINT_PATH]
                                    [--train_root TRAIN_ROOT]
                                    [--use_cosine_loss USE_COSINE_LOSS]
                                    [--sample_mode SAMPLE_MODE]
                                    [--use_wandb USE_WANDB]
transformer_syncnet_train.py: error: unrecognized arguments: 


In [None]:
# syncnet training resume
!python transformer_syncnet_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/syncnet_checkpoint \
--checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000124000.pth \
--use_wandb False \
--train_root local_train_files

In [None]:
# Train wav2lip
!python transformer_wav2lip_train.py --data_root training_data/ \
--checkpoint_dir checkpoints \
--use_wandb False \
--syncnet_checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth

In [None]:
# Train wav2lip resume
!python transformer_wav2lip_train.py --data_root training_data/ \
--checkpoint_dir checkpoints \
--use_wandb False \
--syncnet_checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth \
--checkpoint_path checkpoints/checkpoint_step000004600.pth

## Helper functions

In [None]:
# Generate the training file
import os

def get_subfolders(directory):
    return [os.path.join(directory, name) for name in os.listdir(directory)
            if os.path.isdir(os.path.join(directory, name))]

for subdir in get_subfolders('training_data'):
  for subdir2 in get_subfolders(subdir):
    for root, dirs, files in os.walk(subdir2):
        print(subdir2)
        # Extract the desired portion (last two parts of the path)
        desired_portion = os.path.join(*subdir2.split(os.sep)[-2:])

        # # Path to the output text file
        output_file_path = "output.txt"

        # # Append the extracted portion to the text file as a new line
        with open(output_file_path, 'a') as f:
             f.write(desired_portion + '\n')

        print(f"Appended '{desired_portion}' to {output_file_path}")


## LoRA fine tuning syncnet

In [None]:
!pip install peft numpy

In [1]:
def load_syncnet_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s, strict=False)
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    if optimizer != None:
      for param_group in optimizer.param_groups:
        param_group['lr'] = 0.00001

    return model

def _load(checkpoint_path):
    checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

In [None]:
from peft import LoraConfig, get_peft_model
import torch.nn as nn
import torch
from torch.utils import data as data_utils
from models import TransformerSyncnet as TransformerSyncnet
from syncnet_dataset import Dataset

cross_entropy_loss = nn.CrossEntropyLoss()

# Step 1: Create an instance of the TransformerSyncnet model
model = TransformerSyncnet(num_heads=8, num_encoder_layers=6)

# Step 2: Load the pre-trained weights
load_syncnet_checkpoint("checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth", model, None, reset_optimizer=True, overwrite_global_states=False)

# Define the LoRA configuration
lora_config = LoraConfig(
    r=8,  # Rank of the low-rank update
    lora_alpha=32,  # Scaling factor
    target_modules=['self_attn.out_proj'],  # Targeting the attention layers
    lora_dropout=0.1,
    bias="none"
)

# Apply LoRA to the transformer encoder
model.transformer_encoder = get_peft_model(model.transformer_encoder, lora_config)

# Define optimizer for fine-tuning
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


train_dataset = Dataset('train', 'training_data/')

train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=10, shuffle=True,
        num_workers=0)

# Fine-tuning loop (example)
for epoch in range(2):
    model.train()
    for step, (x, mel, y) in enumerate(train_data_loader):
        # Your training loop here
        optimizer.zero_grad()
        output, audio_embedding, face_embedding = model(x, mel)
        loss = cross_entropy_loss(output, y) #if (global_epoch // 50) % 2 == 0 else contrastive_loss2(a, v, y)
        print('The loss', loss.item())
        loss.backward()
        optimizer.step()

# Save LoRA adapter weights after fine-tuning
torch.save(model.state_dict(), "lora_adapter_weights.pth")


In [None]:
# Load the base model
# Step 1: Create an instance of the TransformerSyncnet model
model = TransformerSyncnet(num_heads=8, num_encoder_layers=6)

# Step 2: Load the pre-trained weights
load_syncnet_checkpoint("checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth", model, None, reset_optimizer=True, overwrite_global_states=False)

# Load the LoRA adapter weights
lora_weights = torch.load("lora_syncnet_weights.pth")

# Update only the LoRA weights (specific keys in the state dict)
model.load_state_dict(lora_weights, strict=False)


## Lora fine tune wav2lip

In [2]:
import torch
import torch.nn as nn
from models import LoRAConv2d, Wav2Lip  # Assume you have a LoRA-wrapped Conv2d

def apply_lora_to_model(model, lora_rank=4):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # Create a LoRA-wrapped Conv2d layer
            lora_module = LoRAConv2d(
                in_channels=module.in_channels,
                out_channels=module.out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=(module.bias is not None),
                lora_rank=lora_rank
            )
            # Copy the weights from the original module
            lora_module.conv.weight = module.weight
            if module.bias is not None:
                lora_module.conv.bias = module.bias
            # Replace the module in the model
            parent_module = model
            *parent_names, child_name = name.split('.')
            for parent_name in parent_names:
                parent_module = getattr(parent_module, parent_name)
            setattr(parent_module, child_name, lora_module)
    return model


In [3]:
def load_wav2lip_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s, strict=False)
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    if optimizer != None:
      for param_group in optimizer.param_groups:
        param_group['lr'] = 0.00001

    return model

In [None]:
from torch import optim
from wav2lip_dataset import Dataset
from tqdm import tqdm
from torch.utils import data as data_utils

use_cuda = torch.cuda.is_available()

recon_loss = nn.L1Loss()
cross_entropy_loss = nn.CrossEntropyLoss()

train_dataset = Dataset('train', 'training_data')
train_data_loader = data_utils.DataLoader(train_dataset, batch_size=10, shuffle=True,num_workers=0)

device = torch.device("cuda" if use_cuda else "cpu")


# Instantiate your model
model = Wav2Lip()

optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=0.001)

load_wav2lip_checkpoint("checkpoints/checkpoint_step000136000.pth", model, optimizer, reset_optimizer=True)

# Apply LoRA to the model
model = apply_lora_to_model(model, lora_rank=2)

#print(model)

# Freeze original model parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze LoRA parameters
for name, module in model.named_modules():
    if isinstance(module, LoRAConv2d):
        module.lora_A.requires_grad = True
        module.lora_B.requires_grad = True


# Fine-tuning loop (example)
prog_bar = tqdm(enumerate(train_data_loader))
nepochs = 100
epoch = 1
while epoch <= nepochs:
  #for step, (x, indiv_mels, mel, gt) in prog_bar:
  for step, (x, indiv_mels, mel, gt) in enumerate(train_data_loader):
    model.train()
    # Your training loop here
    optimizer.zero_grad()
    g =  model(indiv_mels, x)
    loss = recon_loss(g, gt)
    loss.backward()
    optimizer.step()

    print('Step: {}, Epoch: {}, Sync Loss: {}, L1: {}'.format(step, epoch, 0, loss.item()))

  epoch += 1

# Save LoRA adapter weights after fine-tuning
torch.save(model.state_dict(), "checkpoints/wav2lip_lora/lora_wav2lip_weights.pth")

In [None]:
# Instantiate your model
model = Wav2Lip()

optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=0.001)

load_wav2lip_checkpoint("checkpoints/checkpoint_step000136000.pth", model, optimizer, reset_optimizer=True)

# Load the LoRA fine-tuned weights
lora_params = torch.load("checkpoints/wav2lip_lora/lora_wav2lip_weights.pth")

# Load the LoRA weights into the original model without changing the original weights
model.load_state_dict(lora_params, strict=False)

val_dataset = Dataset('val', 'training_data')
val_data_loader = data_utils.DataLoader(val_dataset, batch_size=1, shuffle=True,num_workers=0)

for x, indiv_mels, mel, gt in val_data_loader:
            if x.shape[0] == 1:
              step += 1
              model.eval()

              # Move data to CUDA device
              x = x.to(device)
              gt = gt.to(device)
              indiv_mels = indiv_mels.to(device)
              mel = mel.to(device)

              g = model(indiv_mels, x)
              
              l1loss = recon_loss(g, gt)
              print('The eval L1 loss', l1loss.item())
