In [None]:
!pip install \
accelerate==0.23.0 bitsandbytes==0.41.1 datasets==2.13.0 openai==0.28.1 \
peft==0.4.0 safetensors==0.4.0 transformers==4.34.0 trl==0.4.7


In [34]:
from datasets import Dataset

# Initialize empty lists to store data
ids = []
dialogues = []
summaries = []

# Open the JSONL file and read its contents line by line
with open("train.jsonl", "r") as file:
    for line in file:
        # Parse each JSON object in the JSONL file
        data = json.loads(line)
        # Extract values for 'idx', 'inputs', and 'target'
        idx = data["idx"]
        dialogue = data["inputs"]
        summary = data["target"]
        # Append the values to respective lists
        ids.append(idx)
        dialogues.append(dialogue)
        summaries.append(summary)

# Create a Hugging Face dataset using the lists of data
dataset = Dataset.from_dict({
    "id": ids,
    "dialogue": dialogues,
    "summary": summaries
})
train_dataset = dataset

In [33]:
dataset[0]

{'id': 0,
 'dialogue': 'The lungs are clear, and without focal air space opacity. The cardiomediastinal silhouette is normal in size and contour, and stable. There is no pneumothorax or large pleural effusion.',
 'summary': 'No acute cardiopulmonary abnormality.'}

In [38]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training

model_id = "meta-llama/Llama-2-7b-chat-hf"

# 
# load model in NF4 quantization with double quantization,
# set compute dtype to bfloat16
# 
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    # bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    # use_cache=False,
    # device_map="auto",
)
# model = prepare_model_for_kbit_training(model)
model.config.use_cache = False

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.81s/it]


In [36]:
def prompt_formatter(sample):
    return f"""<s>### Instruction:
You are a helpful, respectful and honest assistant. \
Your task is to summarize the following dialogue. \
Your answer should be based on the provided dialogue only.

### Dialogue:
{sample['dialogue']}

### Summary:
{sample['summary']} </s>"""

n = 0
print(prompt_formatter(train_dataset[n]))

<s>### Instruction:
You are a helpful, respectful and honest assistant. Your task is to summarize the following dialogue. Your answer should be based on the provided dialogue only.

### Dialogue:
The lungs are clear, and without focal air space opacity. The cardiomediastinal silhouette is normal in size and contour, and stable. There is no pneumothorax or large pleural effusion.

### Summary:
No acute cardiopulmonary abnormality. </s>


In [41]:
from transformers import TrainingArguments, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer

# 
# construct a Peft model.
# the QLoRA paper recommends LoRA dropout = 0.05 for small models (7B, 13B)
# 
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM", 
)
# model = get_peft_model(model, peft_config)

# 
# set up the trainer
# 
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

args = TrainingArguments(
    output_dir="llama2-7b-chat-opr",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    logging_steps=4,
    save_strategy="epoch",
    learning_rate=2e-4,
    optim="paged_adamw_32bit",
    bf16=True,
    fp16=False,
    tf32=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    disable_tqdm=False,
)
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    max_seq_length=1024,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=prompt_formatter, 
    args=args,
)

ImportError: cannot import name '_LazyModule' from 'trl.import_utils' (/home/hrudayte.akkalad/.local/lib/python3.8/site-packages/trl/import_utils.py)

In [40]:
!pip install -q -U trl transformers accelerate git+https://github.com/huggingface/peft.git
!pip install -q datasets bitsandbytes einops wandb

  You can safely remove it manually.[0m[33m
  You can safely remove it manually.[0m[33m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
llama-index-llms-huggingface 0.1.4 requires huggingface-hub<0.21.0,>=0.20.3, but you have huggingface-hub 0.22.2 which is incompatible.
llama-index-llms-huggingface 0.1.4 requires torch<3.0.0,>=2.1.2, but you have torch 2.0.1 which is incompatible.[0m[31m
[0m

In [42]:
!pip install 'lightning-flash[text]' --upgrade

Defaulting to user installation because normal site-packages is not writeable
Collecting lightning-flash[text]
  Downloading lightning_flash-0.8.2-py3-none-any.whl.metadata (27 kB)
Collecting torchmetrics<0.11.0,>0.7.0 (from lightning-flash[text])
  Downloading torchmetrics-0.10.3-py3-none-any.whl.metadata (15 kB)
Collecting pytorch-lightning<2.0.0,>1.8.0 (from lightning-flash[text])
  Downloading pytorch_lightning-1.9.5-py3-none-any.whl.metadata (23 kB)
Collecting pyDeprecate>0.2.0 (from lightning-flash[text])
  Downloading pyDeprecate-0.3.2-py3-none-any.whl.metadata (10 kB)
Collecting jsonargparse>=4.22.0 (from jsonargparse[signatures]>=4.22.0->lightning-flash[text])
  Downloading jsonargparse-4.27.7-py3-none-any.whl.metadata (12 kB)
Collecting lightning-utilities>=0.4.1 (from lightning-flash[text])
  Downloading lightning_utilities-0.11.2-py3-none-any.whl.metadata (4.7 kB)
Collecting ftfy (from lightning-flash[text])
  Downloading ftfy-6.2.0-py3-none-any.whl.metadata (7.3 kB)
Collec

In [None]:
!pip install -U torch

Defaulting to user installation because normal site-packages is not writeable
Collecting torch
  Downloading torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl.metadata (25 kB)
Collecting triton==2.2.0 (from torch)
  Using cached triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl (755.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m755.5/755.5 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hUsing cached triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)
Installing collected packages: triton, torch
  Attempting uninstall: triton
    Found existing installation: triton 2.0.0
    Uninstalling triton-2.0.0:
      Successfully uninstalled triton-2.0.0
  You can safely remove it manually.[0m[33m
[0m  Attempting uninstall: torch
    Found existing installation: torch 2.0.1
    Uninstalling torch-2.0.1:
      Successfully uninstalled 