Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Fine Tuning #82

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Commits on Mar 4, 2024

  1. First Draft

    This is just the first draft so we can start building this feature.
    
    - Added dataloader.py, which loads data for training
    - Added train.py, with the current training loop
    - Added lora.py, for LoRA wrapper of the stage 1 Transformer
    - Added dummy_dataset folder with 25 data samples to work with when testing (VCTK-->p311)
    - Commented out the initial inference code when stage 1 model is built.
    
    There is no batch processing in the training loop currently (was getting some dimension mismatching in the KVCache.update).
    
    The dataloader works fine, but everything else requires some work. This is just an initial draft so we can start working on this thing together! :-)
    danablend committed Mar 4, 2024
    Configuration menu
    Copy the full SHA
    53011d7 View commit details
    Browse the repository at this point in the history

Commits on Mar 6, 2024

  1. Cleared everything except stage 1

    Switched from Adam to SGD optimizer
    
    Modified the DataLoader to return the first two encodec token hierarchies as a flattened interleaved tensor (let me know if that looks ok to you?)
    
    Modified LoRA wrapper to only fine tune speaker_cond_pos layer. In nanoGPT-LoRA only the causal attention layer is fine tuned (https://github.com/danielgrittner/nanoGPT-LoRA/blob/master/model.py#L128). Would it be worth trying something similar?
    
    Modified training loop to forward pass with entire batches at a time. Loss calculation doesn't work, need to match the GT labels with generated probabilities. Need some direction here.
    danablend committed Mar 6, 2024
    Configuration menu
    Copy the full SHA
    da6e475 View commit details
    Browse the repository at this point in the history
  2. Add accelerate, new training loop

    Almost trains, I have just made a mistake somewhere causing:
    
    RuntimeError: Trying to backward through the graph a second time
    
    I'm guessing it's because we need to do all the preprocessing in the dataloader rather than in the training loop.
    
    Let me know any thoughts. It's getting close :-)
    danablend committed Mar 6, 2024
    Configuration menu
    Copy the full SHA
    22de9d6 View commit details
    Browse the repository at this point in the history
  3. move loss to model + format similar to nanoGPT

    Moved loss calculation to LoRA wrapper model.
    
    Modified training loop to be similar to that of nanoGPT. This involves using a sliding window for prompts & labels, which should more accurately replicate what the model is actually producing at logits level. If my intuition is wrong about this, please correct me.
    danablend committed Mar 6, 2024
    Configuration menu
    Copy the full SHA
    c23d878 View commit details
    Browse the repository at this point in the history

Commits on Mar 7, 2024

  1. Get model to train by clearing cache between iters

    The model is fine tuning now. Not correctly, but it is fine tuning.
    
    Data must be prepared correctly now.
    
    But first, Attention and KVCache must be modified to be compatible with processing batches where batch size > 1.
    danablend committed Mar 7, 2024
    Configuration menu
    Copy the full SHA
    fecd0ac View commit details
    Browse the repository at this point in the history
  2. Cleaned up training loop, issue with data

    Training loop is pretty clean now, and all data preparation is now done in the dataloader.py file.
    
    Loss becomes "nan" when entries in a batch have a lot of variance between each other (eg one entry had to be padded a lot during collation due to big difference in lengths on either prompt or encodec tokens tensors or both).
    
    Issue could perhaps be solved by grouping longer data points together, keeping them of similar length to avoid a lot of padding. Would love to hear any thoughts here.
    danablend committed Mar 7, 2024
    Configuration menu
    Copy the full SHA
    d473493 View commit details
    Browse the repository at this point in the history
  3. Implement sliding window data loading

    Data loading should be more memory optimized, but this runs.
    
    Gonna run some tests to ensure this correctly trains the LoRA. Might wanna test LoRAs on different layers.
    danablend committed Mar 7, 2024
    Configuration menu
    Copy the full SHA
    198fd9d View commit details
    Browse the repository at this point in the history

Commits on Mar 8, 2024

  1. Correct data loading

    Corrected mistake in data preparation. Will start training some LoRAs now and see if this fine tuning code is correctly set up.
    danablend committed Mar 8, 2024
    Configuration menu
    Copy the full SHA
    7c4b93e View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    e7d27e8 View commit details
    Browse the repository at this point in the history

Commits on Mar 10, 2024

  1. Fix dtypes and copy nanoGPT-LoRA training params

    Disabled Accelerate for now.
    
    Properly aligned all dtypes between the model and dataloader. Previously loaded speaker embeddings were not converted to the correct dtype.
    
    Copied most of the nanoGPT-LoRA training parameters.
    
    Renamed "epochs" to "iters" like in nanoGPT.
    danablend committed Mar 10, 2024
    Configuration menu
    Copy the full SHA
    42c4b92 View commit details
    Browse the repository at this point in the history
  2. load pretrained weights in lora layers

    update to further mimic the nanoGPT-LoRA training process
    danablend committed Mar 10, 2024
    Configuration menu
    Copy the full SHA
    5c6c7a8 View commit details
    Browse the repository at this point in the history