In [None]:
!pip install tqdm torchvision librosa wandb matplotlib lpips

In [None]:
# Train syncnet from scrach
!python transformer_syncnet_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/syncnet_checkpoint \
--use_wandb False

In [None]:
# syncnet training resume
!python transformer_syncnet_train.py --data_root training_data/ \
--checkpoint_dir checkpoints/syncnet_checkpoint \
--checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth \
--use_wandb False

In [None]:
# Train wav2lip
!python transformer_wav2lip_train.py --data_root training_data/ \
--checkpoint_dir checkpoints \
--use_wandb False \
--syncnet_checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth

In [25]:
# Train wav2lip resume
!python transformer_wav2lip_train.py --data_root training_data/ \
--checkpoint_dir checkpoints \
--use_wandb False \
--syncnet_checkpoint_path checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth \
--checkpoint_path checkpoints/checkpoint_step000000001.pth

use_cuda: False
total trainable params 68442979
Load checkpoint from: checkpoints/checkpoint_step000000000.pth
  checkpoint = torch.load(checkpoint_path,
Traceback (most recent call last):
  File "/Users/eddyma/DEV/Github/Wav2Lip/transformer_wav2lip_train.py", line 391, in <module>
    load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
  File "/Users/eddyma/DEV/Github/Wav2Lip/transformer_wav2lip_train.py", line 343, in load_checkpoint
    checkpoint = _load(path)
                 ^^^^^^^^^^^
  File "/Users/eddyma/DEV/Github/Wav2Lip/transformer_wav2lip_train.py", line 334, in _load
    checkpoint = torch.load(checkpoint_path,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/for_wav2lip/lib/python3.12/site-packages/torch/serialization.py", line 1065, in load
    with _open_file_like(f, 'rb') as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/for_wav2lip/lib/python3.12/site-packages/torch/serialization.py", lin

## Helper functions

In [None]:
# Generate the training file
import os

def get_subfolders(directory):
    return [os.path.join(directory, name) for name in os.listdir(directory)
            if os.path.isdir(os.path.join(directory, name))]

for subdir in get_subfolders('training_data'):
  for subdir2 in get_subfolders(subdir):
    for root, dirs, files in os.walk(subdir2):
        print(subdir2)
        # Extract the desired portion (last two parts of the path)
        desired_portion = os.path.join(*subdir2.split(os.sep)[-2:])

        # # Path to the output text file
        output_file_path = "output.txt"

        # # Append the extracted portion to the text file as a new line
        with open(output_file_path, 'a') as f:
             f.write(desired_portion + '\n')

        print(f"Appended '{desired_portion}' to {output_file_path}")


## LoRA fine tuning

In [None]:
!pip install peft numpy

In [2]:
def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s, strict=False)
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    if optimizer != None:
      for param_group in optimizer.param_groups:
        param_group['lr'] = 0.00001

    return model

def _load(checkpoint_path):
    checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

In [11]:
from peft import LoraConfig, get_peft_model
import torch.nn as nn
import torch
from torch.utils import data as data_utils
from models import TransformerSyncnet as TransformerSyncnet
from syncnet_dataset import Dataset

cross_entropy_loss = nn.CrossEntropyLoss()

# Step 1: Create an instance of the TransformerSyncnet model
model = TransformerSyncnet(num_heads=8, num_encoder_layers=6)

# Step 2: Load the pre-trained weights
load_checkpoint("checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth", model, None, reset_optimizer=True, overwrite_global_states=False)

# Define the LoRA configuration
lora_config = LoraConfig(
    r=8,  # Rank of the low-rank update
    lora_alpha=32,  # Scaling factor
    target_modules=['self_attn.out_proj'],  # Targeting the attention layers
    lora_dropout=0.1,
    bias="none"
)

# Apply LoRA to the transformer encoder
model.transformer_encoder = get_peft_model(model.transformer_encoder, lora_config)

# Define optimizer for fine-tuning
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


train_dataset = Dataset('train', 'training_data/')

train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=10, shuffle=True,
        num_workers=0)

# Fine-tuning loop (example)
for epoch in range(2):
    model.train()
    for step, (x, mel, y) in enumerate(train_data_loader):
        # Your training loop here
        optimizer.zero_grad()
        output, audio_embedding, face_embedding = model(x, mel)
        loss = cross_entropy_loss(output, y) #if (global_epoch // 50) % 2 == 0 else contrastive_loss2(a, v, y)
        print('The loss', loss.item())
        loss.backward()
        optimizer.step()

# Save LoRA adapter weights after fine-tuning
torch.save(model.state_dict(), "lora_adapter_weights.pth")


Load checkpoint from: checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth


  checkpoint = torch.load(checkpoint_path,


-----
The loss 0.4326261878013611
The loss 0.08643348515033722
The loss 0.2720319628715515
The loss 0.18335837125778198
The loss 1.0465641021728516
The loss 0.6543803215026855
The loss 0.3780989944934845
The loss 0.707392692565918
The loss 1.0432274341583252
The loss 2.1238856315612793


In [12]:
# Load the base model
# Step 1: Create an instance of the TransformerSyncnet model
model = TransformerSyncnet(num_heads=8, num_encoder_layers=6)

# Step 2: Load the pre-trained weights
load_checkpoint("checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth", model, None, reset_optimizer=True, overwrite_global_states=False)

# Load the LoRA adapter weights
lora_weights = torch.load("lora_adapter_weights.pth")

# Update only the LoRA weights (specific keys in the state dict)
model.load_state_dict(lora_weights, strict=False)


Load checkpoint from: checkpoints/syncnet_checkpoint/checkpoint_step000066000.pth


  checkpoint = torch.load(checkpoint_path,
  lora_weights = torch.load("lora_adapter_weights.pth")


_IncompatibleKeys(missing_keys=['transformer_encoder.layers.0.self_attn.in_proj_weight', 'transformer_encoder.layers.0.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.out_proj.weight', 'transformer_encoder.layers.0.self_attn.out_proj.bias', 'transformer_encoder.layers.0.linear1.weight', 'transformer_encoder.layers.0.linear1.bias', 'transformer_encoder.layers.0.linear2.weight', 'transformer_encoder.layers.0.linear2.bias', 'transformer_encoder.layers.0.norm1.weight', 'transformer_encoder.layers.0.norm1.bias', 'transformer_encoder.layers.0.norm2.weight', 'transformer_encoder.layers.0.norm2.bias', 'transformer_encoder.layers.1.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_bias', 'transformer_encoder.layers.1.self_attn.out_proj.weight', 'transformer_encoder.layers.1.self_attn.out_proj.bias', 'transformer_encoder.layers.1.linear1.weight', 'transformer_encoder.layers.1.linear1.bias', 'transformer_encoder.layers.1.linear2.weight', 'transformer_enco