# Setup


## GPUs

In [None]:
# set gpus for qlora training
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

## Imports


In [8]:
import numpy as np
from tqdm import tqdm
import json
import warnings
import gc
import torch

from datasets import (
    load_dataset,
    concatenate_datasets,
    load_from_disk,
    Features,
    Sequence,
    Value,
)
from datasets import logging as ds_logging
from transformers import AutoTokenizer
from transformers import logging as trans_logging

from qlora import train

## Logging


In [9]:
ds_logging.set_verbosity_error()
ds_logging.disable_progress_bar()
trans_logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# Data


## Load datasets


In [10]:
def read_annotations_from_file(path: str, file: str):
    features = Features(
        {
            "PTC": Sequence(feature=Value(dtype="string", id=None), length=-1, id=None),
            "Evidence": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Medium": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Topic": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Cue": Sequence(feature=Value(dtype="string", id=None), length=-1, id=None),
            "Addr": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Message": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Source": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
        }
    )
    ds = load_dataset(
        "json",
        data_files=os.path.join(path, file),
        field="Annotations",
        split="train",
        features=features,
    )
    ds = ds.add_column("FileName", [file] * len(ds))
    return ds

In [11]:
def read_sentences_from_file(path: str, file: str):
    ds = load_dataset(
        "json", data_files=os.path.join(path, file), field="Sentences", split="train"
    )
    ds = ds.add_column("FileName", [file] * len(ds))
    ds = ds.add_column("Sentence", [" ".join(t) for t in ds["Tokens"]])
    return ds

In [12]:
def read_annotations_from_path(path: str):
    dataset = None

    for file in tqdm(sorted(os.listdir(path))):
        if not dataset:
            dataset = read_annotations_from_file(path, file)
        else:
            dataset = concatenate_datasets(
                [dataset, read_annotations_from_file(path, file)]
            )

    return dataset

In [13]:
def read_sentences_from_path(path: str):
    dataset = None

    for file in tqdm(sorted(os.listdir(path))):
        if not dataset:
            dataset = read_sentences_from_file(path, file)
        else:
            dataset = concatenate_datasets(
                [dataset, read_sentences_from_file(path, file)]
            )

    dataset = dataset.add_column("id", range(len(dataset)))
    return dataset

In [14]:
def read_sentences_dataset(ds_name: str):
    path_to_dataset = "./transformed_datasets/" + ds_name + "/sentences"

    if os.path.isdir(path_to_dataset):
        result = load_from_disk(path_to_dataset)
    else:
        result = read_sentences_from_path(
            "./SpkAtt-2023/data/"
            + ds_name
            + "/task1"
            + ("_test/" if ds_name == "eval" else "/")
        )
        os.makedirs(path_to_dataset, exist_ok=True)
        result.save_to_disk(path_to_dataset)

    return result

In [15]:
def read_annotations_dataset(ds_name: str):
    path_to_dataset = "./transformed_datasets/" + ds_name + "/annotations"

    if os.path.isdir(path_to_dataset):
        return load_from_disk(path_to_dataset)

    result = read_annotations_from_path(
        "./SpkAtt-2023/data/"
        + ds_name
        + "/task1"
        + ("_test/" if ds_name == "eval" else "/")
    )
    os.makedirs(path_to_dataset, exist_ok=True)
    result.save_to_disk(path_to_dataset)
    return result

In [16]:
train_sentences_dataset = read_sentences_dataset("train")
val_sentences_dataset = read_sentences_dataset("dev")
test_sentences_dataset = read_sentences_dataset("eval")

In [17]:
train_annotations_dataset = read_annotations_dataset("train")
val_annotations_dataset = read_annotations_dataset("dev")

## Format datasets for usage in langchain


In [18]:
def get_text_from_label(train_sentences_dataset, row, annotations):
    tokens = []
    for anno in annotations:
        if int(anno.split(":")[0]) == row["SentenceId"]:
            tokens.append(row["Tokens"][int(anno.split(":")[1])])
        else:
            temp_row = train_sentences_dataset.filter(
                lambda r: r["FileName"] == row["FileName"]
                and r["SentenceId"] == int(anno.split(":")[0])
            )[0]
            tokens.append(temp_row["Tokens"][int(anno.split(":")[1])])
    return tokens

