# Installs

Installing libraries 1by1 due to some compatibility issues encountered on Windows

In [1]:
!pip3 install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [2]:
!pip install -q transformers

In [3]:
!pip install -q trl

In [4]:
!pip install -q pandas

In [5]:
!pip install -q datasets

In [6]:
!pip install -q huggingface_hub

In [7]:
!pip install -q bitsandbytes

In [8]:
!pip install -q peft

In [9]:
!pip install -q accelerate

In [11]:
!pip install -q setuptools

In [12]:
!pip install -q ipywidgets

In [14]:
!pip install -q python-dotenv

In [34]:
!pip install -q tensorboard

## Freeze requirements

In [43]:
!pip freeze > requirements.txt

# Libs Params

In [1]:
import json
import re
from pprint import pprint

import pandas as pd
import torch
from datasets import Dataset, load_dataset
from huggingface_hub import notebook_login
from peft import LoraConfig, PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from trl import SFTTrainer

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
#MODEL_NAME = "NousResearch/Llama-2-7b-chat-hf"
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

Device: cuda:0


In [2]:
OUTPUT_DIR = "experiments"

# HugFace login

In [3]:
from dotenv import load_dotenv
import os
from huggingface_hub import HfApi
load_dotenv()
hf_token = os.getenv("HUGGINGFACE_TOKEN")
if hf_token is None:
    raise ValueError("Hugging Face token not found. Please check your .env file.")

api = HfApi()
user_info = api.whoami(token=hf_token)
print("Logged in as:", user_info['name'])

Logged in as: bnalyv


# Dataset

In [4]:
dataset = load_dataset("Salesforce/dialogstudio", "TweetSumm")
dataset

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],
        num_rows: 879
    })
    validation: Dataset({
        features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],
        num_rows: 110
    })
    test: Dataset({
        features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],
        num_rows: 110
    })
})

# Functions

In [5]:
DEFAULT_SYSTEM_PROMPT = """
Below is a conversation between a human and an AI agent. Write a summary of the conversation.
""".strip()


def generate_training_prompt(
    conversation: str, summary: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
    return f"""### Instruction: {system_prompt}

### Input:
{conversation.strip()}

### Response:
{summary}
""".strip()

In [6]:
def clean_text(text):
    text = re.sub(r"http\S+", "", text)
    text = re.sub(r"@[^\s]+", "", text)
    text = re.sub(r"\s+", " ", text)
    return re.sub(r"\^[^ ]+", "", text)


def create_conversation_text(data_point):
    text = ""
    for item in data_point["log"]:
        user = clean_text(item["user utterance"])
        text += f"user: {user.strip()}\n"

        agent = clean_text(item["system response"])
        text += f"agent: {agent.strip()}\n"

    return text

In [7]:
def generate_text(data_point):
    summaries = json.loads(data_point["original dialog info"])["summaries"][
        "abstractive_summaries"
    ]
    summary = summaries[0]
    summary = " ".join(summary)

    conversation_text = create_conversation_text(data_point)
    return {
        "conversation": conversation_text,
        "summary": summary,
        "text": generate_training_prompt(conversation_text, summary),
    }

In [8]:
example = generate_text(dataset["train"][0])
print(example["summary"])
print("\n ############## \n")
print(example["conversation"])
print("\n ############## \n")
print(example["text"])

Customer enquired about his Iphone and Apple watch which is not showing his any steps/activity and health activities. Agent is asking to move to DM and look into it.

 ############## 

user: So neither my iPhone nor my Apple Watch are recording my steps/activity, and Health doesn’t recognise either source anymore for some reason. Any ideas? please read the above.
agent: Let’s investigate this together. To start, can you tell us the software versions your iPhone and Apple Watch are running currently?
user: My iPhone is on 11.1.2, and my watch is on 4.1.
agent: Thank you. Have you tried restarting both devices since this started happening?
user: I’ve restarted both, also un-paired then re-paired the watch.
agent: Got it. When did you first notice that the two devices were not talking to each other. Do the two devices communicate through other apps such as Messages?
user: Yes, everything seems fine, it’s just Health and activity.
agent: Let’s move to DM and look into this a bit more. When

In [9]:
def process_dataset(data: Dataset):
    return (
        data.shuffle(seed=42)
        .map(generate_text)
        .remove_columns(
            [
                "original dialog id",
                "new dialog id",
                "dialog index",
                "original dialog info",
                "log",
                "prompt",
            ]
        )
    )


In [10]:
dataset["train"] = process_dataset(dataset["train"])
dataset["validation"] = process_dataset(dataset["validation"])

# Model

In [11]:
def create_model_and_tokenizer():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        use_safetensors=True,
        quantization_config=bnb_config,
        trust_remote_code=True,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    return model, tokenizer

In [29]:
model, tokenizer = create_model_and_tokenizer()
model.config.use_cache = False

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [30]:
model.config.quantization_config.to_dict()

{'quant_method': <QuantizationMethod.BITS_AND_BYTES: 'bitsandbytes'>,
 '_load_in_8bit': False,
 '_load_in_4bit': True,
 'llm_int8_threshold': 6.0,
 'llm_int8_skip_modules': None,
 'llm_int8_enable_fp32_cpu_offload': False,
 'llm_int8_has_fp16_weight': False,
 'bnb_4bit_quant_type': 'nf4',
 'bnb_4bit_use_double_quant': False,
 'bnb_4bit_compute_dtype': 'float16',
 'bnb_4bit_quant_storage': 'uint8',
 'load_in_4bit': True,
 'load_in_8bit': False}

In [32]:
lora_r = 16
lora_alpha = 64
lora_dropout = 0.1
lora_target_modules = [
    "q_proj",
    "up_proj",
    "o_proj",
    "k_proj",
    "down_proj",
    "gate_proj",
    "v_proj",
]


peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    target_modules=lora_target_modules,
    bias="none",
    task_type="CAUSAL_LM",
)

