Now, once we have the preprocessing done, let us focus on the _train.py_ file with all the code required to finetune an LLM on our specific data.

First let us define the **MedData** class where we are going to save our specific med data.

In [2]:
# Importing libraries for Data Utils
import torch
import datasets
import os
from dataclasses import dataclass
from huggingface_hub import login
from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase, BitsAndBytesConfig)
from transformers.utils import PaddingStrategy
from transformers import PreTrainedTokenizerBase
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

# Importing libraries for Model Utils
import lightning as L
# The Accelerator is the main class for enabling distributed training on any type of training setup
from accelerate import Accelerator
from peft import LoraConfig, TaskType, prepare_model_for_kbit_training, get_peft_model
from pytorch_lightning.loggers import TensorBoardLogger
from callback_utils import GenerateText # TODO: to define

  from .autonotebook import tqdm as notebook_tqdm


Let us define the DataCollatorWithPaddingAndLabel class.
In this class we are using the __@dataclass__ decorator.
As written here (https://dzone.com/articles/understanding-pythons-dataclass-decorator), _In a nutshell, the primary goal of the @dataclass decorator is to simplify the creation of classes._ 

In [48]:
@dataclass
class DataCollatorWithPaddingAndLabels:
    tokenizer: PreTrainedTokenizerBase
    padding : Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, samples: Dict[str, List[Any]]) -> Dict[str, Any]:
        # convert into a dict with lists
        features_list = {key: [] for key in samples[0].keys()}
        for sample in samples:
            for key, val in sample.items():
                features_list[key].append(val)
        
        batch = {} # We are preparing the structure we are going to return back
        for key, val in features_list.item():
            if "input_ids" in key:
                padded = self.tokenizer.pad(
                    {'input_ids':val},
                    padding=self.padding,
                    max_length=self.max_length,
                    pad_to_multiple_of=self.pad_to_multiple_of,
                    return_tensors=self.return_tensors,
                    return_attention_mask=True
                )
                batch[key] = padded['input_ids']
                batch[key.replace('input_ids', 'attention_mask')] = padded['attention_mask']
            elif "labels" in key: # not used here
                batch[key] = self.tokenizer.pad(
                    {'input_ids':val},
                    padding=self.padding,
                    max_length=self.max_length,
                    pad_to_multiple_of=self.pad_to_multiple_of,
                    return_tensors=self.return_tensors,
                    return_attention_mask=True
                )['input_ids']
            else:
                batch[key] = val
        return batch     

In [4]:
# Let us introduce the tokenize_sample function
def tokenize_sample(tokenizer, sample, features, add_special_tokens=False, eos_token=False, postpend=""):
    input_text = [f"\n[{key}]\n{sample[key]}" for key in features]
    input_text = " ".join(input_text)
    input_tokens = tokenizer(input_text, return_attention_mask=False, add_special_tokens=add_special_tokens)
    if eos_token:
        input_tokens['input_ids'] += [tokenizer.eos_token_id]
    
    if postpend != "":
        postpend_tokens = tokenizer.encode(postpend, return_tensors="pt", add_special_tokens=False).squeeze(0).tolist()
        input_tokens['input_ids'] += postpend_tokens
    
    return input_tokens

In [5]:
class MedData(torch.utils.data.Dataset):
    def __init__(self, ds_hf, tokenizer):
        self.ds_hf = ds_hf
        self.tokenizer = tokenizer
        self.max_length = 2048

        # Create prompt features
        self.all_features = ['static', 'event', 'death_status'] # same as self.ds_hf.features.keys()
        self.prompt_features = self.all_features[:-1]
        self.outcome_feature = self.all_features[-1]

    def __len__(self):
        return len(self.ds_hf)
    
    def __getitem__(self, idx):
        sample = self.ds_hf[idx]
        full_text = tokenize_sample(tokenizer=self.tokenizer, sample=sample, features=self.all_features, eos_token=True)
        if len(full_text['input_ids']) > self.max_length:
            # we need to cut but we cut by some more to allow for eos etc
            cut_by = 10+len(full_text['input_ids']) - self.max_length
            encoded_and_cut = self.tokenizer(sample['event'])['input_ids'][cut_by:] # let us cut from left (truncating older events)
            sample['event'] = "(...)" + self.tokenizer.decode(encoded_and_cut)
            full_text = tokenize_sample(tokenizer=self.tokenizer, sample=sample, features=self.all_features, eos_token=True)
        
        for key, val in full_text.items():
            sample[key] = val
        
        prompt = tokenize_sample(tokenizer=self.tokenizer, sample=sample, features=self.prompt_features, postpend=f"[{self.outcome_feature}]\n")

        for key, val in prompt.items():
            sample[f"prompt.{key}"]=val
        return sample


In [3]:
def build_dataloaders(tokenizer, params):
    data = datasets.load_from_disk("")
    ds = {phase : MedData(ds_hf, tokenizer=tokenizer) for phase, ds_hf in data.items()}
    data_collate = DataCollatorWithPaddingAndLabels(tokenizer=tokenizer, max_length=2048)
    dl = {
        phase: torch.utils.data.DataLoader(
            ds, 
            batch_size=1, 
            shuffle=True if phase=="train" else False, 
            collate_fn=data_collate) 
        for phase, ds in ds.items()
        }
    return dl 

