In [1]:
import os
from loguru import logger
from pathlib import Path
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, load_from_disk
import deepspeed
import vllm
from torch.utils.data import DataLoader

[2023-11-28 20:45:12,350] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
from IPython import get_ipython
from IPython.core.magic import register_line_cell_magic
import GPUtil
from termcolor import colored
@register_line_cell_magic
def vram(line, cell=None):
    "monitor the usage of vram"
    if cell:
        get_ipython().run_cell(cell)
    if line:
        get_ipython().run_cell(line)
    print(colored(
        "| "+" | ".join([f"{i} @ {gpu.memoryUtil*100:.2f}%" for i, gpu in enumerate(GPUtil.getGPUs())]) + " |", 
        "green"
    ))
    # logger.debug("  ".join([f"{i}: {gpu.memoryUtil*100:.2f}%" for i, gpu in enumerate(GPUtil.getGPUs())]))

In [3]:
%%vram
os.environ['CUDA_VISIBLE_DEVICES'] = "5,7"
os.environ['WORLD_SIZE'] = "2"

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.75% | 5 @ 7.73% | 6 @ 7.75% | 7 @ 7.76% |[0m


In [4]:
DATA_PATH = Path("./data/samsum")
MODEL_PATH = Path("../models/gpt2/base/")
WORK_DIR = Path('results/samsum/gpt2-base-kd')
TEACHER_MODEL_PATH = Path("./results/samsum/gpt2-xlarge-sft/checkpoint-4600/")

In [5]:
dataset = load_from_disk(str(DATA_PATH))
logger.debug(dataset)

prompt_template = """[INST] <<SYS>>
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
{dialogue}

Summary:
{summary}
"""

logger.debug("Train data example:\n" + prompt_template.format(**dataset['train'][0]))


tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id

logger.debug(f"The padding token id is {tokenizer.pad_token_id}")

CUTOFF_LEN = 512
LABEL_SPLIT = "Summary:\n"

def generate_and_tokenize_prompt(instance, is_test=False):
    def tokenize(prompt, add_eos_token=True):
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=CUTOFF_LEN,
            padding=True,
            return_tensors=None
        )
        if(
            result['input_ids'][-1] != tokenizer.eos_token_id
            and len(result['input_ids']) < CUTOFF_LEN
            and add_eos_token
        ):
            result['input_ids'].append(tokenizer.eos_token_id)
            result['attention_mask'].append(1)
        result['labels'] = result['input_ids'].copy()
        return result
    tokenized_full_prompt = tokenize(prompt_template.format(**instance))
    tokenized_user_prompt = tokenize(prompt_template.format(**instance).split(LABEL_SPLIT)[0] + LABEL_SPLIT, add_eos_token=False)
    user_prompt_len = len(tokenized_user_prompt['input_ids'])
    tokenized_full_prompt['labels'] = [-100]*user_prompt_len + tokenized_full_prompt['labels'][user_prompt_len:]
    if is_test:
        tokenized_user_prompt['_id'] = instance['id']
        return tokenized_user_prompt
    
    len_labels = len(tokenizer(instance['summary'])['input_ids'])
    tokenized_full_prompt['is_label_complete'] = len(tokenized_full_prompt['labels'][user_prompt_len:]) >= len_labels
    return tokenized_full_prompt