In [19]:
def build_complete_dataset(sentences_dataset, annotations_dataset, dataset_name):
    path_to_dataset = "./transformed_datasets/" + dataset_name + "/complete"
    if os.path.isdir(path_to_dataset):
        return load_from_disk(path_to_dataset)

    ptc, ptc_temp, ptc_mapped, ptc_mapped_temp = [], [], [], []
    evidence, evidence_temp, evidence_mapped, evidence_mapped_temp = [], [], [], []
    medium, medium_temp, medium_mapped, medium_mapped_temp = [], [], [], []
    topic, topic_temp, topic_mapped, topic_mapped_temp = [], [], [], []
    cue, cue_temp, cue_mapped, cue_mapped_temp = [], [], [], []
    addr, addr_temp, addr_mapped, addr_mapped_temp = [], [], [], []
    message, message_temp, message_mapped, message_mapped_temp = [], [], [], []
    source, source_temp, source_mapped, source_mapped_temp = [], [], [], []
    (
        sentence_extended,
        tokens_extended,
        sentence_extended_ids,
    ) = (
        [],
        [],
        [],
    )

    index_in_anno_ds = 0

    for i, row in tqdm(enumerate(sentences_dataset)):
        context = row["Sentence"]
        tokens = row["Tokens"]
        ids = [row["SentenceId"]] * len(row["Tokens"])
        if (
            i + 1 < len(sentences_dataset)
            and sentences_dataset[i + 1]["FileName"] == row["FileName"]
        ):
            context = context + " " + sentences_dataset[i + 1]["Sentence"]
            tokens.extend(sentences_dataset[i + 1]["Tokens"])
            ids.extend(
                [sentences_dataset[i + 1]["SentenceId"]]
                * len(sentences_dataset[i + 1]["Tokens"])
            )
        if (
            i + 2 < len(sentences_dataset)
            and sentences_dataset[i + 2]["FileName"] == row["FileName"]
        ):
            context = context + " " + sentences_dataset[i + 2]["Sentence"]
            tokens.extend(sentences_dataset[i + 2]["Tokens"])
            ids.extend(
                [sentences_dataset[i + 2]["SentenceId"]]
                * len(sentences_dataset[i + 2]["Tokens"])
            )
        sentence_extended.append(context)
        tokens_extended.append(tokens)
        sentence_extended_ids.append(ids)

        if annotations_dataset is not None:
            id_of_next_sentence_with_annotation = (
                int(annotations_dataset[index_in_anno_ds]["Cue"][0].split(":")[0])
                if index_in_anno_ds != len(annotations_dataset)
                else -1
            )

            if row["SentenceId"] != id_of_next_sentence_with_annotation:
                ptc.append([])
                ptc_mapped.append([])
                evidence.append([])
                evidence_mapped.append([])
                medium.append([])
                medium_mapped.append([])
                topic.append([])
                topic_mapped.append([])
                cue.append([])
                cue_mapped.append([])
                addr.append([])
                addr_mapped.append([])
                message.append([])
                message_mapped.append([])
                source.append([])
                source_mapped.append([])
                continue

            while row["SentenceId"] == id_of_next_sentence_with_annotation:
                ptc_temp.append(annotations_dataset[index_in_anno_ds]["PTC"])
                evidence_temp.append(annotations_dataset[index_in_anno_ds]["Evidence"])
                medium_temp.append(annotations_dataset[index_in_anno_ds]["Medium"])
                topic_temp.append(annotations_dataset[index_in_anno_ds]["Topic"])
                cue_temp.append(annotations_dataset[index_in_anno_ds]["Cue"])
                addr_temp.append(annotations_dataset[index_in_anno_ds]["Addr"])
                message_temp.append(annotations_dataset[index_in_anno_ds]["Message"])
                source_temp.append(annotations_dataset[index_in_anno_ds]["Source"])

                ptc_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, ptc_temp[-1])
                )
                evidence_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, evidence_temp[-1])
                )
                medium_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, medium_temp[-1])
                )
                topic_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, topic_temp[-1])
                )
                cue_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, cue_temp[-1])
                )
                addr_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, addr_temp[-1])
                )
                message_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, message_temp[-1])
                )
                source_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, source_temp[-1])
                )

                index_in_anno_ds += 1
                if index_in_anno_ds == len(annotations_dataset):
                    break
                id_of_next_sentence_with_annotation = int(
                    annotations_dataset[index_in_anno_ds]["Cue"][0].split(":")[0]
                )

            ptc.append(ptc_temp)
            ptc_mapped.append(ptc_mapped_temp)
            evidence.append(evidence_temp)
            evidence_mapped.append(evidence_mapped_temp)
            medium.append(medium_temp)
            medium_mapped.append(medium_mapped_temp)
            topic.append(topic_temp)
            topic_mapped.append(topic_mapped_temp)
            cue.append(cue_temp)
            cue_mapped.append(cue_mapped_temp)
            addr.append(addr_temp)
            addr_mapped.append(addr_mapped_temp)
            message.append(message_temp)
            message_mapped.append(message_mapped_temp)
            source.append(source_temp)
            source_mapped.append(source_mapped_temp)

            ptc_temp, ptc_mapped_temp = [], []
            evidence_temp, evidence_mapped_temp = [], []
            medium_temp, medium_mapped_temp = [], []
            topic_temp, topic_mapped_temp = [], []
            cue_temp, cue_mapped_temp = [], []
            addr_temp, addr_mapped_temp = [], []
            message_temp, message_mapped_temp = [], []
            source_temp, source_mapped_temp = [], []

    res = sentences_dataset.add_column("sentence_extended", sentence_extended)
    res = res.add_column("tokens_extended", tokens_extended)
    res = res.add_column("sentence_extended_ids", sentence_extended_ids)

    if annotations_dataset is not None:
        res = res.add_column("ptc", ptc)
        res = res.add_column("ptc_mapped", ptc_mapped)
        res = res.add_column("evidence", evidence)
        res = res.add_column("evidence_mapped", evidence_mapped)
        res = res.add_column("medium", medium)
        res = res.add_column("medium_mapped", medium_mapped)
        res = res.add_column("topic", topic)
        res = res.add_column("topic_mapped", topic_mapped)
        res = res.add_column("cue", cue)
        res = res.add_column("cue_mapped", cue_mapped)
        res = res.add_column("addr", addr)
        res = res.add_column("addr_mapped", addr_mapped)
        res = res.add_column("message", message)
        res = res.add_column("message_mapped", message_mapped)
        res = res.add_column("source", source)
        res = res.add_column("source_mapped", source_mapped)

    os.makedirs(path_to_dataset, exist_ok=True)
    res.save_to_disk(path_to_dataset)

    return res

