In [1]:
from accelerate.utils import BnbQuantizationConfig
from accelerate import Accelerator, notebook_launcher
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, \
                        get_cosine_schedule_with_warmup, set_seed
import transformers
import optimum

from datasets import load_dataset,Dataset
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, \
                        get_cosine_schedule_with_warmup, set_seed
from peft import LoraConfig, TaskType, get_peft_model
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed
import logging
from torch.utils.data import DataLoader
from torch.optim import AdamW, SGD
from tqdm.notebook import tqdm
import torch
from torch.nn.utils.rnn import pad_sequence
import glob
from collections import OrderedDict
import re

import os

In [2]:
import datetime
start_time = datetime.datetime.now()

In [3]:
def truncate_txt(text, length):
    text_list = text.split()
    
    if len(text_list) <= length:
        return text
    
    return " ".join(text_list[:length])


def gen_prompt(og_text, rewritten_text):
    
    # Truncate the texts to first 200 words for now
    # As we are having memory issues on Mixtral8x7b
    og_text = truncate_txt(og_text, 170)
    rewritten_text = truncate_txt(rewritten_text, 170)
    
    return f"""    
    Original Text:
    \"""{og_text}\"""
    
    Rewritten Text:
    \"""{rewritten_text}\"""
    
    You are given 2 essays, the Rewritten text was created from the Original text using the google Gemma model.
    Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten text.
    Start directly with the prompt, output should be one line only.
    """.strip()

In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split

filename1 = "./input/gemma-rewrite-nbroad/nbroad-v1.csv"
filename2 = "./input/gemma-rewrite-nbroad/nbroad-v2.csv"
df = pd.concat([pd.read_csv(filename1), pd.read_csv(filename2)]).reset_index(drop=True)

df['reverse_prompt'] = df.apply(lambda x: gen_prompt(x.original_text, x.rewritten_text), axis=1)
df = df.iloc[:6]

data = Dataset.from_pandas(df[['reverse_prompt', 'rewrite_prompt']],split='train')

In [5]:
# MODEL_PATH = "/kaggle/input/gemma/transformers/7b-it/2"
MODEL_PATH = "distilbert/distilgpt2"
# MODEL_PATH = "gpt2"
# MODEL_PATH = "distilbert/distilroberta-base"
# MODEL_PATH = 'mistralai/Mistral-7B-Instruct-v0.2'
# MODEL_PATH = "google/gemma-2b-it"
# MODEL_PATH = "google/flan-t5-small"
# MODEL_PATH = "/kaggle/input/mistral/pytorch/7b-instruct-v0.1-hf/1"
# MODEL_PATH = "/kaggle/input/mixtral/pytorch/8x7b-instruct-v0.1-hf/1"
# MODEL_PATH = "/kaggle/input/llama-2/pytorch/7b-chat-hf/1"
# MODEL_PATH = "/kaggle/input/llama-2/pytorch/13b-chat-hf/1"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
def tokenize_samples(samples):
    inputs = tokenizer(samples["reverse_prompt"], max_length=512, truncation=True)
    targets = tokenizer(samples["rewrite_prompt"], max_length=512, truncation=True)
    return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'labels': targets['input_ids']}
    
data = data.map(tokenize_samples, batched=True)

Map:   0%|          | 0/6 [00:00<?, ? examples/s]

In [6]:
encoded = tokenize_samples(df.iloc[0])
tokenizer.decode(encoded['labels']), df.iloc[0].rewrite_prompt

('Regency Romance: Model the text on a Regency romance novel, focusing on social gatherings, romantic pursuits, and the strict manners of the era.',
 'Regency Romance: Model the text on a Regency romance novel, focusing on social gatherings, romantic pursuits, and the strict manners of the era.')

In [7]:
def collate_fn(batch):
    inputs = [torch.tensor(b['input_ids']) for b in batch]
    labels = [torch.tensor(b['labels']) for b in batch]

    max_length = max(len(input_) for input_ in inputs)

    input_ids = pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id)

    padded_labels = []
    for label in labels:
        padded_label = torch.full((max_length,), fill_value=-100, dtype=torch.long)
        padded_label[:len(label)] = label
        padded_labels.append(padded_label)
    padded_labels = torch.stack(padded_labels)

    attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
    attention_mask[input_ids == tokenizer.pad_token_id] = 0

    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': padded_labels}


In [8]:
print(collate_fn)

<function collate_fn at 0x76851d8a3370>


In [9]:
# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig


# # Found a good blog to catch me up fast!
# # https://huggingface.co/blog/4bit-transformers-bitsandbytes
# # https://huggingface.co/docs/transformers/v4.38.1/en/quantization#compute-data-type
# quantization_config = BitsAndBytesConfig(
#     load_in_4bit = True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
# )



# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_PATH,
#     device_map = "auto",
#     trust_remote_code = True,
#     quantization_config=quantization_config,
# )

# # model = model.to_bettertransformer()
# model = accelerator.prepare(model)

In [10]:
def del_past_models(save_path, file_exten='pth'):
    """
    Remove all of the past models
    
    You can change the file_extension if you save the models with other file_extension
    """
    past_models = glob.glob(os.path.join(save_path, '*.' + file_exten))
    for past_model in past_models:
        os.remove(past_model)
        logger.info(f'Remove model {past_model}!')
        
def save_checkpoint(path, model, optim, sched, epoch, iters):
    lr = optim.param_groups[0]['lr']
    # model_state = model.state_dict()
    model_state = OrderedDict((name, param) for name, param in model.named_parameters() \
                    if param.requires_grad)
    
    checkpoint_path = os.path.join(path, f'checkpoint_{iters + 1}.pth')
    new_checkpoint_path = os.path.join(path, f'checkpoint.pth')
    logger.info(f"Model of epoch {epoch} saved at checkpoint_{iters + 1}.pth, lr={lr:.3e}")
    
    torch.save({
        'model': model_state,
        'optimizer': optim.state_dict(),
        'scheduler': sched.state_dict(), 
        'epoch': epoch
    }, checkpoint_path)
    
    torch.save({
        'model': model_state,
        'optimizer': optim.state_dict(),
        'scheduler': sched.state_dict(), 
        'epoch': epoch
    }, new_checkpoint_path)

    
def load_checkpoint(checkpoint_path, model, optim=None, sched=None):
    checkpoint = torch.load(checkpoint_path)
    logger.info(f"Model of epoch {checkpoint['epoch']} is loaded")
    
    model.load_state_dict(checkpoint['model'], strict=False)
    if optim is not None and sched is not None:
        optim.load_state_dict(checkpoint['optimizer'])
        sched.load_state_dict(checkpoint['scheduler'])
        return model, optim, sched, checkpoint['epoch']
    else:
        return model, checkpoint['epoch']

