In [2]:
# !pip install transformers==4.31.0 peft==0.4.0 accelerate==0.21.0 bitsandbytes==0.40.2 safetensors>=0.3.1 tokenizers>=0.13.3
# !pip install word2number

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
from tqdm import tqdm
import pandas as pd

In [5]:
with open("/content/drive/My Drive/Datasets/Harvard n2c2 NLP Research Data Sets/2018 (Track 2) ADE and Medication Extraction Challenge/training_20180910/100035.txt", 'r') as f:
  text = f.read()
  print(text[10179:10197])

recurrent seizures


In [6]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: kaivalya mannam
"""
import re
from bisect import bisect_right
from word2number import w2n

max_sentence_length = 100

def readTextFile(filename):

    # open the txt file with prescription data
    text_file = open(filename, 'rt')

    # an array of dicts - each dict contains information about the words for a particular line
    text_info = []

    index = 0  # position of character in txt file
    line = text_file.readline()

    # read txt file
    while line != '':

        line_dict = {"start": index, "end": index + len(line)}

        # get the array of words, and start indices, initial sequence labels, target labels for each word
        words_array, starts, normwords = getWordsArray(line, index)

        # default sequence label is NA and default target label is O
        sequence_labels = ['NA' for i in range(0, len(words_array))]
        target_labels = ['O' for i in range(0, len(words_array))]

        # update the dictionary, and append it
        line_dict.update({'words': words_array, 'sequences': sequence_labels,
                          'targets': target_labels, 'starts': starts, 'normwords': normwords})
        text_info.append(line_dict)

        # update the buffer
        index += len(line)
        line = text_file.readline()

    text_file.close()

    return text_info

# gets annotated lines

def readAnnFile(filename):

    # open annotation file, and read the lines
    ann_file = open(filename, 'rt')
    inplines = ann_file.readlines()

    lines = []
    for line in inplines:
        # get rid of \n at the end of the line
        if line[-1:]=='\n':
            line = line.strip()
        inpcomps = line.split('\t')
        newcomps = []
        # inpcomps[0] is entity
        newcomps.append(inpcomps[0])
        if len(inpcomps)<2:
            print("problem skipping line {} in file {}".format(line, filename))
            continue
        # inpcomps[1] is metadata, split by space and append
        for comp in inpcomps[1].split(' '):
            newcomps.append(comp)
        if len(inpcomps)>2:
            # finally what is left is word string, split it using getWordsArray
            # this is so getWordsArray will work properly
            # inpcomps[2] = inpcomps[2]+'\n'
            # for comp in getWordsArray(inpcomps[2]):
            #     newcomps.append(comp)
            newcomps.append(inpcomps[2])
        lines.append(newcomps)

    # convert each line to an array of words

    # extract the T and R lines
    t_lines = list(filter(lambda line: "T" in line[0], lines))
    r_lines = list(filter(lambda line: "R" in line[0], lines))

    # sort by start #
    t_lines = sorted(t_lines, key=lambda line: int(line[2]))

    t_stats = list(map(lambda tok: tok[1], t_lines))
    r_stats = list(map(lambda tok: tok[1], r_lines))

    return t_lines, t_stats, r_lines, r_stats

# use annotated lines to update sequences and entities in our dictionaries
# assumes input is T lines

def readEntities(lines, text_info):

    # obtain a list of the starting location of each line in the file
    start_indices = list(
        map(lambda line_dict: int(line_dict['start']), text_info))

    entity_dict = {}

    # read annotation file
    for line_array in lines:

        # if its not a blank line
        if (not(len(line_array) == 0)):

            [sequence, entity, start, end, word_array] = getAnnInfo(
                line_array)  # unpack information about the annotation

            entity_dict[sequence]={'entity':entity, 'start':start, 'end':end, 'span':word_array}

            # finds which lines are spanned by the annotation
            line_start = bisect_right(start_indices, int(start)) - 1
            line_end = bisect_right(start_indices, int(end)) - 1

            # how many extra lines the annotation spans (in the txt file)
            extraLines = line_end - line_start
            currentLine = 0  # current line being read (indexed at 0)

            # current line dictionary we're looking to find the words in our annotation
            my_dict = text_info[line_start]

            # the location (in dictionary) of the first word
            index = bisect_right(my_dict['starts'], int(start)) - 1

            # loop through each word in the annotation
            i = 0
            while i < len(word_array):

                # found the word and it matches :)
                if my_dict['words'][index] == word_array[i]:

                    # update info accordingly
                    update_target = True
                    if my_dict['sequences'][index] != 'NA':
                        # if it exists, we give preference to ADE over Reason
                        existing = my_dict['targets'][index]
                        if existing in ['B-ADE', 'I-ADE'] and entity == 'Reason':
                            #print("Keeping ADE over Reason")
                            update_target = False
                        elif existing in ['B-Reason', 'I-Reason'] and entity == 'ADE':
                            #print("Prefering ADE over Reason")
                            update_target = True
                        elif existing in ['B-Drug'] and entity == 'ADE':
                            # the first word is a Drug as well as ADE
                            # so we will write this as a Drug for first token
                            # and keep the rest as ADE
                            #print("ADE overlaps with Drug, keeping subset for ADE")
                            update_target = False
                        elif existing in ['B-ADE', 'B-Form'] and entity == 'Drug':
                            # this is same as the previous case
                            # for form its like Insulin Pen where the Insulin part is Drug
                            # and Pen part is Form
                            # so we will write this as a Drug for first token
                            # and keep the rest as is
                            #print("... overlaps with Drug, keeping subset ...")
                            update_target = True
                        elif existing in ['B-Drug'] and entity == 'Form':
                            # skip this case and put the I- when we loop
                            #print("... overlap with Drug 2, keeping subset ...")
                            update_target = False
                        elif existing in ['B-Reason', 'I-Reason'] and entity == 'Drug':
                            # here reason ends with drug name
                            #print("... overlap of drug with Reason, keeping subset ...")
                            update_target = True
                        elif existing in ['I-Drug'] and entity == 'Strength':
                            # here strength overlaps with Drug ...
                            #print("... overlap of strength with Drug, keeping subset ...")
                            update_target = True
                        elif entity == existing[2:]:
                            # this is fine they are same
                            update_target = True
                        else:
                            print("Skipping duplicate:")
                            print("\tExisting: " + my_dict['targets'][index])
                            print("\tNew: " + entity)
                    if update_target:
                        my_dict['sequences'][index] = sequence
                        if (i == 0):
                            my_dict['targets'][index] = "B-" + entity
                        else:
                            my_dict['targets'][index] = "I-" + entity

                # found the word, but its in a compound word :|
                elif word_array[i] in my_dict['words'][index]:

                    # here we don't do the update_target check for now.
                    # try modifying the dict because the word is hidden in a compound word
                    new_dict, index = modifyDict(my_dict, word_array[i], index)

                    # we found the word in the modified dict, so update it
                    text_info[line_start + currentLine] = new_dict

                    continue

                # didn't find the word :(
                else:

                    # try moving along the dictionary to see if we can find it
                    index += 1

                    # if we've moved along the dictionary too far, raise an error
                    if (index == len(my_dict['words'])):
                        # raise "Error! Couldn't find " + word_array[i] + " within" + '\n' + str(my_dict['words'])
                        print(
                            "Uh oh, couldn't update data structure for this annotation")
                        print(word_array)
                        break
                    continue

                 # if this word was the last in txt file line, and we know we have multiple lines in this annotation
                if (index == len(my_dict['words']) - 1) and (currentLine < extraLines):

                    # move to the next line
                    while (currentLine < extraLines):
                        currentLine += 1
                        my_dict = text_info[line_start + currentLine]

                        # skip empty lines, and stop jumping past lines once we reach a non-blank line
                        if len(my_dict['words']) != 0:
                            break

                    # make index -1, so it becomes 0 after incrementing below
                    index = -1

                # move onto next word
                index += 1
                i += 1
    return text_info, entity_dict

# converts the text_info data structure into sentences

def makeSentences_internal(text_info, new_tok_counter, sent_len_counter, max_sent_len=100, paragraphMode=False):

    sentence_length = 0  # length of current words
    sentences = []
    sentence = defaultSentence()

    for line_num, line_dict in enumerate(text_info):

        # if this line is empty
        if len(line_dict['words']) == 0:

            # and the previous line wasn't blank
            if sentence_length > 0:

                # append the old sentence
                sentences.append(sentence)
                sent_len_counter.update([sentence_length])

                # initialize a new sentence
                sentence = defaultSentence()
                sentence_length = 0

        if len(line_dict['words'])!=len(line_dict['normwords']):
            assert False

        # read the words
        for i in range(0, len(line_dict['words'])):

            # append the info of each word to our sentence
            sentence['seq'].append(line_dict['sequences'][i])
            sentence['words'].append(line_dict['words'][i])
            sentence['normwords'].append(line_dict['normwords'][i])
            sentence['starts'].append(str(line_dict['starts'][i]))
            sentence['line_num'].append(str(line_num + 1))
            sentence['word_index'].append(str(i))

            # append the right entity + secondary entity info
            entity = line_dict['targets'][i]
            sentence['targets'].append(entity)

            # increment sentence length and update counter
            sentence_length += 1
            new_tok_counter.update([entity])

            # add a break either when we reach a period or the sentence is too long
            if (paragraphMode==False and line_dict['words'][i] == ".") or (max_sent_len != 0 and sentence_length > max_sent_len):
                #if (paragraphMode==False and max_sent_len != 0 and sentence_length > max_sent_len):
                #    print("breaking sentence len due to max_sent_len")

                if sentence_length > 0:

                    # append the old sentence
                    sentences.append(sentence)
                    sent_len_counter.update([sentence_length])

                    # initialize a new sentence
                    sentence = defaultSentence()
                    sentence_length = 0

    # append the last remaining sentence if any
    if sentence_length > 0:
        # append the old sentence
        sentences.append(sentence)
        sent_len_counter.update([sentence_length])

    return sentences


def makeSentences(text_info, new_tok_counter, sent_len_counter):

    global max_sentence_length
    sentences = makeSentences_internal(text_info, new_tok_counter, sent_len_counter, max_sentence_length, False)
    return sentences

def makeSentences_for_predict(text_info, new_tok_counter, sent_len_counter):

    global max_sentence_length
    sentences = makeSentences_internal(text_info, new_tok_counter, sent_len_counter, max_sentence_length*3, False)
    return sentences

def makeSentences_paragraph(text_info, new_tok_counter, sent_len_counter):

    # in paragraph mode, there is no max_sentence_length, so we pass 0
    sentences = makeSentences_internal(text_info, new_tok_counter, sent_len_counter, 0, True)
    return sentences

# writes unified tokens to a file


# modifies the dictionary to find a missing word
# suppose we are looking for "anthracycline", but the words list is ["anthracycline-induced", "cardiomyopathy", ...]
# this function modifies the list of words to be ["anthracycline", "-induced", "cardiomyopathy", ...], taking in index of 0

def modifyDict(line_dict, target_word, index):

    # compound_word contianing the target word
    compound_word = line_dict['words'][index]

    # split the compound word once to isolate the word
    new_words = re.split("(" + re.escape(target_word) + ")", compound_word, 1)

    # remove blank entries
    new_words = [word for word in new_words if (word != '')]
    norm_new_words = [normWord(word) for word in new_words if (word != '')]

    # the would be location of the target word in line_dict['words']
    targetLocation = new_words.index(target_word) + index

    # modify the list of words to remove the compound word and insert the new words
    line_dict['words'] = line_dict['words'][0:index] + \
        new_words + line_dict['words'][index+1:]

    line_dict['normwords'] = line_dict['normwords'][0:index] + \
        norm_new_words + line_dict['normwords'][index+1:]

    # start indices for the new words
    new_starts = []

    # start of the compound word in the old list
    start = line_dict['starts'][index]

    # creating new starts
    for word in new_words:
        new_starts.append(start)
        start += len(word)

    # add new entries to sequences, targets, and starts
    line_dict['sequences'] = line_dict['sequences'][0:index] + \
        ['NA' for i in range(0, len(new_words))] + \
        line_dict['sequences'][index+1:]
    line_dict['targets'] = line_dict['targets'][0:index] + \
        ['O' for i in range(0, len(new_words))] + \
        line_dict['targets'][index+1:]
    line_dict['starts'] = line_dict['starts'][0: index] + \
        new_starts + line_dict['starts'][index+1:]

    return line_dict, targetLocation


# returns an empty default sentence

def defaultSentence():

    sent = {}
    sent.update({'seq': [], 'words': [], 'starts': [], 'line_num': [],
                 'word_index': [], 'normwords': [], 'targets': [] })

    sent.update({'rels': [], 'relspan': set()})

    return sent

# unpack the information from the line

def getAnnInfo(line_array):

    [sequence, entity, start] = line_array[0:3]  # unpack

    endIndex = 3  # index of the end token (in line_array)

    # figure out endIndex
    while ';' in line_array[endIndex]:
        endIndex += 1

    # the ending character of the annotation
    end = line_array[endIndex]

    # the remaining tokens on the annotation line is treated as a list of words
    word_array = line_array[endIndex + 1:]

    return [sequence, entity, start, end, word_array]

# converts a line to an array of words
# if we want to get indices at which each word starts, pass a starting index

def getWordsArray(line, start=None):

    # replace \n at the end of the line with a space
    if line[-1:]=='\n':
        line = line[:-1] + " "

    # split the line into words (keeping punctuation)
    # for %, *, :, (, ), \,, and ; we can split no matter where they occur
    # for - and . we will split on word space boundaries
    # numeric is handled below
    words_array = re.split('(\s+|[\%\*\:\(\)\,\;]|[\-\.]\s+|[\-]+|[\#]+)', line) #

    # if we want to get starting indices also
    if start != None:

        starts = []
        normwords = []

        # loop through the array, noting down start values
        for word in words_array:
            starts.append(start)
            start += len(word)
            #if len(word)==0:
            #    print("check")
            numeric, newword = isNumber(word)
            if numeric:
                normwords.append(newword)
            else:
                normwords.append(word)

        # remove blank elements, while also removing start indices that correspond to those elements
        try:
            words_array, starts, normwords = zip(
                *filter(lambda tuple: tuple[0].strip() not in '', zip(words_array, starts, normwords)))

        # filter receives an error when all entries are removed
        except:
            words_array, starts, normwords = ([], [], [])

        # remove extra whitespace from each element
        words_array = list(map(lambda x: x.strip(), words_array))
        normwords = list(map(lambda x: x.strip(), normwords))

        assert len(words_array)==len(starts)
        assert len(words_array)==len(normwords)
        return list(words_array), list(starts), list(normwords)

    # we just want the array
    else:

        # remove extra whitespace from each element
        words_array = list(map(lambda x: x.strip(), words_array))

        # remove blank elements
        words_array = [word for word in words_array if (
            word.strip() not in '')]

        return words_array

# returns true if the word contains a numerical digit or is a word number like "six"

def isNumber(word): #, old_method = False):

    # if it contains numerical values
    if not set('0123456789').isdisjoint(word):
        try:
            val = int(word)
            return True, "ORDINAL" # fully a number
        except:
            newword = []
            for i in word:
                # not a number or first token
                if set('0123456789').isdisjoint(i):
                    newword.append(i)
                elif len(newword)==0 or newword[len(newword)-1]!='0':
                    # first token or previous token is not a number
                    newword.append('0')
                # else skip this as we have already written the number
            return True, "".join(newword)
    else:
        # try converting the word to a number
        try:
            val = w2n.word_to_num(word)
        except:
            return False, word
        return True, "ORDINAL"

def normWord(word):
    num, newword = isNumber(word)
    if num:
        return newword
    else:
        return word


In [7]:
t_lines, t_stats, r_lines, r_stats = readAnnFile("/content/drive/My Drive/Datasets/Harvard n2c2 NLP Research Data Sets/2018 (Track 2) ADE and Medication Extraction Challenge/training_20180910/100035.ann")

In [8]:
len(t_lines)

292

In [9]:
t_lines2 = []
for line in tqdm(t_lines):
  if len(line) > 5:
    line.pop(3)
  t_lines2.append(line)

100%|██████████| 292/292 [00:00<00:00, 1296017.74it/s]


In [10]:
t_lines3 = {}
for line2 in t_lines2:
  t_lines3[line2[0]] = line2[1:]

In [11]:
r_lines2 = []
for r_l in r_lines:
  t1 = r_l[2].split(":")[1]
  t2 = r_l[3].split(":")[1]
  new_l = []
  new_l.extend(r_l[:2])
  new_l.extend(t_lines3[t1])
  new_l.extend(t_lines3[t2])
  r_lines2.append(new_l)

In [12]:
r_lines_df = pd.DataFrame(r_lines, columns=['red_id', 'red_type', 'rel_arg_1', 'rel_arg_2'])
r_lines_df.head()

Unnamed: 0,red_id,red_type,rel_arg_1,rel_arg_2
0,R1,Reason-Drug,Arg1:T1,Arg2:T3
1,R4,Route-Drug,Arg1:T5,Arg2:T3
2,R5,Strength-Drug,Arg1:T7,Arg2:T6
3,R6,Route-Drug,Arg1:T8,Arg2:T6
4,R7,Frequency-Drug,Arg1:T9,Arg2:T6


In [13]:
t_lines_df = pd.DataFrame(t_lines2, columns=['entity_id', 'entity_type', 'entity_start', 'entity_end', 'entity'])
t_lines_df.head()

Unnamed: 0,entity_id,entity_type,entity_start,entity_end,entity
0,T296,Drug,164,171,Vicodin
1,T344,Drug,739,742,CTX
2,T345,Drug,744,756,azithromycin
3,T346,Route,758,760,SC
4,T347,Drug,761,772,epinephrine


In [14]:
r_lines_df2 = pd.DataFrame(r_lines2, columns=['red_id', 'red_type', 'entity_1_type', 'entity_1_start', 'entity_1_end', 'entity_1', 'entity_2_type', 'entity_2_start', 'entity_2_end', 'entity_2'])
r_lines_df2.head()

Unnamed: 0,red_id,red_type,entity_1_type,entity_1_start,entity_1_end,entity_1,entity_2_type,entity_2_start,entity_2_end,entity_2
0,R1,Reason-Drug,Reason,10179,10197,recurrent seizures,Drug,10227,10233,ativan
1,R4,Route-Drug,Route,10240,10242,IM,Drug,10227,10233,ativan
2,R5,Strength-Drug,Strength,10466,10470,25mg,Drug,10455,10465,Topiramate
3,R6,Route-Drug,Route,10471,10473,PO,Drug,10455,10465,Topiramate
4,R7,Frequency-Drug,Frequency,10474,10477,BID,Drug,10455,10465,Topiramate


## Work on Creating Training Sentences

In [None]:
import os
import argparse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
    default_data_collator,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)
from datasets import load_from_disk
import torch

import bitsandbytes as bnb
from huggingface_hub import login, HfFolder

model_id = "meta-llama/Llama-2-13b-hf" # sharded weights
model_id = model_id
dataset_path = '/opt/ml/input/data/training'
epochs = 3
per_device_train_batch_size = 2
lr = 2e-4
hf_token = HfFolder.get_token()
merge_weights = True


def parse_arge():
    """Parse the arguments."""
    parser = argparse.ArgumentParser()
    # add model id and dataset path argument
    parser.add_argument(
        "--model_id",
        type=str,
        help="Model id to use for training.",
    )
    parser.add_argument(
        "--dataset_path", type=str, default="lm_dataset", help="Path to dataset."
    )
    parser.add_argument(
        "--hf_token", type=str, default=HfFolder.get_token(), help="Path to dataset."
    )
    # add training hyperparameters for epochs, batch size, learning rate, and seed
    parser.add_argument(
        "--epochs", type=int, default=3, help="Number of epochs to train for."
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=1,
        help="Batch size to use for training.",
    )
    parser.add_argument(
        "--lr", type=float, default=5e-5, help="Learning rate to use for training."
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Seed to use for training."
    )
    parser.add_argument(
        "--gradient_checkpointing",
        type=bool,
        default=True,
        help="Path to deepspeed config file.",
    )
    parser.add_argument(
        "--bf16",
        type=bool,
        default=True if torch.cuda.get_device_capability()[0] == 8 else False,
        help="Whether to use bf16.",
    )
    parser.add_argument(
        "--merge_weights",
        type=bool,
        default=True,
        help="Whether to merge LoRA weights with base model.",
    )
    args, _ = parser.parse_known_args()

    if args.hf_token:
        print(f"Logging into the Hugging Face Hub with token {args.hf_token[:10]}...")
        login(token=args.hf_token)

    return args


# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
def print_trainable_parameters(model, use_4bit=False):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        # if using DS Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel

        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params
    if use_4bit:
        trainable_params /= 2
    print(
        f"all params: {all_param:,d} || trainable params: {trainable_params:,d} || trainable%: {100 * trainable_params / all_param}"
    )


# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)


def create_peft_model(model, gradient_checkpointing=True, bf16=True):
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        prepare_model_for_kbit_training,
    )
    from peft.tuners.lora import LoraLayer

    # prepare int-4 model for training
    model = prepare_model_for_kbit_training(
        model, use_gradient_checkpointing=gradient_checkpointing
    )
    if gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # get lora target modules
    modules = find_all_linear_names(model)
    print(f"Found {len(modules)} modules to quantize: {modules}")

    peft_config = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=modules,
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    model = get_peft_model(model, peft_config)

    # pre-process the model by upcasting the layer norms in float 32 for
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if bf16:
                module = module.to(torch.bfloat16)
        if "norm" in name:
            module = module.to(torch.float32)
        if "lm_head" in name or "embed_tokens" in name:
            if hasattr(module, "weight"):
                if bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    model.print_trainable_parameters()
    return model


def training_function(args):
    # set seed
    set_seed(args.seed)

    dataset = load_from_disk(args.dataset_path)
    # load model from the hub with a bnb config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        use_cache=False
        if args.gradient_checkpointing
        else True,  # this is needed for gradient checkpointing
        device_map="auto",
        quantization_config=bnb_config,
    )

    # create peft config
    model = create_peft_model(
        model, gradient_checkpointing=args.gradient_checkpointing, bf16=args.bf16
    )

    # Define training args
    output_dir = "/tmp/llama2"
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=args.per_device_train_batch_size,
        bf16=args.bf16,  # Use BF16 if available
        learning_rate=args.lr,
        num_train_epochs=args.epochs,
        gradient_checkpointing=args.gradient_checkpointing,
        # logging strategies
        logging_dir=f"{output_dir}/logs",
        logging_strategy="steps",
        logging_steps=10,
        save_strategy="no",
    )

    # Create Trainer instance
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=default_data_collator,
    )

    # Start training
    trainer.train()

    sagemaker_save_dir="/opt/ml/model/"
    if args.merge_weights:
        # merge adapter weights with base model and save
        # save int 4 model
        trainer.model.save_pretrained(output_dir, safe_serialization=False)
        # clear memory
        del model
        del trainer
        torch.cuda.empty_cache()

        from peft import AutoPeftModelForCausalLM

        # load PEFT model in fp16
        model = AutoPeftModelForCausalLM.from_pretrained(
            output_dir,
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16,
        )
        # Merge LoRA and base model and save
        model = model.merge_and_unload()
        model.save_pretrained(
            sagemaker_save_dir, safe_serialization=True, max_shard_size="2GB"
        )
    else:
        trainer.model.save_pretrained(
            sagemaker_save_dir, safe_serialization=True
        )

    # save tokenizer for easy inference
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    tokenizer.save_pretrained(sagemaker_save_dir)


def main():
    args = parse_arge()
    training_function(args)


if __name__ == "__main__":
    main()