In [1]:
import os
from loguru import logger
from pathlib import Path
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, load_from_disk
import deepspeed
import vllm
from torch.utils.data import DataLoader

[2023-11-28 09:44:22,747] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [62]:
from pprint import pprint
import GPUtil
import pandas as pd

def gpu_stat():
    print("  ".join([f"{i}: {gpu.memoryUtil*100:.2f}%" for i, gpu in enumerate(GPUtil.getGPUs())]))
    # pprint([": ".join([str(i), str(round(gpu.memoryUtil * 100, 2))]) for i,gpu in enumerate(GPUtil.getGPUs())])

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = "4,5,6,7"
os.environ['WORLD_SIZE'] = "4"

In [4]:
DATA_PATH = Path("./data/samsum")
MODEL_PATH = Path("../models/gpt2/base/")
WORK_DIR = Path('results/samsum/gpt2-base-kd')
TEACHER_MODEL_PATH = Path("./results/samsum/gpt2-xlarge-sft/checkpoint-4600/")

In [5]:
dataset = load_from_disk(str(DATA_PATH))
logger.debug(dataset)

prompt_template = """[INST] <<SYS>>
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
{dialogue}

Summary:
{summary}
"""

logger.debug("Train data example:\n" + prompt_template.format(**dataset['train'][0]))


tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id

logger.debug(f"The padding token id is {tokenizer.pad_token_id}")

CUTOFF_LEN = 256
LABEL_SPLIT = "Summary:\n"

def generate_and_tokenize_prompt(instance, is_test=False):
    def tokenize(prompt, add_eos_token=True):
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=CUTOFF_LEN,
            padding=True,
            return_tensors=None
        )
        if(
            result['input_ids'][-1] != tokenizer.eos_token_id
            and len(result['input_ids']) < CUTOFF_LEN
            and add_eos_token
        ):
            result['input_ids'].append(tokenizer.eos_token_id)
            result['attention_mask'].append(1)
        result['labels'] = result['input_ids'].copy()
        return result
    tokenized_full_prompt = tokenize(prompt_template.format(**instance))
    tokenized_user_prompt = tokenize(prompt_template.format(**instance).split(LABEL_SPLIT)[0] + LABEL_SPLIT, add_eos_token=False)
    user_prompt_len = len(tokenized_user_prompt['input_ids'])
    tokenized_full_prompt['labels'] = [-100]*user_prompt_len + tokenized_full_prompt['labels'][user_prompt_len:]
    if is_test:
        tokenized_user_prompt['_id'] = instance['id']
        return tokenized_user_prompt
    return tokenized_full_prompt

tokenized_dataset = dataset.map(generate_and_tokenize_prompt, num_proc=1) \
                           .remove_columns(dataset['train'].column_names) \
                           .with_format(type='torch')

[32m2023-11-27 22:12:33.334[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [34m[1mDatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})[0m
[32m2023-11-27 22:12:33.335[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [34m[1mTrain data example:
[INST] <<SYS>>
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
Amanda: I baked  cookies. Do you want some?
Jerry: Sure!
Amanda: I'll bring you tomorrow :-)

Summary:
Amanda baked cookies and will bring Jerry some tomorrow.
[0m
[32m2023-11-27 22:12:33.392[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m22[0m - [34m[1mThe padding token id is 50256[0m
Loading cached pro

In [63]:
gpu_stat()

0: 9.83%  1: 21.73%  2: 0.03%  3: 0.03%  4: 0.03%  5: 0.03%  6: 0.03%  7: 33.04%


In [7]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, 
                                             torch_dtype=torch.float16, 
                                             load_in_8bit=False)

In [8]:
model.cuda(7)
model.device

device(type='cuda', index=7)

In [9]:
gpu_stat()

         0     1    2    3    4    5    6    7
mems   6.6  25.7  0.0  0.0  0.0  0.0  0.0  4.4
loads  0.0   0.0  0.0  0.0  0.0  0.0  0.0  6.0


In [10]:
teacher_model = AutoModelForCausalLM.from_pretrained(TEACHER_MODEL_PATH, 
                                             torch_dtype=torch.float16, 
                                             load_in_8bit=False)

In [11]:
teacher_model.cuda(7)
teacher_model.eval()
teacher_model.device

device(type='cuda', index=7)

In [12]:
gpu_stat()

         0     1    2    3    4    5    6     7
mems   6.6  25.7  0.0  0.0  0.0  0.0  0.0  17.0
loads  0.0   0.0  0.0  0.0  0.0  0.0  0.0  10.0


In [13]:
label_pad_token_id = -100
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    return_tensors="pt",
    pad_to_multiple_of=8
)

# %%
dataloader = DataLoader(tokenized_dataset['train'], collate_fn=data_collator, batch_size=8, pin_memory=True, pin_memory_device="cuda:7")

In [14]:
data = next(dataloader._get_iterator())

for k, v in data.items():
    data[k] = v.cuda(7)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [17]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   6.56  25.66  0.03  0.03  0.03  0.03  0.03  16.99
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   0.00


In [18]:
data['input_ids'].shape

torch.Size([8, 256])

In [19]:
optimizer = AdamW(model.parameters(), lr=5e-4)

In [20]:
output = model.forward(**data)

In [21]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   6.56  25.66  0.03  0.03  0.03  0.03  0.03  26.23
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   5.00


In [22]:
with torch.no_grad():
    output_teacher = teacher_model.forward(**data)

In [23]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   6.56  25.66  0.03  0.03  0.03  0.03  0.03  33.29
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   0.00


In [24]:
torch.cuda.empty_cache()

In [25]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   6.56  13.11  0.03  0.03  0.03  0.03  0.03  31.58
loads  1.00   0.00  0.00  0.00  0.00  0.00  0.00   4.00


In [65]:
# inf_mask = torch.isinf(output.logits)

In [26]:
probs = F.softmax(output.logits/5, dim=-1, dtype=torch.float32)
probs_teacher = F.softmax(output_teacher.logits/5, dim=-1, dtype=torch.float32)

In [27]:
gpu_stat()

          0      1     2     3     4     5     6     7
mems   9.83  17.72  0.03  0.03  0.03  0.03  0.03  35.6
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   0.0


In [28]:
loss_kd = F.kl_div(probs.log(), probs_teacher, log_target=False)



In [29]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   9.83  17.72  0.03  0.03  0.03  0.03  0.03  42.01
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   0.00


In [30]:
loss_kd

tensor(5.8731e-06, device='cuda:7', grad_fn=<MeanBackward0>)

In [31]:
torch.cuda.empty_cache()

In [32]:
loss_kd.backward()

In [33]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   9.83  17.72  0.03  0.03  0.03  0.03  0.03  39.64
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   0.00


In [34]:
optimizer.step()

In [35]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   9.83  17.72  0.03  0.03  0.03  0.03  0.03  39.64
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   0.00


In [36]:
optimizer.zero_grad()

In [37]:
gpu_stat()

          0      1     2     3     4     5     6      7
mems   9.83  17.72  0.03  0.03  0.03  0.03  0.03  39.64
loads  0.00   0.00  0.00  0.00  0.00  0.00  0.00   0.00


In [38]:
torch.cuda.empty_cache()

In [42]:
gpu_stat()

[(0, 9.83),
 (1, 17.72),
 (2, 0.03),
 (3, 0.03),
 (4, 0.03),
 (5, 0.03),
 (6, 0.03),
 (7, 33.04)]