In [11]:
def train_epoch(epoch, model, accelerator, 
                train_dataloader, checkpointing_steps, 
                optimizer, lr_scheduler, save_path):
    global overall_step
    model.train()
    epoch_loss = []
    pbar = tqdm(train_dataloader)
    grad = torch.tensor(0.0)
    
    output_dir = f"step_{overall_step}.pth"
    unwrapped_model = accelerator.unwrap_model(model)
    save_checkpoint(save_path, unwrapped_model, optimizer, lr_scheduler, 
                   epoch, overall_step)
    
    for step, batch in enumerate(pbar):
        if (datetime.datetime.now() - start_time) > datetime.timedelta(hours=6):
            break
        with accelerator.accumulate(model):
            # Gradient accumulation
            # with accelerator.autocast():
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                grad = accelerator.clip_grad_norm_(parameters=model.parameters(), 
                                                   max_norm=2.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        overall_step += 1
        lr = optimizer.param_groups[0]['lr']
        
        pbar.set_description(
            f'Epoch {epoch}: loss = {loss.item(): .3f}, grad = {grad.item(): .3f}, lr = {lr: .3e}')
        
        with torch.no_grad():
            avg_loss = accelerator.gather(loss.repeat(len(batch))).mean()
        epoch_loss.append(avg_loss.item() / accelerator.gradient_accumulation_steps)
        
        accelerator.wait_for_everyone()
        if overall_step % checkpointing_steps == 0:
            logger.info(
                f'Epoch {epoch}: loss = {loss.item(): .3f}, grad = {grad.item(): .3f}, lr = {lr: .3e}')
            if accelerator.is_local_main_process:
                # Clear all of the current models
                del_past_models(save_path)

                output_dir = f"step_{overall_step}.pth"
                unwrapped_model = accelerator.unwrap_model(model)
                save_checkpoint(save_path, unwrapped_model, optimizer, lr_scheduler, 
                               epoch, overall_step)
    
    # Just log the loss of the main process
    if len(epoch_loss) > 0:
        logger.info(f'Epoch {epoch}: loss = {sum(epoch_loss) / (len(epoch_loss)+1e-10): .3f}, lr = {lr: .3e}')
    
            
def main(batch_size: int, num_epochs: int, lr: float, grad_accumulation_steps: int, 
         checkpointing_steps: int, save_path: str, ckpt_path: str,
         num_warmup_steps: int=0, r: int=4, lora_alpha: int=32, lora_dropout: float=0.1):
    set_seed(1234)
    
    accelerator = Accelerator(gradient_accumulation_steps=grad_accumulation_steps)
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True, 
        # bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        # bnb_4bit_use_double_quant=True,
    )
    
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, 
                                                 quantization_config=quantization_config, 
                                                 torch_dtype=torch.float16)
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    
    
    accelerator.print(model)
    
    # Instantiate dataloaders. (We do not split the test data)
    train_dataloader = DataLoader(
        data, shuffle=True, collate_fn=collate_fn, batch_size=batch_size
    )
    
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, 
                             inference_mode=False, r=r, 
                             lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                             target_modules=
                            lora_target_modules_dict.get(MODEL_PATH, ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],)
                            )
    model = get_peft_model(model, peft_config)
    
    if accelerator.is_local_main_process:
        model.print_trainable_parameters()
    
    lr = lr * accelerator.num_processes * accelerator.gradient_accumulation_steps
    
    optimizer = AdamW(params=model.parameters(), lr=lr)
    
    total_steps = len(train_dataloader) * num_epochs * \
            accelerator.num_processes * accelerator.gradient_accumulation_steps
    
    # Instantiate scheduler
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        # to ensure the lr will not become zero at the end of training
        # (You can adjust this param)
        num_training_steps=total_steps * 1.1,
    )
    
    current_epochs = 0
    global overall_step
    overall_step = 0
    
    # Load checkpoint
    print(ckpt_path)
    if os.path.isfile(ckpt_path):
        print('loading')
        model, optimizer, lr_scheduler, current_epochs = \
                    load_checkpoint(ckpt_path, model, optimizer, lr_scheduler)
        c = re.search('(\d)+', ckpt_path)
        if c is None: overall_step = 1
        else: overall_step = int(c.group())
        logger.info(
            f'Checkpoint {ckpt_path} loaded at epoch {current_epochs}, the training will resume from epoch {current_epochs + 1}!')
        current_epochs += 1
    
    
    if current_epochs >= num_epochs:
        raise ValueError('The num_epochs should be larger than the saved epochs!!')
    
    # Prepare everything
    # There is no specific order to remember, 
    # we just need to unpack the objects in the same order we gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    
    
    logger.info('*********************** Start training! **************************')
    
    for epoch in range(current_epochs, num_epochs):
        train_epoch(epoch, model, accelerator, train_dataloader, 
                    checkpointing_steps, optimizer, lr_scheduler, save_path)
        
    # Save the final model
    accelerator.wait_for_everyone()
    if accelerator.is_local_main_process:
        output_dir = f"step_{overall_step}.pth"
        unwrapped_model = accelerator.unwrap_model(model)
        save_checkpoint(save_path, unwrapped_model, optimizer, lr_scheduler, 
                       epoch, overall_step)

In [12]:
lora_target_modules_dict = {
  'gpt2': ['c_attn'],
  'distilbert/distilgpt2': ['c_attn'],
  'distilbert/distilroberta-base': ['query', 'key', 'value'],
}
import json
os.makedirs('./settings', exist_ok=True)
json.dump(lora_target_modules_dict, open('./settings/lora_target_modules.json', 'w'))

In [13]:
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
print(tokenizer.pad_token_id)

50256


In [14]:
import logging
import os
import datetime
os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
  level=logging.INFO,
  filename='./logs/log_%s.txt' % datetime.datetime.now().strftime('%y%m%d%H%m%S'), filemode='a',
  datefmt='%H:%M:%S',
  format='%(asctime)s - %(levelname)s - %(message)s',
)

logger = logging.getLogger(__name__)
# Stream handler, logging to the stream
logger.info("abc")

In [15]:
import jupyter_capture_output

Jupyter Capture Output v0.0.11


In [None]:
%%capture_text --path "cap/overfitting.txt"

import os

batch_size = 2
grad_accumulation_steps = 1
num_epochs = 600
lr = 5e-5
checkpointing_steps = 500
save_path = os.path.join('./working/trained_models/', MODEL_PATH)
r = 64
lora_alpha = 64
lora_dropout = 0.05

# If ckpt_path is a real path (os.path.isfile(ckpt_path) is True),
# then the checkpoint will be loaded
ckpt_path = os.path.join(save_path, 'checkpoint.pth')
print(ckpt_path)
kwargs = {
'batch_size':batch_size, 'num_epochs':num_epochs, 'lr':lr, 'grad_accumulation_steps':grad_accumulation_steps, 
'checkpointing_steps':checkpointing_steps, 'save_path':save_path, 'ckpt_path':ckpt_path, 'r':r, 'lora_alpha':lora_alpha, 'lora_dropout':lora_dropout
}

if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)
# notebook_launcher(main, args, num_processes=1)
main(**kwargs)