In [3]:
import os
import torch
import sklearn
import numpy as np
import pandas as pd
import torch.nn.functional as F
import json
import gc

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.optim import AdamW
from lightning import Fabric

from huggingface_hub import login
from transformers import get_scheduler, pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer, EarlyStoppingCallback
from peft import get_peft_model, LoraConfig, TaskType, PeftModelForCausalLM

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

from pprint import pprint
from tqdm import tqdm
# import warnings
# warnings.simplefilter("ignore", UserWarning)

%config InlineBackend.figure_formats = ['svg']

In [4]:
with open('../datasets/train_test_data.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

In [6]:
torch.set_float32_matmul_precision("medium")
fabric = Fabric(accelerator="cuda", devices=1, precision="bf16-mixed")
device = fabric.device
fabric.launch()

Using bfloat16 Automatic Mixed Precision (AMP)


In [None]:
# teacher_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
student_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj"],
    lora_dropout=0.2,
    bias="none",
    task_type="CAUSAL_LM"
)

# teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name, device_map="auto", quantization_config=quant_config,).eval()
student_model = AutoModelForCausalLM.from_pretrained(student_model_name, device_map="auto", quantization_config=quant_config,).train()
tokenizer = AutoTokenizer.from_pretrained(student_model_name)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"