In [25]:
from datasets import Dataset, DatasetDict, load_dataset
import ast

In [26]:
from unsloth import FastLanguageModel

In [27]:
def get_data(csv_path, valid_set_ratio = 0.15):
    dataset = load_dataset('csv', data_files="arcelik_llm_training_set.csv")

    # Split the dataset into training and evaluation sets
    train_test_split = dataset['train'].train_test_split(test_size=0.15)
    dataset = DatasetDict({
        'train': train_test_split['train'],
        'eval': train_test_split['test']
    })
    return dataset

In [28]:

csv_path = "./arcelik_llm_training_set.csv"
base_model_id = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
max_seq_length = 130
load_in_4bit =True

In [29]:
def get_model_and_tokenizer(model_id: str, max_length, load_in_4bit,):
    model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id, # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
    )
    return model, tokenizer

In [30]:
model, tokenizer = get_model_and_tokenizer(base_model_id,
                                           max_seq_length,
                                           load_in_4bit)

==((====))==  Unsloth: Fast Mistral patching release 2024.2
   \\   /|    GPU: NVIDIA RTX A4000. Max memory: 15.731 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.0+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.24. FA = False.
 "-____-"     Apache 2 free license: http://github.com/unslothai/unsloth




In [59]:

dataset = get_data(csv_path)

PROMPT_TEMPLATE = "[INST] Extract span of text from the customer review associated with the topic - {}. Customer Review : '{}'[/INST] {}{}"

In [60]:
max_length = 120

def format_row(row):
    indices = ast.literal_eval(row['indices'])
    span = row['sentence'][indices[0]:indices[1]]
    try:
        assert span == row['span']
    except Exception:
        print(row['sentence'])
        print(span)
        print(row['span'])
        print(indices)
        print("next")
    # formatted_string = f"### Question : Extract span of text from the customer review associated with the topic - {row['topic']}. Customer Review - '{row['sentence']}'\n ### Answer: {span} "
    formatted_string = PROMPT_TEMPLATE.format(row['topic'], row['sentence'], span, )
    result = tokenizer(formatted_string,truncation=True,
        max_length=max_length,
        padding="max_length",)
    new_record = {}
    result["input_ids"].append(tokenizer.eos_token_id)
    new_record ["labels"] = result["input_ids"].copy()
    return new_record

In [61]:
tokenized_train_dataset = dataset['train'].map(format_row)
tokenized_val_dataset = dataset['eval'].map(format_row)

Map:  80%|████████  | 18353/22825 [00:08<00:02, 2201.97 examples/s]

Pros: Space saving for a small kitchen like mine Cons: None
None
None
(55, 59)
next


Map: 100%|██████████| 22825/22825 [00:10<00:00, 2129.33 examples/s]
Map:  66%|██████▌   | 2646/4029 [00:01<00:00, 2004.30 examples/s]

Pros: Its easy to access and space saving Cons: None
None
None
(48, 52)
next


Map: 100%|██████████| 4029/4029 [00:02<00:00, 1745.92 examples/s]


In [46]:
tokenizer.decode(2)

'</s>'

In [65]:
tokenized_train_dataset[5]['labels'][40:]

[16289,
 28793,
 5804,
 633,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2]

In [11]:
tokenized_train_dataset[0]

{'input_ids': [2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  1,
  733,
  16289,
  28793,
  1529,
  2107,
  12363,
  302,
  2245,
  477,
  272,
  6346,
  4058,
  5363,
  395,
  272,
  9067,
  387,
  8382,
  28723,
  16648,
  8349,
  714,
  464,
  28737,
  403,
  6416,
  11572,
  304,
  315,
  2613,
  298,
  737,
  456,
  18401,
  28705,
  562,
  378,
  28735,
  776,
  459,
  354,
  528,
  1815,
  28792,
  28748,
  16289,
  28793,
  459,
  354,
  528,
  28705,
  2],
 'attention_mask': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,