Now once we had all the data utils settled down we can focus more on the model utils (we need a GPU to run them).

In [None]:
class ModelLightning(L.LightningModule):
    def __init__(self, model, tokenizer, learning_rate):
        super().__init__()
        self.save_hyperparameters(ignore=['model', 'tokenizer'])
        self.model = model
        self.tokenizer = tokenizer
        self.learning_rate = learning_rate
    
    def forward(self, input_ids, attention_mask, labels=None, *args, **kwargs):
        return self.model(input_ids, attention_mask=attention_mask, labels=labels)
    
    def step(self, batch, phase):
        output = self.forward(batch['input_ids'], batch['attention_mask'], batch['input_ids']) # TODO: to understand better
        loss = output.loss
        self.log(f"{phase}/loss", loss) # TODO: how does it work log?
        return loss
    
    def traininig_step(self, batch, batch_idx):
        return self.step(batch, 'train')
    
    def validation_step(self, batch, batch_idx):
        return self.step(batch, 'validation')
    
    def test_step(self, batch, batch_idx):
        return self.step(batch, 'test')
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

Here we are creating a function __load_model()__ which is responsible for taking the model we want to train and load it trough the __prepare_model_for_kbit_training__ method, which wraps the entire protocol for preparing a model before running a training.
Here there are some hyperparameters, both for the model itself (__AutoModelForCausalLM__) but also for LoRA adapter (__LoraConfig__).

Some references:
    - LoRA Parameters in general: https://huggingface.co/docs/peft/package_reference/lora
    - PEFT parameters (e.g. inference_mode): https://huggingface.co/docs/peft/quicktour
    - LoRA alpha : https://datascience.stackexchange.com/questions/123229/understanding-alpha-parameter-tuning-in-lora-paper

It still somehow unclear how tune these hyperparameters


In [1]:
# we should have params in the name
def load_model(params):
    device_index = Accelerator().process_index
    device_map = {"": device_index}
    model = AutoModelForCausalLM.from_pretrained(
       params["model_name"],
        trust_remote_code=True, 
        load_in_8bit=params.get("load_in_8bit", False),
        load_in_4bit=params.get("load_in_4bit", False),
        device_map=device_map
    )

    # build adapters per task
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=params["lora_dim"]
        lora_alpha=params["lora_alpha"]
        lora_dropout=params["lora_dropout"],
        target_modules=params["lora_target_modules"]
    )

    model = prepare_model_for_kbit_training(model)

    # Add adapters
    model = get_peft_model(model, lora_config)
    print(model.print_trainable_parameters())
    return model

In [None]:
def load_pl_module(tokenizer, params, checkpoint_path=None):
    print("build_model...")
    model = load_model(params)
    if checkpoint_path:
        print("loading from checkpoint...")
        pl_module = ModelLightning.load_from_checkpoint(checkpoint_path, model=model, tokenizer=tokenizer, learning_rate=params['learning_rate'])
        del model
        torch.cuda.empty_cache()
    else:
        pl_module = ModelLightning(model=model, tokenizer=tokenizer, learning_rate=params['learning_rate'])
    return pl_module

Now, let us define the __params__ dictionary containing all the parameters useful for generalizing all this code.

In [None]:
torch.set_float32_matmul_precision("medium")
params = {
    "model_name": "meta-llama/Llama-3.1-8B",
    'accumulate_grad_batches': 16,
    'precision': 16,
    'val_check_interval': 0.25,
    'max_epochs': 100,
    'batch_size': 1,
    'max_length': 2048,
    'learning_rate': 1e-6,
    # lora parameters
    'load_in_8bit' : True,
    "lora_dim": 256,
    "lora_alpha": 256,
    "lora_dropout": 0.1,
    'lora_target_modules': None
}
params['name'] = params['model_name'].split("/")[-1]

In [None]:
if __name__ == "__main__":

    login("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")
    tokenizer = AutoTokenizer.from_pretrained(params['model_name'])
    tokenizer.pad_token = tokenizer.eos_token
    print("build dataloaders...")
    dl = build_dataloaders(tokenizer, params)
    checkpoint_path = None
    pl_module = load_pl_module(tokenizer, params, checkpoint_path=checkpoint_path)

    # Let us create a tensorboard logger: # TODO: to understand
    callbacks = [
        GenerateText(dataloaders=dl, max_token_len=params['max_length']),
        L.pytorch.callbacks.ModelCheckpoint(monitor="validation/loss")]
    print("start training..")
    logger = TensorBoardLogger("tb_logs", name=params['name'])
    trainer = L.Trainer(
        logger=logger,
        callbacks=callbacks,
        max_epochs=params.get('max_epochs', 10),
        accumulate_grad_batches=params.get('accumulate_grad_batches', 1),
        precision=params.get('precision', '16-mixed'),
        val_check_interval=params.get('val_check_interval', 0.5),
        )

    trainer.fit(model=pl_module, train_dataloaders=dl['train'],val_dataloaders= dl['validation'])