In [1]:
from data_module import TextDataset
from dataloader import compute_loss, custom_data_collator_forget, compute_idk_loss
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, BitsAndBytesConfig
import transformers
import os
from peft import LoraConfig, get_peft_model
import wandb
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
import torch.nn.functional as F
from tqdm import tqdm
from accelerate import Accelerator
import copy
from torch.cuda.amp import GradScaler, autocast

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def train_loop(peft_model, base_model, dataloader, optimizer, scheduler, num_epochs, gradient_accumulation_steps, save_dir, start_epoch):


    peft_model.train()
    base_model.eval()


    for epoch in range(start_epoch, num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(
            dataloader,
            desc=f"Epoch {epoch + 1}/{num_epochs}",
        )
        optimizer.zero_grad()

        for step, batch in enumerate(progress_bar):
            
            inputs = batch


            # Compute custom loss
            loss = compute_loss(peft_model, inputs, base_model)
            scaled_loss = loss / gradient_accumulation_steps  # Scale loss for accumulation

            scaled_loss.backward()

            # Perform optimizer and scheduler step
            if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()


            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=epoch_loss / len(dataloader))
        torch.save(
        {
            "epoch": epoch + 1,
            "model_state_dict": peft_model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "loss": epoch_loss,
        },
        os.path.join(save_dir, f"checkpoint_epoch_{epoch + 1}.pt")
        )
        peft_model.save_pretrained(f"results/peft_checkpoint_{epoch+1}/")
        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

In [None]:
seed_num = 42
set_seed(seed_num)

batch_size = 2
gradient_accumulation_steps = 4
num_epochs = 5
max_length = 256
data_path = 'data/'
lr = 1e-4
weight_decay = 0.01
save_dir = 'results/'
start_epoch = 0

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained('PKU-Alignment/alpaca-7b-reproduced', quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained('PKU-Alignment/alpaca-7b-reproduced')


Loading checkpoint shards: 100%|██████████| 7/7 [01:20<00:00, 11.53s/it]


In [4]:
base_model = AutoModelForCausalLM.from_pretrained('PKU-Alignment/alpaca-7b-reproduced', quantization_config=bnb_config, device_map='auto')

Loading checkpoint shards: 100%|██████████| 7/7 [00:03<00:00,  1.82it/s]


In [5]:
torch_format_dataset = TextDataset(data_path, tokenizer=tokenizer, max_length=max_length)
dataloader = DataLoader(torch_format_dataset, batch_size=batch_size, collate_fn=custom_data_collator_forget)


peft_config = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)

peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()

# Define optimizer and scheduler
optimizer = AdamW(peft_model.parameters(), lr=lr, weight_decay=weight_decay)
num_training_steps = num_epochs * (len(torch_format_dataset) // batch_size)
scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

trainable params: 8,388,608 || all params: 6,746,812,416 || trainable%: 0.1243


In [6]:
print('Loading saved weights')
desired_epoch = 3
checkpoint = torch.load(os.path.join(save_dir, f"checkpoint_epoch_{desired_epoch}.pt"))
peft_model.load_state_dict(checkpoint["model_state_dict"], strict=False)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
start_epoch = desired_epoch

Loading saved weights


  checkpoint = torch.load(os.path.join(save_dir, f"checkpoint_epoch_{desired_epoch}.pt"))


In [7]:
print('Starting training')
train_loop(
        model=peft_model,
        base_model=base_model,
        dataloader=dataloader,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=num_epochs,
        gradient_accumulation_steps=gradient_accumulation_steps,
        save_dir = save_dir,
        start_epoch = start_epoch
    )

Starting training


Epoch 4/5: 100%|██████████| 2400/2400 [1:56:59<00:00,  2.92s/it, loss=0.143]   


Epoch 4 Loss: 0.14339224822634908


Epoch 5/5: 100%|██████████| 2400/2400 [1:56:54<00:00,  2.92s/it, loss=0.123]   


Epoch 5 Loss: 0.12344643197832435


In [8]:
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained(save_dir+"final_result/final_merged_model")
tokenizer.save_pretrained(save_dir+"final_result/final_merged_model")



('/home/hice1/hkhanuja3/scratch/alpaca_accelerate/altpo_loss/wr_1/final_result/final_merged_model/tokenizer_config.json',
 '/home/hice1/hkhanuja3/scratch/alpaca_accelerate/altpo_loss/wr_1/final_result/final_merged_model/special_tokens_map.json',
 '/home/hice1/hkhanuja3/scratch/alpaca_accelerate/altpo_loss/wr_1/final_result/final_merged_model/tokenizer.json')