[32m2023-11-28 20:45:12.939[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [34m[1mDatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})[0m
[32m2023-11-28 20:45:12.940[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [34m[1mTrain data example:
[INST] <<SYS>>
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
Amanda: I baked  cookies. Do you want some?
Jerry: Sure!
Amanda: I'll bring you tomorrow :-)

Summary:
Amanda baked cookies and will bring Jerry some tomorrow.
[0m


[32m2023-11-28 20:45:13.006[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m22[0m - [34m[1mThe padding token id is 50256[0m


In [6]:
columns = ['input_ids', 'attention_mask', 'labels']

train_data = dataset['train'].map(generate_and_tokenize_prompt, num_proc=1) \
                             .filter(lambda instance: instance['is_label_complete']) \
                             .select_columns(columns) \
                             .with_format(type='torch')
                           
val_data = dataset['test'].map(generate_and_tokenize_prompt, num_proc=1) \
                          .filter(lambda instance: instance['is_label_complete']) \
                          .select_columns(columns) \
                          .with_format(type='torch', columns=columns)

Loading cached processed dataset at /home/yzhangjy/LLM/llm-kd/data/samsum/train/cache-06823d004bf2423d.arrow
Loading cached processed dataset at /home/yzhangjy/LLM/llm-kd/data/samsum/train/cache-d2b62b440b3151af.arrow
Loading cached processed dataset at /home/yzhangjy/LLM/llm-kd/data/samsum/test/cache-763244c730f8eaad.arrow
Loading cached processed dataset at /home/yzhangjy/LLM/llm-kd/data/samsum/test/cache-dfdf889bd99c1a61.arrow


In [7]:
logger.debug(f"Training data usage: {train_data.num_rows / dataset['train'].num_rows * 100:.4f}%")
logger.debug(f"Validation data usage: {val_data.num_rows / dataset['validation'].num_rows * 100:.4f}%")

[32m2023-11-28 20:45:13.150[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [34m[1mTraining data usage: 97.2305%[0m
[32m2023-11-28 20:45:13.151[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [34m[1mValidation data usage: 96.6993%[0m


In [8]:
# tokenized_summary = dataset['train'].map(lambda x: tokenizer(x['summary'])).remove_columns(dataset['train'].column_names)max([len(ids) for ids in tokenized_summary['input_ids']])
# max([len(ids) for ids in tokenized_summary['input_ids']])
# label_lens = [torch.where(lab==-100, 0, 1).sum().item for lab in tokenized_dataset['train']['labels']]

In [9]:
%%vram
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, 
                                             torch_dtype=torch.float16, 
                                             load_in_8bit=False,
                                             use_cache=False)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.75% | 5 @ 7.73% | 6 @ 7.75% | 7 @ 7.76% |[0m


In [10]:
%%vram
model.cuda(5)
model.device

device(type='cuda', index=5)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.75% | 5 @ 12.09% | 6 @ 7.75% | 7 @ 7.76% |[0m


In [11]:
%%vram
teacher_model = AutoModelForCausalLM.from_pretrained(TEACHER_MODEL_PATH, 
                                             torch_dtype=torch.float16, 
                                             load_in_8bit=False,
                                             use_cache=False)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 12.10% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [12]:
%%vram
teacher_model.cuda(5)
teacher_model.eval()
teacher_model.device

device(type='cuda', index=5)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 24.70% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [13]:
label_pad_token_id = -100
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    return_tensors="pt",
    pad_to_multiple_of=8
)

# %%
dataloader = DataLoader(train_data, 
                        collate_fn=data_collator, 
                        batch_size=16)#, pin_memory=True, pin_memory_device="cuda:7")

In [14]:
%%vram
data = next(dataloader._get_iterator())

for k, v in data.items():
    data[k] = v.cuda(5)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 24.70% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [15]:
data['input_ids'].shape

torch.Size([16, 352])

In [16]:
# Not train on input
# unsqueeze for the convenience of later computations
loss_mask = torch.where(data['labels']==-100, 0, 1).unsqueeze(-1)

In [17]:
%%vram
optimizer = AdamW(model.parameters(), lr=5e-4)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 24.70% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [18]:
%%vram
output = model.forward(**data)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 47.09% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [19]:
%%vram
probs = F.softmax(output.logits, dim=-1, dtype=torch.float32)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 51.49% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [20]:
%%vram
with torch.no_grad():
    output_teacher = teacher_model.forward(**data)

[32m| 0 @ 13.74% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 58.09% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [21]:
%%vram
torch.cuda.empty_cache()

[32m| 0 @ 17.01% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 51.49% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [22]:
%%vram
probs_teacher = F.softmax(output_teacher['logits'], dim=-1, dtype=torch.float32)

[32m| 0 @ 17.01% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 55.89% | 6 @ 8.79% | 7 @ 7.77% |[0m


In [23]:
# inf_mask = torch.isinf(output.logits)

`kd_loss`:
1. compute KL divergence pointwisely  (`reduction="none"`)
2. mask the input part  (`*mask`)
3. keep only output tokens and compute batchmean on them:
    - mean on output tokens (`/ mask.sum(1, keepdim=True)`)
    - batchmean (mathematically correct) on the batch (`.view(len(input),-1).sum(-1).mean()`)

In [24]:
kd_loss = lambda input, target, mask: (F.kl_div(input.log()*mask, target*mask, reduction="none") / mask.sum(1, keepdim=True)).view(len(input),-1).sum(-1).mean()

In [25]:
%%vram
# loss_kd = F.kl_div(probs.log(), probs_teacher, reduction="none")
loss_kd = kd_loss(probs, probs_teacher, loss_mask)

[32m| 0 @ 17.01% | 1 @ 26.67% | 2 @ 15.98% | 3 @ 7.74% | 4 @ 7.76% | 5 @ 77.87% | 6 @ 8.80% | 7 @ 7.77% |[0m


In [27]:
loss_kd

tensor(6.7298, device='cuda:5', grad_fn=<MeanBackward0>)

In [31]:
%%vram
torch.cuda.empty_cache()

[32m| 0 @ 17.01% | 1 @ 26.67% | 2 @ 15.99% | 3 @ 7.75% | 4 @ 7.76% | 5 @ 60.29% | 6 @ 8.80% | 7 @ 8.82% |[0m


In [32]:
%%vram
loss_kd.backward()

[32m| 0 @ 17.53% | 1 @ 26.67% | 2 @ 15.99% | 3 @ 8.79% | 4 @ 7.76% | 5 @ 69.12% | 6 @ 8.80% | 7 @ 8.82% |[0m


In [33]:
%%vram
optimizer.step()

[32m| 0 @ 17.53% | 1 @ 26.67% | 2 @ 15.99% | 3 @ 8.79% | 4 @ 7.76% | 5 @ 69.12% | 6 @ 8.80% | 7 @ 8.82% |[0m


In [34]:
%%vram
optimizer.zero_grad()

[32m| 0 @ 17.53% | 1 @ 26.67% | 2 @ 15.99% | 3 @ 8.79% | 4 @ 7.76% | 5 @ 69.12% | 6 @ 8.80% | 7 @ 8.82% |[0m


In [35]:
%%vram
torch.cuda.empty_cache()

[32m| 0 @ 17.53% | 1 @ 26.67% | 2 @ 15.99% | 3 @ 8.79% | 4 @ 7.76% | 5 @ 47.87% | 6 @ 8.80% | 7 @ 8.82% |[0m