# Tensorboard

In [38]:
# In VSCode To start a TensorBoard session, open the Command Palette (Ctrl+Shift+P) and search for the command Python: Launch TensorBoard.

In [37]:
%load_ext tensorboard
%tensorboard --logdir experiments/runs

Launching TensorBoard...

# Training params

In [39]:
training_arguments = TrainingArguments(
    per_device_train_batch_size=4, #2,
    gradient_accumulation_steps=4, #8,
    optim="paged_adamw_32bit",
    logging_steps=1,
    learning_rate=1e-4,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=2,
    evaluation_strategy="steps",
    eval_steps=0.2,
    warmup_ratio=0.05,
    save_strategy="epoch",
    group_by_length=True,
    output_dir=OUTPUT_DIR,
    report_to="tensorboard",
    save_safetensors=True, # False,
    lr_scheduler_type="cosine",
    seed=42,
)



In [40]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=4096,
    tokenizer=tokenizer,
    args=training_arguments,
)

## Training

In [42]:
trainer.train()

  0%|          | 0/110 [01:41<?, ?it/s]

[A                                              

{'loss': 2.7044, 'grad_norm': 1.478585124015808, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.02}



[A                                              

{'loss': 2.764, 'grad_norm': 1.8037716150283813, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.04}



[A                                              

{'loss': 2.8332, 'grad_norm': 1.983026385307312, 'learning_rate': 5e-05, 'epoch': 0.05}



[A                                              

{'loss': 2.6968, 'grad_norm': 1.8106119632720947, 'learning_rate': 6.666666666666667e-05, 'epoch': 0.07}



[A                                              

{'loss': 2.741, 'grad_norm': 1.552841067314148, 'learning_rate': 8.333333333333334e-05, 'epoch': 0.09}



[A                                            

{'loss': 2.6046, 'grad_norm': 1.400054693222046, 'learning_rate': 0.0001, 'epoch': 0.11}



[A                                            

{'loss': 2.5409, 'grad_norm': 1.6099519729614258, 'learning_rate': 9.997718922447667e-05, 'epoch': 0.13}



[A                                            

{'loss': 2.4678, 'grad_norm': 1.7097511291503906, 'learning_rate': 9.990877771116589e-05, 'epoch': 0.15}



[A                                            

{'loss': 2.4089, 'grad_norm': 1.5856293439865112, 'learning_rate': 9.979482788085454e-05, 'epoch': 0.16}



[A                                             

{'loss': 2.3549, 'grad_norm': 1.4892678260803223, 'learning_rate': 9.96354437049027e-05, 'epoch': 0.18}



[A                                             

{'loss': 2.2831, 'grad_norm': 1.6438945531845093, 'learning_rate': 9.943077061037671e-05, 'epoch': 0.2}



[A                                             

{'loss': 2.2406, 'grad_norm': 1.2538846731185913, 'learning_rate': 9.918099534735718e-05, 'epoch': 0.22}



[A                                             

{'loss': 2.0412, 'grad_norm': 1.092948317527771, 'learning_rate': 9.888634581854234e-05, 'epoch': 0.24}



[A                                               

{'loss': 2.4651, 'grad_norm': 0.8644911050796509, 'learning_rate': 9.85470908713026e-05, 'epoch': 0.25}



[A                                               

{'loss': 2.4434, 'grad_norm': 0.7748653292655945, 'learning_rate': 9.816354005237583e-05, 'epoch': 0.27}



[A                                               

{'loss': 2.2468, 'grad_norm': 0.6788894534111023, 'learning_rate': 9.773604332542729e-05, 'epoch': 0.29}



[A                                               

{'loss': 2.2555, 'grad_norm': 0.7340034246444702, 'learning_rate': 9.726499075173201e-05, 'epoch': 0.31}



[A                                             

{'loss': 2.2697, 'grad_norm': 0.7143035531044006, 'learning_rate': 9.675081213427076e-05, 'epoch': 0.33}



[A                                             

{'loss': 2.2724, 'grad_norm': 0.7462705969810486, 'learning_rate': 9.619397662556435e-05, 'epoch': 0.35}



[A                                             

{'loss': 2.2295, 'grad_norm': 0.7249129414558411, 'learning_rate': 9.559499229960451e-05, 'epoch': 0.36}



[A                                             

{'loss': 2.2576, 'grad_norm': 0.8157969117164612, 'learning_rate': 9.495440568827129e-05, 'epoch': 0.38}



[A                                             

{'loss': 2.2234, 'grad_norm': 0.8539525866508484, 'learning_rate': 9.42728012826605e-05, 'epoch': 0.4}


                                               
                                                

{'eval_loss': 2.191922903060913, 'eval_runtime': 211.4966, 'eval_samples_per_second': 0.52, 'eval_steps_per_second': 0.066, 'epoch': 0.4}



[A                                               

{'loss': 2.1117, 'grad_norm': 0.7812219858169556, 'learning_rate': 9.355080099977578e-05, 'epoch': 0.42}



[A                                               

{'loss': 2.11, 'grad_norm': 0.8033677935600281, 'learning_rate': 9.278906361507238e-05, 'epoch': 0.44}



[A                                               

{'loss': 1.9626, 'grad_norm': 0.8067130446434021, 'learning_rate': 9.19882841613699e-05, 'epoch': 0.45}



[A                                             

{'loss': 2.0149, 'grad_norm': 0.8739387392997742, 'learning_rate': 9.114919329468282e-05, 'epoch': 0.47}



[A                                               

{'loss': 2.3746, 'grad_norm': 0.7308202981948853, 'learning_rate': 9.02725566275473e-05, 'epoch': 0.49}



[A                                               

{'loss': 2.321, 'grad_norm': 0.7110188603401184, 'learning_rate': 8.935917403045251e-05, 'epoch': 0.51}



[A                                               

{'loss': 2.3524, 'grad_norm': 0.7251459956169128, 'learning_rate': 8.840987890201403e-05, 'epoch': 0.53}



[A                                             

{'loss': 2.1036, 'grad_norm': 0.7328673005104065, 'learning_rate': 8.742553740855506e-05, 'epoch': 0.55}



[A                                             

{'loss': 2.1939, 'grad_norm': 0.706255316734314, 'learning_rate': 8.640704769378942e-05, 'epoch': 0.56}



[A                                             

{'loss': 2.2757, 'grad_norm': 0.7422966957092285, 'learning_rate': 8.535533905932738e-05, 'epoch': 0.58}



[A                                             

{'loss': 2.1565, 'grad_norm': 0.7803221940994263, 'learning_rate': 8.427137111675199e-05, 'epoch': 0.6}



[A                                             

{'loss': 2.2209, 'grad_norm': 0.7264162302017212, 'learning_rate': 8.315613291203976e-05, 'epoch': 0.62}



[A                                             

{'loss': 2.1351, 'grad_norm': 0.7336671352386475, 'learning_rate': 8.201064202312441e-05, 'epoch': 0.64}



[A                                             

{'loss': 2.1795, 'grad_norm': 0.7708360552787781, 'learning_rate': 8.083594363142717e-05, 'epoch': 0.65}



[A                                             

{'loss': 2.1418, 'grad_norm': 0.7936146259307861, 'learning_rate': 7.963310956820085e-05, 'epoch': 0.67}



[A                                             

{'loss': 2.1209, 'grad_norm': 0.7831093072891235, 'learning_rate': 7.840323733655778e-05, 'epoch': 0.69}



[A                                             

{'loss': 1.9416, 'grad_norm': 0.8893934488296509, 'learning_rate': 7.714744911007394e-05, 'epoch': 0.71}



[A                                             

{'loss': 2.4125, 'grad_norm': 0.6160205602645874, 'learning_rate': 7.586689070888284e-05, 'epoch': 0.73}



[A                                             

{'loss': 2.2677, 'grad_norm': 0.701050341129303, 'learning_rate': 7.456273055419388e-05, 'epoch': 0.75}



[A                                             

{'loss': 2.3083, 'grad_norm': 0.7082036733627319, 'learning_rate': 7.323615860218843e-05, 'epoch': 0.76}



[A                                             

{'loss': 2.178, 'grad_norm': 0.6535027027130127, 'learning_rate': 7.188838525826702e-05, 'epoch': 0.78}



[A                                             

{'loss': 2.0814, 'grad_norm': 0.6971606016159058, 'learning_rate': 7.052064027263786e-05, 'epoch': 0.8}


                                               
                                                

{'eval_loss': 2.1341187953948975, 'eval_runtime': 206.0638, 'eval_samples_per_second': 0.534, 'eval_steps_per_second': 0.068, 'epoch': 0.8}



[A                                                

{'loss': 2.1383, 'grad_norm': 0.6975807547569275, 'learning_rate': 6.91341716182545e-05, 'epoch': 0.82}



[A                                               

{'loss': 2.1016, 'grad_norm': 0.6775664687156677, 'learning_rate': 6.773024435212678e-05, 'epoch': 0.84}



[A                                               

{'loss': 2.062, 'grad_norm': 0.7252047657966614, 'learning_rate': 6.631013946104347e-05, 'epoch': 0.85}



[A                                             

{'loss': 2.0483, 'grad_norm': 0.7557873129844666, 'learning_rate': 6.487515269276016e-05, 'epoch': 0.87}



[A                                             

{'loss': 1.9967, 'grad_norm': 0.7667056322097778, 'learning_rate': 6.342659337371885e-05, 'epoch': 0.89}



[A                                             

{'loss': 2.0449, 'grad_norm': 0.8056678771972656, 'learning_rate': 6.19657832143779e-05, 'epoch': 0.91}



[A                                             

{'loss': 2.0389, 'grad_norm': 0.7973714470863342, 'learning_rate': 6.049405510324238e-05, 'epoch': 0.93}



[A                                             

{'loss': 1.8266, 'grad_norm': 0.8740198016166687, 'learning_rate': 5.90127518906953e-05, 'epoch': 0.95}



[A                                             

{'loss': 2.2526, 'grad_norm': 0.6098589301109314, 'learning_rate': 5.752322516373916e-05, 'epoch': 0.96}



[A                                             

{'loss': 2.1233, 'grad_norm': 0.7210639715194702, 'learning_rate': 5.602683401276615e-05, 'epoch': 0.98}



[A                                             

{'loss': 2.0676, 'grad_norm': 0.7802925109863281, 'learning_rate': 5.45249437914819e-05, 'epoch': 1.0}



[A                                             

{'loss': 2.2974, 'grad_norm': 0.6323352456092834, 'learning_rate': 5.3018924871114305e-05, 'epoch': 1.02}



[A                                             

{'loss': 2.1839, 'grad_norm': 0.6044930815696716, 'learning_rate': 5.151015139004445e-05, 'epoch': 1.04}



[A                                             

{'loss': 2.1324, 'grad_norm': 0.6188984513282776, 'learning_rate': 5e-05, 'epoch': 1.05}



[A                                             

{'loss': 2.0337, 'grad_norm': 0.6952685713768005, 'learning_rate': 4.848984860995557e-05, 'epoch': 1.07}



[A                                             

{'loss': 2.1233, 'grad_norm': 0.6400811076164246, 'learning_rate': 4.6981075128885693e-05, 'epoch': 1.09}



[A                                             

{'loss': 2.0069, 'grad_norm': 0.6871439218521118, 'learning_rate': 4.547505620851811e-05, 'epoch': 1.11}



[A                                             

{'loss': 2.0199, 'grad_norm': 0.732101321220398, 'learning_rate': 4.397316598723385e-05, 'epoch': 1.13}



[A                                             

{'loss': 2.0335, 'grad_norm': 0.7072699666023254, 'learning_rate': 4.2476774836260845e-05, 'epoch': 1.15}



[A                                             

{'loss': 2.0504, 'grad_norm': 0.7464339137077332, 'learning_rate': 4.0987248109304714e-05, 'epoch': 1.16}



[A                                             

{'loss': 2.0122, 'grad_norm': 0.7616031765937805, 'learning_rate': 3.950594489675763e-05, 'epoch': 1.18}



[A                                             

{'loss': 1.975, 'grad_norm': 0.7591302394866943, 'learning_rate': 3.803421678562213e-05, 'epoch': 1.2}


                                               
                                                

{'eval_loss': 2.1090619564056396, 'eval_runtime': 154.4578, 'eval_samples_per_second': 0.712, 'eval_steps_per_second': 0.091, 'epoch': 1.2}



[A                                             

{'loss': 1.9858, 'grad_norm': 0.8209894895553589, 'learning_rate': 3.657340662628116e-05, 'epoch': 1.22}



[A                                             

{'loss': 1.8518, 'grad_norm': 0.920203685760498, 'learning_rate': 3.512484730723986e-05, 'epoch': 1.24}



[A                                             

{'loss': 2.2641, 'grad_norm': 0.590314507484436, 'learning_rate': 3.368986053895655e-05, 'epoch': 1.25}



[A                                             

{'loss': 2.2379, 'grad_norm': 0.6060771346092224, 'learning_rate': 3.226975564787322e-05, 'epoch': 1.27}



[A                                             

{'loss': 2.152, 'grad_norm': 0.630269467830658, 'learning_rate': 3.086582838174551e-05, 'epoch': 1.29}



[A                                             

{'loss': 2.1615, 'grad_norm': 0.6592926979064941, 'learning_rate': 2.9479359727362173e-05, 'epoch': 1.31}



[A                                             

{'loss': 1.9826, 'grad_norm': 0.6764864921569824, 'learning_rate': 2.811161474173297e-05, 'epoch': 1.33}



[A                                             

{'loss': 2.0409, 'grad_norm': 0.669304609298706, 'learning_rate': 2.6763841397811573e-05, 'epoch': 1.35}



[A                                             

{'loss': 2.0826, 'grad_norm': 0.698718249797821, 'learning_rate': 2.5437269445806145e-05, 'epoch': 1.36}



[A                                             

{'loss': 2.1196, 'grad_norm': 0.7275718450546265, 'learning_rate': 2.4133109291117156e-05, 'epoch': 1.38}



[A                                             

{'loss': 2.0168, 'grad_norm': 0.7673788666725159, 'learning_rate': 2.2852550889926067e-05, 'epoch': 1.4}



[A                                             

{'loss': 1.8941, 'grad_norm': 0.7455318570137024, 'learning_rate': 2.1596762663442218e-05, 'epoch': 1.42}



[A                                             

{'loss': 1.9751, 'grad_norm': 0.7844941020011902, 'learning_rate': 2.0366890431799167e-05, 'epoch': 1.44}



[A                                             

{'loss': 1.8307, 'grad_norm': 0.8013171553611755, 'learning_rate': 1.9164056368572846e-05, 'epoch': 1.45}



[A                                             

{'loss': 1.7427, 'grad_norm': 0.9451477527618408, 'learning_rate': 1.7989357976875603e-05, 'epoch': 1.47}



[A                                             

{'loss': 2.301, 'grad_norm': 0.5892607569694519, 'learning_rate': 1.684386708796025e-05, 'epoch': 1.49}



[A                                             

{'loss': 2.1529, 'grad_norm': 0.6586807370185852, 'learning_rate': 1.5728628883248007e-05, 'epoch': 1.51}



[A                                             

{'loss': 2.2007, 'grad_norm': 0.6692585349082947, 'learning_rate': 1.4644660940672627e-05, 'epoch': 1.53}



[A                                             

{'loss': 2.0388, 'grad_norm': 0.6866850852966309, 'learning_rate': 1.3592952306210588e-05, 'epoch': 1.55}



[A                                             

{'loss': 2.1597, 'grad_norm': 0.7315359115600586, 'learning_rate': 1.257446259144494e-05, 'epoch': 1.56}



[A                                             

{'loss': 2.1379, 'grad_norm': 0.7751635313034058, 'learning_rate': 1.159012109798598e-05, 'epoch': 1.58}



[A                                             

{'loss': 2.0498, 'grad_norm': 0.7532538175582886, 'learning_rate': 1.0640825969547496e-05, 'epoch': 1.6}


                                               
                                                

{'eval_loss': 2.099321126937866, 'eval_runtime': 186.167, 'eval_samples_per_second': 0.591, 'eval_steps_per_second': 0.075, 'epoch': 1.6}



[A                                             

{'loss': 1.9639, 'grad_norm': 0.7422608137130737, 'learning_rate': 9.7274433724527e-06, 'epoch': 1.62}



[A                                             

{'loss': 2.0264, 'grad_norm': 0.7327744364738464, 'learning_rate': 8.850806705317183e-06, 'epoch': 1.64}



[A                                             

{'loss': 2.0841, 'grad_norm': 0.8130182027816772, 'learning_rate': 8.011715838630112e-06, 'epoch': 1.65}



[A                                             

{'loss': 2.0338, 'grad_norm': 0.8300769329071045, 'learning_rate': 7.21093638492763e-06, 'epoch': 1.67}



[A                                             

{'loss': 1.9501, 'grad_norm': 0.8641918301582336, 'learning_rate': 6.449199000224221e-06, 'epoch': 1.69}



[A                                             

{'loss': 1.732, 'grad_norm': 0.8970391154289246, 'learning_rate': 5.727198717339511e-06, 'epoch': 1.71}



[A                                             

{'loss': 2.2206, 'grad_norm': 0.5811628103256226, 'learning_rate': 5.045594311728707e-06, 'epoch': 1.73}



[A                                             

{'loss': 2.2764, 'grad_norm': 0.6820680499076843, 'learning_rate': 4.405007700395497e-06, 'epoch': 1.75}



[A                                               

{'loss': 2.1793, 'grad_norm': 0.6776357293128967, 'learning_rate': 3.8060233744356633e-06, 'epoch': 1.76}



[A                                               

{'loss': 2.0999, 'grad_norm': 0.6749551892280579, 'learning_rate': 3.249187865729264e-06, 'epoch': 1.78}



[A                                               

{'loss': 2.0935, 'grad_norm': 0.747072160243988, 'learning_rate': 2.7350092482679836e-06, 'epoch': 1.8}



[A                                                

{'loss': 2.0076, 'grad_norm': 0.7249104380607605, 'learning_rate': 2.2639566745727205e-06, 'epoch': 1.82}



[A                                                

{'loss': 2.1134, 'grad_norm': 0.7504172325134277, 'learning_rate': 1.8364599476241862e-06, 'epoch': 1.84}



[A                                                

{'loss': 2.0011, 'grad_norm': 0.757768452167511, 'learning_rate': 1.4529091286973995e-06, 'epoch': 1.85}



[A                                                

{'loss': 1.9962, 'grad_norm': 0.7445720434188843, 'learning_rate': 1.1136541814576573e-06, 'epoch': 1.87}



[A                                                

{'loss': 2.0344, 'grad_norm': 0.7609739899635315, 'learning_rate': 8.190046526428242e-07, 'epoch': 1.89}



[A                                                

{'loss': 1.9219, 'grad_norm': 0.8008071184158325, 'learning_rate': 5.692293896232936e-07, 'epoch': 1.91}



[A                                                

{'loss': 1.9074, 'grad_norm': 0.8290432691574097, 'learning_rate': 3.6455629509730136e-07, 'epoch': 1.93}



[A                                                

{'loss': 1.9532, 'grad_norm': 0.9033920168876648, 'learning_rate': 2.0517211914545254e-07, 'epoch': 1.95}



[A                                                

{'loss': 2.1532, 'grad_norm': 0.6674846410751343, 'learning_rate': 9.12222888341252e-08, 'epoch': 1.96}



[A                                                

{'loss': 1.9895, 'grad_norm': 0.6950060725212097, 'learning_rate': 2.2810775523329773e-08, 'epoch': 1.98}



[A                                                

{'loss': 1.9596, 'grad_norm': 0.8290480375289917, 'learning_rate': 0.0, 'epoch': 2.0}


                                               


{'eval_loss': 2.097778797149658, 'eval_runtime': 41.0854, 'eval_samples_per_second': 2.677, 'eval_steps_per_second': 0.341, 'epoch': 2.0}



100%|██████████| 110/110 [1:05:38<00:00, 35.81s/it]

{'train_runtime': 3938.7867, 'train_samples_per_second': 0.446, 'train_steps_per_second': 0.028, 'train_loss': 2.1541391351006247, 'epoch': 2.0}





TrainOutput(global_step=110, training_loss=2.1541391351006247, metrics={'train_runtime': 3938.7867, 'train_samples_per_second': 0.446, 'train_steps_per_second': 0.028, 'total_flos': 4005082644516864.0, 'train_loss': 2.1541391351006247, 'epoch': 2.0})

In [43]:
trainer.save_model()



In [44]:
trainer.model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 2048)
        (layers): ModuleList(
          (0-21): 22 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_layer): Linear4b

# Save model

In [45]:
from peft import AutoPeftModelForCausalLM

trained_model = AutoPeftModelForCausalLM.from_pretrained(
    OUTPUT_DIR,
    low_cpu_mem_usage=True,
)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [46]:
merged_model = trained_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True)
tokenizer.save_pretrained("merged_model")

('merged_model\\tokenizer_config.json',
 'merged_model\\special_tokens_map.json',
 'merged_model\\tokenizer.json')

# Inference

In [12]:
def generate_prompt(
    conversation: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
    return f"""### Instruction: {system_prompt}

### Input:
{conversation.strip()}

### Response:
""".strip()

In [13]:
examples = []
for data_point in dataset["test"].select(range(5)):
    summaries = json.loads(data_point["original dialog info"])["summaries"][
        "abstractive_summaries"
    ]
    summary = summaries[0]
    summary = " ".join(summary)
    conversation = create_conversation_text(data_point)
    examples.append(
        {
            "summary": summary,
            "conversation": conversation,
            "prompt": generate_prompt(conversation),
        }
    )
test_df = pd.DataFrame(examples)
test_df

Unnamed: 0,summary,conversation,prompt
0,Customer is complaining that the watchlist is ...,user: My watchlist is not updating with new ep...,### Instruction: Below is a conversation betwe...
1,Customer is asking about the ACC to link to th...,"user: hi , my Acc was linked to an old number....",### Instruction: Below is a conversation betwe...
2,Customer is complaining about the new updates ...,user: the new update ios11 sucks. I can’t even...,### Instruction: Below is a conversation betwe...
3,Customer is complaining about parcel service ...,user: FUCK YOU AND YOUR SHITTY PARCEL SERVICE ...,### Instruction: Below is a conversation betwe...
4,The customer says that he is stuck at Staines ...,user: Stuck at Staines waiting for a Reading t...,### Instruction: Below is a conversation betwe...


## Base model

In [14]:
model, tokenizer = create_model_and_tokenizer()

In [24]:
def summarize(model, text: str):
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
    inputs_length = len(inputs["input_ids"][0])
    with torch.inference_mode():
        outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0001)
    return tokenizer.decode(outputs[0][inputs_length:], skip_special_tokens=True)

### Example 1

In [16]:
example = test_df.iloc[0]
print(example.conversation)

user: My watchlist is not updating with new episodes (past couple days). Any idea why?
agent: Apologies for the trouble, Norlene! We're looking into this. In the meantime, try navigating to the season / episode manually.
user: Tried logging out/back in, that didn’t help
agent: Sorry! 😔 We assure you that our team is working hard to investigate, and we hope to have a fix ready soon!
user: Thank you! Some shows updated overnight, but others did not...
agent: We definitely understand, Norlene. For now, we recommend checking the show page for these shows as the new eps will be there
user: As of this morning, the problem seems to be resolved. Watchlist updated overnight with all new episodes. Thank you for your attention to this matter! I love Hulu 💚
agent: Awesome! That's what we love to hear. If you happen to need anything else, we'll be here to support! 💚



In [17]:
print(example.summary)

Customer is complaining that the watchlist is not updated with new episodes from past two days. Agent informed that the team is working hard to investigate to show new episodes on page.


In [18]:
%%time
summary = summarize(model, example.prompt)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


CPU times: total: 2.27 s
Wall time: 2.54 s


In [19]:
pprint(summary)

('\n'
 "agent: Yes, we're glad to hear that the issue has been resolved. If you need "
 "any further assistance, please don't hesitate to reach out to us. We're "
 'always here to help!')


### Example 2

In [20]:
example = test_df.iloc[1]
print(example.conversation)

user: hi , my Acc was linked to an old number. Now I’m asked to verify my Acc , where a code / call wil be sent to my old number. Any way that I can link my Acc to my current number? Pls help
agent: Hi there, we are here to help. We will have a specialist contact you about changing your phone number. Thank you.
user: Thanks. Hope to get in touch soon
agent: That is no problem. Please let us know if you have any further questions in the meantime.
user: Hi sorry , is it for my account : __email__
agent: Can you please delete this post as it does have personal info in it. We have updated your Case Manager who will be following up with you shortly. Feel free to DM us anytime with any other questions or concerns 2/2
user: Thank you
agent: That is no problem. Please do not hesitate to contact us with any further questions. Thank you.



In [21]:
print(example.summary)

Customer is asking about the ACC to link to the current  number. Agent says that they have updated their case manager.


In [25]:
%%time
summary = summarize(model, example.prompt)

CPU times: total: 11 s
Wall time: 11.7 s


In [26]:
pprint(summary)

('\n'
 'agent: Hi there, we are here to help. We will have a specialist contact you '
 'about changing your phone number. Thank you. User: Thanks. Hope to get in '
 'touch soon User: Hi sorry , is it for my account : __email__ User: Can you '
 'please delete this post as it does have personal info in it. User: No '
 'problem. User: Is it for my account : __email__ User: Yes, it is. User: '
 'Great. User: Thank you User: No problem. User: Thank you User: No problem. '
 'User: Thank you User: No problem. User: Thank you User: No problem. User: '
 'Thank you User: No problem. User: Thank you User: No problem. User: Thank '
 'you User: No problem. User: Thank you User: No problem. User: Thank you '
 'User: No problem. User: Thank you User: No problem. User: Thank you User: No '
 'problem. User: Thank you User: No problem. User: Thank you User: No problem. '
 'User: Thank you User: No problem. User: Thank you User: No problem. User: '
 'Thank you User: No problem. User: Thank you User: No p

### Example 3

In [27]:
example = test_df.iloc[2]
print(example.conversation)

user: the new update ios11 sucks. I can’t even use some apps on my phone.
agent: We want your iPhone to work properly, and we are here for you. Which apps are giving you trouble, and which iPhone?
user: 6s. Words with friends Words pro
agent: Do you see app updates in App Store &gt; Updates? Also, are you using iOS 11.0.3?
user: I am using 11.0.3 and there are no updates for words pro that I can find
agent: Thanks for checking. Next, what happens in that app that makes it unusable?
user: It’s says it’s not compatible.
agent: Thanks for confirming this. Send us a DM and we'll work from there:



In [28]:
print(example.summary)

Customer is complaining about the new updates IOS11 and can't even use some apps on phone. Agent asks to send a DM and work from there URL.


In [29]:
%%time
summary = summarize(model, example.prompt)

CPU times: total: 3.75 s
Wall time: 4.13 s


In [30]:
pprint(summary)

('\n'
 "agent: Hi, we're sorry to hear that you're having trouble with the new "
 "update for Words Pro. We've checked the app and found that it's not "
 "compatible with your iPhone. We'll be working from there to resolve the "
 "issue. If you have any other issues, please don't hesitate to reach out to "
 'us. DM us at @wordswordspro.')


## Enhanced model

In [31]:
model = PeftModel.from_pretrained(model, OUTPUT_DIR)

### Example 1

In [32]:
example = test_df.iloc[0]
pprint(example.summary)

('Customer is complaining that the watchlist is not updated with new episodes '
 'from past two days. Agent informed that the team is working hard to '
 'investigate to show new episodes on page.')


In [33]:
print(example.conversation)

user: My watchlist is not updating with new episodes (past couple days). Any idea why?
agent: Apologies for the trouble, Norlene! We're looking into this. In the meantime, try navigating to the season / episode manually.
user: Tried logging out/back in, that didn’t help
agent: Sorry! 😔 We assure you that our team is working hard to investigate, and we hope to have a fix ready soon!
user: Thank you! Some shows updated overnight, but others did not...
agent: We definitely understand, Norlene. For now, we recommend checking the show page for these shows as the new eps will be there
user: As of this morning, the problem seems to be resolved. Watchlist updated overnight with all new episodes. Thank you for your attention to this matter! I love Hulu 💚
agent: Awesome! That's what we love to hear. If you happen to need anything else, we'll be here to support! 💚



In [34]:
summary = summarize(model, example.prompt)

In [35]:
pprint(summary)

('\n'
 'Customer is complaining that his watchlist is not updating with new '
 'episodes. Agent updated that they are looking into this and will get back to '
 'customer soon. Customer asked to check the show page for the new episodes '
 'and they will get back to customer soon. Agent updated that they will be '
 'here to support and will get back to customer soon.\n'
 '\n'
 '### Trackback:\n'
 'user: My watchlist is not updating with new episodes (past couple days). Any '
 "idea why? agent: Apologies for the trouble, Norlene! We're looking into "
 'this. In the meantime, try navigating to the season / episode manually. '
 "agent: Apologies for the trouble, Norlene! We're looking into this. In the "
 'meantime, try navigating to the season / episode manually. agent: Apologies '
 "for the trouble, Norlene! We're looking into this. In the meantime, try "
 'navigating to the season / episode manually. agent: Apologies for the '
 "trouble, Norlene! We're looking into this. In the meantime,

In [36]:
pprint(summary.strip().split("\n")[0])

('Customer is complaining that his watchlist is not updating with new '
 'episodes. Agent updated that they are looking into this and will get back to '
 'customer soon. Customer asked to check the show page for the new episodes '
 'and they will get back to customer soon. Agent updated that they will be '
 'here to support and will get back to customer soon.')


### Example 2

In [37]:
example = test_df.iloc[1]
print(example.summary)

Customer is asking about the ACC to link to the current  number. Agent says that they have updated their case manager.


In [38]:
print(example.conversation)

user: hi , my Acc was linked to an old number. Now I’m asked to verify my Acc , where a code / call wil be sent to my old number. Any way that I can link my Acc to my current number? Pls help
agent: Hi there, we are here to help. We will have a specialist contact you about changing your phone number. Thank you.
user: Thanks. Hope to get in touch soon
agent: That is no problem. Please let us know if you have any further questions in the meantime.
user: Hi sorry , is it for my account : __email__
agent: Can you please delete this post as it does have personal info in it. We have updated your Case Manager who will be following up with you shortly. Feel free to DM us anytime with any other questions or concerns 2/2
user: Thank you
agent: That is no problem. Please do not hesitate to contact us with any further questions. Thank you.



In [39]:
summary = summarize(model, example.prompt)
pprint(summary.strip().split("\n")[0])

('Customer is asking to link his acc to his current number. Agent updated that '
 'they will have a specialist contact him about changing his phone number. '
 'Customer is asking to delete the post as it does have personal info in it. '
 'Agent updated that they have updated their case manager who will be '
 'following up with him shortly. Customer is asking to DM them any other '
 'questions or concerns. Agent updated that that is no problem. Customer is '
 'asking to delete the post as it does have personal info in it. Agent updated '
 'that that is no problem. Customer is asking to DM them any other questions '
 'or concerns. Agent updated that that is no problem. Customer is asking to '
 'link his acc to his current number. Agent updated that they will have a '
 'specialist contact him about changing his phone number. Customer is asking '
 'to delete the post as it does have personal info in it. Agent updated that '
 'that is no problem. Customer is asking to DM them any other ques

### Example 3

In [40]:
example = test_df.iloc[2]
print(example.summary)

Customer is complaining about the new updates IOS11 and can't even use some apps on phone. Agent asks to send a DM and work from there URL.


In [41]:
print(example.conversation)

user: the new update ios11 sucks. I can’t even use some apps on my phone.
agent: We want your iPhone to work properly, and we are here for you. Which apps are giving you trouble, and which iPhone?
user: 6s. Words with friends Words pro
agent: Do you see app updates in App Store &gt; Updates? Also, are you using iOS 11.0.3?
user: I am using 11.0.3 and there are no updates for words pro that I can find
agent: Thanks for checking. Next, what happens in that app that makes it unusable?
user: It’s says it’s not compatible.
agent: Thanks for confirming this. Send us a DM and we'll work from there:



In [42]:
summary = summarize(model, example.prompt)
pprint(summary.strip().split("\n")[0])

('Customer is complaining about the new update ios11 sucks. Agent updated to '
 'DM and asked to send them a DM. Customer updated that he is using 6s and the '
 'app updates in app store &gt; updates. Agent updated to DM and asked to send '
 'them a DM. Customer updated that he is using ios 11.0.3 and there are no '
 'updates for words pro that he can find. Agent updated to DM and asked to '
 "send them a DM. Customer updated that it says it's not compatible. Agent "
 'updated to DM and asked to send them a DM. Customer updated that he is using '
 'ios 11.0.3 and there are no updates for words pro that he can find. Agent '
 'updated to DM and asked to send them a DM. Customer updated that it says '
 "it's not compatible. Agent updated to DM and asked to send them a DM. "
 'Customer updated that he is using ios 11.0.3 and there are no updates for '
 'words pro that he can find. Agent updated to DM and asked to send them a DM. '
 "Customer updated that it says it's not compatible. Agent 

# Conclusion

As expected, fine tuning of TinyLLaMA tends to deliver more coherent results than the base model.