In [20]:
train_ds = build_complete_dataset(
    train_sentences_dataset, train_annotations_dataset, "train"
)
val_ds = build_complete_dataset(val_sentences_dataset, val_annotations_dataset, "dev")
test_ds = build_complete_dataset(test_sentences_dataset, None, "eval")

In [21]:
inputs = test_sentences_dataset.rename_column("Sentence", "Satz")

## Dataset Showcase


In [22]:
train_ds[52]

{'Tokens': ['-',
  'Letzter',
  'Redner',
  'in',
  'der',
  'Debatte',
  ':',
  'Bernd',
  'Westphal',
  'für',
  'die',
  'SPD-Fraktion',
  '.'],
 'SentenceId': 52,
 'FileName': '19002_Zusatzpunkt_3_CDUCSU_Jung_ID19209800_21.11.2017.json',
 'Sentence': '- Letzter Redner in der Debatte : Bernd Westphal für die SPD-Fraktion .',
 'id': 52,
 'sentence_extended': '- Letzter Redner in der Debatte : Bernd Westphal für die SPD-Fraktion .',
 'tokens_extended': ['-',
  'Letzter',
  'Redner',
  'in',
  'der',
  'Debatte',
  ':',
  'Bernd',
  'Westphal',
  'für',
  'die',
  'SPD-Fraktion',
  '.'],
 'sentence_extended_ids': [52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52],
 'ptc': [[]],
 'ptc_mapped': [[]],
 'evidence': [[]],
 'evidence_mapped': [[]],
 'medium': [[]],
 'medium_mapped': [[]],
 'topic': [[]],
 'topic_mapped': [[]],
 'cue': [['52:5']],
 'cue_mapped': [['Debatte']],
 'addr': [[]],
 'addr_mapped': [[]],
 'message': [[]],
 'message_mapped': [[]],
 'source': [[]],
 'source_mapped': 

In [23]:
train_ds[15]

{'Tokens': ['Dazu',
  'muss',
  'man',
  'nur',
  'mit',
  'den',
  'Landwirten',
  'sprechen',
  ',',
  'die',
  'sagen',
  ':',
  'Ja',
  ',',
  'auch',
  'früher',
  'gab',
  'es',
  'extreme',
  'Ereignisse',
  ',',
  'auch',
  'früher',
  'gab',
  'es',
  'Naturkatastrophen',
  ',',
  'aber',
  'in',
  'einem',
  'Jahr',
  'den',
  'Hagel',
  ',',
  'im',
  'anderen',
  'Jahr',
  'eine',
  'Dürre',
  'und',
  'im',
  'dritten',
  'Jahr',
  ',',
  'wie',
  'in',
  'diesem',
  'Jahr',
  ',',
  'die',
  'Frostschäden',
  ',',
  'unter',
  'denen',
  'die',
  'Obstbauern',
  'zu',
  'leiden',
  'hatten',
  ',',
  'diese',
  'Häufung',
  'hatten',
  'wir',
  'früher',
  'so',
  'nicht',
  ',',
  'also',
  'tut',
  'etwas',
  'gegen',
  'den',
  'Klimawandel',
  '.'],
 'SentenceId': 15,
 'FileName': '19002_Zusatzpunkt_3_CDUCSU_Jung_ID19209800_21.11.2017.json',
 'Sentence': 'Dazu muss man nur mit den Landwirten sprechen , die sagen : Ja , auch früher gab es extreme Ereignisse , auch früh

## Build lmsys format json


In [24]:
def map_cues_to_string(mapped):
    if mapped == []:
        return "#UNK#"
    return ", ".join(["[" + ", ".join(val) + "]" for val in mapped])

In [25]:
def map_roles_to_string(mapped):
    if mapped == []:
        return "#UNK#"
    return ", ".join(mapped)

In [26]:
lmsys_data_path = "./lmsys.json"


def build_lmsys_format(train_ds, val_ds):
    result = []

    index = 0
    for row in concatenate_datasets([train_ds, val_ds]):
        if len(row["cue_mapped"]) == 0:
            element = {"id": "identity_" + str(index)}
            index += 1
            conversations = [
                {
                    "from": "human",
                    "value": 'A cue is the lexical items in a sentence that indicate that speech, writing, or thought is being reproduced.\nI want you to extract all cues in the text below.\nIf you find multiple words for one cue, you output them separated by commas.\nIf no cue can be found in the given text, you output the string #UNK# as cue.\nNow extract all cues from the following sentence.\nUse the prefix "Cues: ".\nSentence: '
                    + row["Sentence"],
                },
                {
                    "from": "gpt",
                    "value": "Cues: " + map_cues_to_string(row["cue_mapped"]),
                },
            ]
            element["conversations"] = conversations
            result.append(element)
            continue
        for i, cue in enumerate(row["cue_mapped"]):
            element = {"id": "identity_" + str(index)}
            index += 1
            conversations = [
                {
                    "from": "human",
                    "value": 'A cue is the lexical items in a sentence that indicate that speech, writing, or thought is being reproduced.\nI want you to extract all cues in the text below.\nIf you find multiple words for one cue, you output them separated by commas.\nIf no cue can be found in the given text, you output the string #UNK# as cue.\nNow extract all cues from the following sentence.\nUse the prefix "Cues: ".\nSentence: '
                    + row["Sentence"],
                },
                {
                    "from": "gpt",
                    "value": "Cues: " + map_cues_to_string(row["cue_mapped"]),
                },
                {
                    "from": "human",
                    "value": "Now I give you again the sentence only in addition with the two following sentences, because the roles can be partially contained in the following sentences.\nText: "
                    + row["sentence_extended"]
                    + "\n\nNow find all roles in the sentence associated with the cue '"
                    + ", ".join(cue)
                    + "' you found in the beginning sentence.",
                },
                {
                    "from": "gpt",
                    "value": "cue: "
                    + ", ".join(cue)
                    + "\nptc: "
                    + map_roles_to_string(row["ptc_mapped"][i])
                    + "\nevidence: "
                    + map_roles_to_string(row["evidence_mapped"][i])
                    + "\nmedium: "
                    + map_roles_to_string(row["medium_mapped"][i])
                    + "\ntopic: "
                    + map_roles_to_string(row["topic_mapped"][i])
                    + "\naddr: "
                    + map_roles_to_string(row["addr_mapped"][i])
                    + "\nmessage: "
                    + map_roles_to_string(row["message_mapped"][i])
                    + "\nsource: "
                    + map_roles_to_string(row["source_mapped"][i]),
                },
            ]
            element["conversations"] = conversations
            result.append(element)

    with open(lmsys_data_path, "w", encoding="utf8") as outfile:
        json.dump(result, outfile, indent=3, ensure_ascii=False)

In [27]:
build_lmsys_format(train_ds, val_ds)

# QLoRA Fine-Tuning

## Parse data into required format


In [28]:
parsed_cues_file = "./transformed_datasets/prompts_training/parsed_data_cues.jsonl"
parsed_roles_file = "./transformed_datasets/prompts_training/parsed_data_roles.jsonl"
os.makedirs(os.path.dirname(parsed_cues_file), exist_ok=True)
os.makedirs(os.path.dirname(parsed_roles_file), exist_ok=True)

# token to signal the end of the assistant's response
separator = "</s>"

# reload parsed data
with open(lmsys_data_path) as f:
    data = json.load(f)

# save parsed prompts separately
all_prompts_cues = []
all_prompts_roles = []
for conversation in data:
    # keep track of the complete conversation in order to generate the input of the prompts
    complete_prompt = ""

    for i, turn in enumerate(conversation["conversations"]):
        if turn["from"] == "human":
            complete_prompt += "User: "
            complete_prompt += turn["value"]
        elif turn["from"] == "gpt":
            complete_prompt += "Assistant: "

            # idea
            # turn 0: user prompt for cues
            # turn 1: assistant response with cues
            #   --> create sample with the conversation up to this point as input and the cues as output
            # turn 2: user prompt for roles for one specific cue
            # turn 3: assistant response with roles
            #   --> create sample with the conversation up to this point as input and the roles as output
            # there should be no further turns because we split all conversations with multiple cues into separate conversations

            sample = json.dumps(
                {"input": complete_prompt, "output": turn["value"] + separator}
            )

            if i == 1 and sample not in all_prompts_cues:
                # turn 1: assistant response with cues
                all_prompts_cues.append(sample)
            elif i == 3 and sample not in all_prompts_cues:
                # turn 3: assistant response with roles
                all_prompts_roles.append(sample)
            elif i != 1 and i != 3:
                print(
                    "ERROR: each conversation should maximally contain 4 turns"
                    " and only turn 1 and 3 should be responses by the assistant"
                )

            complete_prompt += turn["value"] + separator
        complete_prompt += "\n"

# write parsed prompts to files
with open(parsed_cues_file, "w") as f:
    f.write("\n".join(all_prompts_cues))

with open(parsed_roles_file, "w") as f:
    f.write("\n".join(all_prompts_roles))

In [29]:
# check that the file with the cue prompts was written correctly
with open(parsed_cues_file) as f:
    lines = f.readlines()

print(f"Number of samples: {len(lines)}\n")

print("First 5 samples:")
for l in lines[:5]:
    print("=== in: ===\n" + json.loads(l)["input"] + "\n")
    print("=== out: ===\n" + json.loads(l)["output"] + "\n")
    print()

Number of samples: 9399

First 5 samples:
=== in: ===
User: A cue is the lexical items in a sentence that indicate that speech, writing, or thought is being reproduced.
I want you to extract all cues in the text below.
If you find multiple words for one cue, you output them separated by commas.
If no cue can be found in the given text, you output the string #UNK# as cue.
Now extract all cues from the following sentence.
Use the prefix "Cues: ".
Sentence: Frau Präsidentin !
Assistant: 

=== out: ===
Cues: #UNK#</s>


=== in: ===
User: A cue is the lexical items in a sentence that indicate that speech, writing, or thought is being reproduced.
I want you to extract all cues in the text below.
If you find multiple words for one cue, you output them separated by commas.
If no cue can be found in the given text, you output the string #UNK# as cue.
Now extract all cues from the following sentence.
Use the prefix "Cues: ".
Sentence: Liebe Kolleginnen und Kollegen !
Assistant: 

=== out: ===
Cu

In [30]:
# check that the file with the role prompts was written correctly
with open(parsed_roles_file) as f:
    lines = f.readlines()

print(f"Number of samples: {len(lines)}\n")

print("First 5 samples:")
for l in lines[:5]:
    print("=== in: ===\n" + json.loads(l)["input"] + "\n")
    print("=== out: ===\n" + json.loads(l)["output"] + "\n")
    print()

Number of samples: 5914

First 5 samples:
=== in: ===
User: A cue is the lexical items in a sentence that indicate that speech, writing, or thought is being reproduced.
I want you to extract all cues in the text below.
If you find multiple words for one cue, you output them separated by commas.
If no cue can be found in the given text, you output the string #UNK# as cue.
Now extract all cues from the following sentence.
Use the prefix "Cues: ".
Sentence: Bundeskanzlerin Angela Merkel hat auf der Klimakonferenz in Bonn gesprochen .
Assistant: Cues: [gesprochen]</s>
User: Now I give you again the sentence only in addition with the two following sentences, because the roles can be partially contained in the following sentences.
Text: Bundeskanzlerin Angela Merkel hat auf der Klimakonferenz in Bonn gesprochen . Sie hat dort den Klimawandel als eine zentrale Herausforderung für die Menschheit bezeichnet . Sie hat von einer Schicksalsfrage gesprochen .

Now find all roles in the sentence ass

## Check optimal source and target lengths

This step is only required if you want to use your own data. If you use the original GermEval 2023 task 1 data, you can skip this step and use the source and target lengths that are already defined in the config files in the `configs` folder (parameters `source_max_len` and `target_max_len`).

If you want to change the maximum source or target lengths, keep in mind that longer prompts mean longer training times and more memory requirements. While it would be best to set the maximum source/target lengths to the maximum lengths of the inputs/outputs, this is not always feasible due to memory constraints. In this case, we recommend choosing maximum lengths that only truncate few samples.


In [31]:
# encode all prompt inputs with the Llama 1 tokenizer (same as the Llama 2 tokenizer)
tokenizer = AutoTokenizer.from_pretrained(
    "huggyllama/llama-7b", padding_side="right", use_fast=False, tokenizer_type="llama"
)

encoded_inputs_cues = []
encoded_inputs_roles = []
encoded_outputs_cues = []
encoded_outputs_roles = []
with open(parsed_cues_file) as f:
    for l in f.readlines():
        enc_in = tokenizer.encode(json.loads(l)["input"])
        encoded_inputs_cues.append(enc_in)
        enc_out = tokenizer.encode(json.loads(l)["output"])
        encoded_outputs_cues.append(enc_out)
with open(parsed_roles_file) as f:
    for l in f.readlines():
        enc_in = tokenizer.encode(json.loads(l)["input"])
        encoded_inputs_roles.append(enc_in)
        enc_out = tokenizer.encode(json.loads(l)["output"])
        encoded_outputs_roles.append(enc_out)

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
# maximum source lengths taken from the config files
max_length_source_cues = 256
max_length_source_roles = 640

print("cues source lengths")
len_enc = [len(e) for e in encoded_inputs_cues]
print(f"max length: {max(len_enc)}")
print(f"mean length: {np.mean(len_enc)}")
print(
    f"number of samples longer than {max_length_source_cues}: {sum(np.array(len_enc) > max_length_source_cues)}"
)
print()

print("roles source lengths")
len_enc = [len(e) for e in encoded_inputs_roles]
print(f"max length: {max(len_enc)}")
print(f"mean length: {np.mean(len_enc)}")
print(
    f"number of samples longer than {max_length_source_roles}: {sum(np.array(len_enc) > max_length_source_roles)}"
)

In [None]:
# maximum target lengths taken from the config files
max_length_target_cues = 64
max_length_target_roles = 256

print("cues target lengths")
len_enc = [len(e) for e in encoded_outputs_cues]
print(f"max length: {max(len_enc)}")
print(f"mean length: {np.mean(len_enc)}")
print(
    f"number of samples longer than {max_length_target_cues}: {sum(np.array(len_enc) > max_length_target_cues)}"
)
print()

print("roles target lengths")
len_enc = [len(e) for e in encoded_outputs_roles]
print(f"max length: {max(len_enc)}")
print(f"mean length: {np.mean(len_enc)}")
print(
    f"number of samples longer than {max_length_target_roles}: {sum(np.array(len_enc) > max_length_target_roles)}"
)

## Train models

This step can be skipped if you already have trained models.

For training, you first have to prepare the Llama 2 models and adapt the configuration. To prepare the Llama 2 models, you will have to make them accessible in HF (Huggingface) format. You can either use the models directly from Huggingface or prepare them yourself by first downloading the model weights from [the official Llama repo](https://github.com/facebookresearch/llama) and then converting these weights using their [conversion manual](https://github.com/facebookresearch/llama-recipes/#model-conversion-to-hugging-face). When using the models from Huggingface, you should add the parameter `use_auth_token` with your Huggingface token to the training configs in the `configs` folder. If you don't want to use the models from Huggingface, once you have prepared the models yourself, update the path to the models in the config files (parameter `model_name_or_path`) in the `configs` folder so the paths point to the folder containing the `pytorch_model-000xx-of-00015.bin` files.

Further configuration parameters:

- `per_device_train_batch_size` and `gradient_accumulation_steps`: With these two parameters you can control the batch size and the number of accumulation steps when calculating the gradients during training. Larger batch sizes should speed up training, but increase memory requirements considerably. We recommend choosing the parameters so that their product `per_device_train_batch_size * gradient_accumulation_steps` is a multiple of 16.
- `save_steps` and `max_steps`: set `max_steps` to control the length of training (`save_steps` determines when checkpoints are created)


In [None]:
# choose config for cue model
cues_training_config = "./configs/7b_cues.args"  # 7b model
# cues_training_config = "./configs/70b_cues.args" # 70b model

train(cues_training_config)

# free vram after training
gc.collect()
torch.cuda.empty_cache()
gc.collect()


In [None]:
# choose config for roles model
roles_training_config = "./configs/7b_roles.args"  # 7b model
# roles_training_config = "./configs/70b_roles.args" # 70b model

train(roles_training_config)

# free vram after training
gc.collect()
torch.cuda.empty_cache()
gc.collect()


In [None]:
# define config files for training
# 7B models
cues_training_config = {
    "model_name_or_path": "meta-llama/Llama-2-7b-hf",
    "output_dir": "./output/spkatt-7b-cues",
    "data_seed": 42,
    "save_steps": 200,
    "evaluation_strategy": "no",
    "dataloader_num_workers": 4,
    "lora_modules": "all",
    "bf16": True,
    "dataset": "transformed_datasets/prompts_training/parsed_data_cues.jsonl",
    "dataset_format": "input-output",
    "source_max_len": 256,
    "target_max_len": 64,
    "per_device_train_batch_size": 16,
    "gradient_accumulation_steps": 1,
    "max_steps": 4000,
    "learning_rate": 0.0002,
    "lora_dropout": 0.1,
    "seed": 0,
}
roles_training_config = {
    "model_name_or_path": "meta-llama/Llama-2-7b-hf",
    "output_dir": "./output/spkatt-7b-roles",
    "data_seed": 42,
    "save_steps": 200,
    "evaluation_strategy": "no",
    "dataloader_num_workers": 4,
    "lora_modules": "all",
    "bf16": True,
    "dataset": "transformed_datasets/prompts_training/parsed_data_roles.jsonl",
    "dataset_format": "input-output",
    "source_max_len": 640,
    "target_max_len": 256,
    "per_device_train_batch_size": 16,
    "gradient_accumulation_steps": 1,
    "max_steps": 4000,
    "learning_rate": 0.0002,
    "lora_dropout": 0.1,
    "seed": 0,
}

# 70B models
# cues_training_config = {"model_name_or_path": "meta-llama/Llama-2-70b-hf",
#                         "output_dir": "./output/spkatt-70b-cues",
#                         "data_seed": 42,
#                         "save_steps": 500,
#                         "evaluation_strategy": "no",
#                         "dataloader_num_workers": 4,
#                         "lora_modules": "all",
#                         "bf16": True,
#                         "dataset": "transformed_datasets/prompts_training/parsed_data_cues.jsonl",
#                         "dataset_format": "input-output",
#                         "source_max_len": 256,
#                         "target_max_len": 64,
#                         "per_device_train_batch_size": 16,
#                         "gradient_accumulation_steps": 1,
#                         "max_steps": 2500,
#                         "learning_rate": 0.0001,
#                         "lora_dropout": 0.05,
#                         "seed": 0,
#                         }
# roles_training_config = {"model_name_or_path": "meta-llama/Llama-2-70b-hf",
#                          "output_dir": "./output/spkatt-70b-roles",
#                          "data_seed": 42,
#                          "save_steps": 500,
#                          "evaluation_strategy": "no",
#                          "dataloader_num_workers": 4,
#                          "lora_modules": "all",
#                          "bf16": True,
#                          "dataset": "transformed_datasets/prompts_training/parsed_data_roles.jsonl",
#                          "dataset_format": "input-output",
#                          "source_max_len": 640,
#                          "target_max_len": 256,
#                          "per_device_train_batch_size": 8,
#                          "gradient_accumulation_steps": 2,
#                          "max_steps": 2500,
#                          "learning_rate": 0.0001,
#                          "lora_dropout": 0.05,
#                          "seed": 0,
#                          }

train(cues_training_config)
# train(roles_training_config)

# Inference

## Loading LLM Chain

In [None]:
# model_name = "/home/ngr/models/llama-2-hf/70b"
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype=torch.bfloat16,
#     device_map=device_map,
#     cache_dir="/home/ngr/.cache/huggingface/hub",
# )
# checkpoint_dir = "/home/ngr/repos/qlora3/experiments/exp008/exp008e/output/spkatt-70b-cues/checkpoint-2000/"
# model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir, "adapter_model"))
# model = model.merge_and_unload()
# tokenizer = LlamaTokenizer.from_pretrained(model_name, legacy=False)
# tokenizer.bos_token_id = 1

# from transformers import pipeline

# pipe = pipeline(
#     task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300
# )
# from langchain import HuggingFacePipeline

# llm = HuggingFacePipeline(pipeline=pipe)
