In [11]:
import pandas as pd
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, EarlyStoppingCallback
from trl import SFTTrainer
from peft import LoraConfig
from datasets import Dataset
from trl import SFTConfig
from datasets import Dataset


DOMAIN_COUNT = 5

In [12]:
# Load and validate datasets
test_set = pd.read_csv('../data/test_set.csv')
test_set


Unnamed: 0,description,suggestions
0,Artisan bakery specializing in sourdough and s...,[]
1,Children's educational gaming platform with in...,[]
2,Boutique law firm specializing in contractual ...,[]
3,Indie gaming podcast reviewing cozy simulation...,[]
4,Juvenile enrichment center offering poker tour...,[]
5,Avian wildlife rehabilitation facility focusin...,[]
6,Legal consultation service for intellectual pr...,[]
7,Interactive entertainment platform featuring s...,[]
8,Youth development program incorporating high-s...,[]
9,Specialized veterinary clinic treating feather...,[]


In [13]:
dataset_v1 = pd.read_csv('../data/dataset_v1.csv')

def check_suggestions_have_5_strings_or_empty(df):
    for idx, row in df.iterrows():
        try:
            suggestions = json.loads(row['suggestions'])
            if not isinstance(suggestions, list) or (len(suggestions) != 5 and len(suggestions) != 0):
                print(f"Row {idx}: suggestions does not contain exactly 5 items")
                return False
            for item in suggestions:
                if not isinstance(item, str):
                    print(f"Row {idx}: suggestion item is not a string: {item}")
                    return False
        except json.JSONDecodeError:
            print(f"Row {idx}: invalid JSON in suggestions")
            return False
    print("All suggestions contain exactly 5 strings or are empty.")
    return True

check_suggestions_have_5_strings_or_empty(dataset_v1)
print("Dataset V1 shape:", dataset_v1.shape)


All suggestions contain exactly 5 strings or are empty.
Dataset V1 shape: (200, 2)


In [14]:
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
print(f"Loading model: {model_name}")

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

Loading model: mistralai/Mistral-7B-Instruct-v0.3


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s]


In [15]:
def format_dataset(examples):
    texts = []
    for description, suggestions_json_str in zip(examples['description'], examples['suggestions']):

        try:
            suggestions_list = json.loads(suggestions_json_str)
        except json.JSONDecodeError:
            # In case of any errors, just use an empty list
            suggestions_list = []

        # Join the list of domains with a newline character - JSON array pattern is too hard to learn
        # For empty lists, this will correctly result in an empty string
        suggestions_formatted = "\n".join(suggestions_list[:DOMAIN_COUNT])

        text = f"""<s>[INST] Generate {DOMAIN_COUNT} creative domain name(s) (without TLD extensions like .com) for the following business description:

{description} [/INST] {suggestions_formatted}</s>"""

        texts.append(text)
    return {"text": texts}



full_train_dataset = Dataset.from_pandas(dataset_v1)

split = full_train_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split['train']
validation_dataset = split['test']


train_dataset = train_dataset.map(
    format_dataset,
    batched=True,
    remove_columns=train_dataset.column_names
)

validation_dataset = validation_dataset.map(
    format_dataset,
    batched=True,
    remove_columns=validation_dataset.column_names
)

print("Dataset formatted successfully!")
print("Example 1:")
print(train_dataset[0]['text'])
print("\n" + "="*50 + "\n")
print("Example 2:")
print(train_dataset[1]['text'])


Map: 100%|██████████| 180/180 [00:00<00:00, 115298.52 examples/s]
Map: 100%|██████████| 20/20 [00:00<00:00, 20641.26 examples/s]

Dataset formatted successfully!
Example 1:
<s>[INST] Generate 5 creative domain name(s) (without TLD extensions like .com) for the following business description:

Bicycle-powered smoothie cart for events. [/INST] pedalpour
crankcup
blendbike
wheelwhisk
spinSip</s>


Example 2:
<s>[INST] Generate 5 creative domain name(s) (without TLD extensions like .com) for the following business description:

Neighborhood bread share and bake sale. [/INST] loaflist
crumbcart
bakebarter
ovenout
sliceSwap</s>





In [16]:
# Define LoRA configuration
lora_config = LoraConfig(
    r=128,  # Rank of the update matrices - those would need doubel check and/or hyperparams optim
    lora_alpha=256,  # Scaling factor
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],

    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

print("LoRA configuration created")


LoRA configuration created


In [17]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=2
)

In [18]:
sft_config = SFTConfig(
    output_dir="../models/model_v1",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=5e-6,
    num_train_epochs=10,
    max_steps=-1,
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,
    warmup_steps=50,
    lr_scheduler_type="cosine",
    report_to=None,
    remove_unused_columns=False,
    max_length=1024,

    # Fix for MPS device - disable fp16, optionally enable bf16
    fp16=False if torch.backends.mps.is_available() else True,
    bf16=True if torch.backends.mps.is_available() else False,

    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

In [19]:
# Create the SFTTrainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    peft_config=lora_config,
    args=sft_config,
    callbacks=[early_stopping_callback],
)

print("SFTTrainer initialized successfully!")


Adding EOS to train dataset: 100%|██████████| 180/180 [00:00<00:00, 75129.34 examples/s]
Tokenizing train dataset: 100%|██████████| 180/180 [00:00<00:00, 13614.68 examples/s]
Truncating train dataset: 100%|██████████| 180/180 [00:00<00:00, 115971.54 examples/s]
Adding EOS to eval dataset: 100%|██████████| 20/20 [00:00<00:00, 18653.79 examples/s]
Tokenizing eval dataset: 100%|██████████| 20/20 [00:00<00:00, 6324.82 examples/s]
Truncating eval dataset: 100%|██████████| 20/20 [00:00<00:00, 14523.21 examples/s]

SFTTrainer initialized successfully!





In [20]:
trainer.train()



Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,2.1708,1.986558,1.384448,10848.0,0.698622
2,1.343,1.482552,1.431394,21696.0,0.705437
3,0.8195,1.417125,1.068874,32544.0,0.721006
4,0.6331,1.589967,0.987004,43392.0,0.708432
5,0.5245,1.770476,0.849815,54240.0,0.713863




TrainOutput(global_step=225, training_loss=1.3633195527394613, metrics={'train_runtime': 439.5229, 'train_samples_per_second': 4.095, 'train_steps_per_second': 1.024, 'total_flos': 2424316510863360.0, 'train_loss': 1.3633195527394613, 'epoch': 5.0})

In [21]:
save_path = "../models/model_v1"
trainer.save_model(save_path)
print(f"Model saved to: {save_path}")

tokenizer.save_pretrained(save_path)
print("Tokenizer saved as well")


Model saved to: ../models/model_v1
Tokenizer saved as well


In [22]:
description = "Artisanal coffee roastery with single-origin beans"
#description = "Gluten free bakery online delivery"
#description = "Hit man for hire"
prompt = f"""Generate {DOMAIN_COUNT} creative domain name(s) (without TLD extensions like .com) for the following business description:

{description}"""

test_input = f"<s>[INST] {prompt} [/INST]"

inputs = tokenizer(test_input, return_tensors="pt").to(model.device)
input_token_length = inputs.input_ids.shape[1]

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )


# slicing the output tensor
generated_token_ids = outputs[0, input_token_length:]

# Decode only the new tokens to get the clean response
response = tokenizer.decode(generated_token_ids, skip_special_tokens=True)

print(f"Generated response:\n{response}")

Generated response:
beanbloom
roastriver
groundgrove
brewbarn
cupcraft
