In [1]:
import os
import torch
import warnings
from datasets import load_dataset
from dotenv import load_dotenv
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, 
    DataCollatorForLanguageModeling, BitsAndBytesConfig
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel
from tqdm import tqdm

In [2]:
# Suppress Warnings
warnings.filterwarnings("ignore")

# Load Environment Variables
env_path = "/media/volume/LegalEase/Repos/CPSC5830-Team1/.env"
load_dotenv(env_path)
HF_READ_TOKEN = os.getenv("BENS_HUGGING_FACE_READ_TOKEN")
HF_WRITE_TOKEN = os.getenv("BENS_HUGGING_FACE_WRITE_TOKEN")
print(f"Read Token: {HF_READ_TOKEN}")
print(f"Write Token: {HF_WRITE_TOKEN}")

Read Token: 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
Write Token: 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


In [3]:
# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Load Mistral 7B Model & Tokenizer
model_name = "mistralai/Mistral-7B-Instruct-v0.2"

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_use_double_quant=True,  
    bnb_4bit_quant_type="nf4",  
    bnb_4bit_compute_dtype=torch.bfloat16  
)

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_READ_TOKEN,
    cache_dir="/media/volume/LegalEaseMaxim/cache"
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
# Enable Gradient Checkpointing
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [8]:
# Load Tokenizer & Set Padding
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/media/volume/LegalEaseMaxim/cache")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [9]:
# Load Dataset
dataset_path = {"train": "my_dataset1.jsonl"}
data = load_dataset("json", data_files=dataset_path)

In [10]:
split_data = data["train"].train_test_split(test_size=0.1)
print(f"Training Data: {len(split_data['train'])}")
print(f"Validation Data: {len(split_data['test'])}")

Training Data: 558
Validation Data: 62


In [11]:
def format_prompt(example):
    messages = example["messages"]
    formatted_text = ""
    for msg in messages:
        role = msg["role"]
        content = msg["content"]
        if role == "system":
            formatted_text += f"[SYSTEM] {content} [/SYSTEM]\n"
        elif role == "user":
            formatted_text += f"[INST] {content} [/INST]\n"
        elif role == "assistant":
            formatted_text += f"{content}\n"
    return {"formatted_text": formatted_text}


In [12]:
formatted_data = split_data.map(format_prompt).remove_columns(["messages"])

Map:   0%|          | 0/558 [00:00<?, ? examples/s]

Map:   0%|          | 0/62 [00:00<?, ? examples/s]

In [13]:
# Tokenization Function
def tokenize_function(examples):
    return tokenizer(
        examples["formatted_text"], 
        truncation=True, 
        padding="max_length", 
        max_length=1024, 
        add_special_tokens=True
    )


In [14]:
tokenized_data = formatted_data.map(tokenize_function, batched=True)

Map:   0%|          | 0/558 [00:00<?, ? examples/s]

Map:   0%|          | 0/62 [00:00<?, ? examples/s]

In [15]:
# LoRA Configuration
config = LoraConfig(
    r=16,  
    lora_alpha=32,  
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], 
    lora_dropout=0.1,  
    bias="none",  
    task_type="CAUSAL_LM"
)

In [16]:
model = get_peft_model(model, config)

In [17]:
# Data Collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


In [18]:
# Training Arguments
training_args = TrainingArguments(
    output_dir="/media/volume/LegalEaseMaxim/output",
    learning_rate=2e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_accumulation_steps=4,
    warmup_steps=30,
    fp16=True,  
    optim="adamw_bnb_8bit"
)


In [19]:
# Trainer Setup
trainer = Trainer(
    model=model,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    args=training_args,
    data_collator=data_collator
)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [20]:
# Train Model
model.config.use_cache = False  
trainer.train()
model.config.use_cache = True 

Epoch,Training Loss,Validation Loss
1,1.1492,0.678636
2,0.616,0.625532
3,0.4935,0.613867
4,0.3844,0.636441
5,0.3059,0.675321


In [21]:
if isinstance(model, PeftModel):
    model = model.merge_and_unload()
    print("LoRA merged successfully!")
else:
    print("Model is not a PEFT model, skipping merge.")


LoRA merged successfully!


In [22]:
# Save Model & Tokenizer
model.save_pretrained("./business_llm")
tokenizer.save_pretrained("./business_llm")
model.cpu()
torch.cuda.empty_cache()

In [23]:
# Push to Hugging Face Hub
PUBLISH_TO_HUB = True
if PUBLISH_TO_HUB:
    repo_name = "XCIT3D247/LegalEaseV2"
    model.push_to_hub(repo_name, token=HF_WRITE_TOKEN)
    tokenizer.push_to_hub(repo_name, token=HF_WRITE_TOKEN)
    print(f"Model uploaded to: https://huggingface.co/{repo_name}")

model.safetensors:   0%|          | 0.00/4.65G [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


Model uploaded to: https://huggingface.co/XCIT3D247/LegalEaseV2


In [24]:
# Load Model for Inference
tokenizer = AutoTokenizer.from_pretrained(repo_name)
model = AutoModelForCausalLM.from_pretrained(repo_name, torch_dtype=torch.bfloat16, device_map="auto")
model = PeftModel.from_pretrained(model, repo_name)


config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.
Some parameters are on the meta device because they were offloaded to the cpu.
Some parameters are on the meta device because they were offloaded to the cpu.


In [25]:
# Generate Responses
def generate_response(prompt, max_length=1028):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(**inputs, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(output[0], skip_special_tokens=True)


In [26]:
# Test the Model
test_questions = [
    "What type of business entity should I choose for a tech startup?",
    "What are the tax implications of forming an LLC?",
    "How does Delaware compare to Washington for incorporating a business?"
]

In [27]:
for question in test_questions:
    print(f"Q: {question}")
    print(f"A: {generate_response(question)}\n")

Q: What type of business entity should I choose for a tech startup?
A: What type of business entity should I choose for a tech startup?

The choice of business entity for a tech startup depends on several factors, including your goals, liability concerns, tax implications, and regulatory requirements. Here's a breakdown of the most common options:

1. **Sole Proprietorship:** This is the simplest structure, but it offers the least protection. You are personally liable for all business debts and obligations. It's best for very small, low-risk startups.

2. **Partnership:** If you're starting with one or more co-founders, a partnership might be suitable. You'll need to define the roles and responsibilities of each partner and address issues of profit and loss sharing.  Partnerships can be general (unlimited liability) or limited (limited liability).

3. **Limited Liability Company (LLC):** This is a popular choice for many startups. It offers limited liability protection, meaning your pe