In [16]:
import pandas as pd
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, EarlyStoppingCallback
from trl import SFTTrainer
from peft import LoraConfig
from datasets import Dataset
from trl import SFTConfig
from datasets import Dataset


DOMAIN_COUNT = 5

In [17]:
# Load and validate datasets
test_set = pd.read_csv('../data/test_set.csv')
test_set


Test set loaded:


Unnamed: 0,description,suggestions
0,Artisan bakery specializing in sourdough and s...,[]
1,Eco-friendly cleaning product line for househo...,[]
2,Mindfulness and meditation app for busy profes...,[]
3,Local hiking trail guide with offline maps and...,[]
4,Urban rooftop gardening supplies store with st...,[]
5,Boutique fitness studio offering HIIT and yoga...,[]
6,Children's educational science kits subscripti...,[]
7,Home organization consultancy for small apartm...,[]
8,Indie music discovery newsletter curating emer...,[]
9,Farmers market locator with vendor reviews and...,[]


In [18]:
dataset_v1 = pd.read_csv('../data/dataset_v1.csv')

def check_suggestions_have_5_strings_or_empty(df):
    for idx, row in df.iterrows():
        try:
            suggestions = json.loads(row['suggestions'])
            if not isinstance(suggestions, list) or (len(suggestions) != 5 and len(suggestions) != 0):
                print(f"Row {idx}: suggestions does not contain exactly 5 items")
                return False
            for item in suggestions:
                if not isinstance(item, str):
                    print(f"Row {idx}: suggestion item is not a string: {item}")
                    return False
        except json.JSONDecodeError:
            print(f"Row {idx}: invalid JSON in suggestions")
            return False
    print("All suggestions contain exactly 5 strings or are empty.")
    return True

check_suggestions_have_5_strings_or_empty(dataset_v1)
print("Dataset V1 shape:", dataset_v1.shape)


All suggestions contain exactly 5 strings or are empty.
Dataset V1 shape: (200, 2)


In [19]:
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
print(f"Loading model: {model_name}")

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

print("Model and tokenizer loaded successfully!")


Loading model: mistralai/Mistral-7B-Instruct-v0.3


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.25it/s]


Model and tokenizer loaded successfully!


In [20]:
def format_dataset(examples):
    texts = []
    for description, suggestions_json_str in zip(examples['description'], examples['suggestions']):

        try:
            suggestions_list = json.loads(suggestions_json_str)
        except json.JSONDecodeError:
            # In case of any errors, just use an empty list
            suggestions_list = []

        # Join the list of domains with a newline character - JSON array pattern is too hard to learn
        # For empty lists, this will correctly result in an empty string
        suggestions_formatted = "\n".join(suggestions_list[:DOMAIN_COUNT])

        text = f"""<s>[INST] Generate {DOMAIN_COUNT} creative domain name(s) (without TLD extensions like .com) for the following business description:

{description} [/INST] {suggestions_formatted}</s>"""

        texts.append(text)
    return {"text": texts}



full_train_dataset = Dataset.from_pandas(dataset_v1)

split = full_train_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split['train']
validation_dataset = split['test']


train_dataset = train_dataset.map(
    format_dataset,
    batched=True,
    remove_columns=train_dataset.column_names
)

validation_dataset = validation_dataset.map(
    format_dataset,
    batched=True,
    remove_columns=validation_dataset.column_names
)

print("Dataset formatted successfully!")
print("Example 1:")
print(train_dataset[0]['text'])
print("\n" + "="*50 + "\n")
print("Example 2:")
print(train_dataset[1]['text'])


Map: 100%|██████████| 180/180 [00:00<00:00, 116203.59 examples/s]
Map: 100%|██████████| 20/20 [00:00<00:00, 17623.13 examples/s]

Dataset formatted successfully!
Sample formatted example:
<s>[INST] Generate 5 creative domain name(s) (without TLD extensions like .com) for the following business description:

Bicycle-powered smoothie cart for events. [/INST] pedalpour
crankcup
blendbike
wheelwhisk
spinSip</s>


Another example:
<s>[INST] Generate 5 creative domain name(s) (without TLD extensions like .com) for the following business description:

Neighborhood bread share and bake sale. [/INST] loaflist
crumbcart
bakebarter
ovenout
sliceSwap</s>





In [21]:
# Define LoRA configuration
lora_config = LoraConfig(
    r=128,  # Rank of the update matrices - those would need doubel check and/or hyperparams optim
    lora_alpha=256,  # Scaling factor
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],

    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

print("LoRA configuration created")


LoRA configuration created


In [22]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=2
)

In [23]:
sft_config = SFTConfig(
    output_dir="../models/model_v1",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=5e-6,
    num_train_epochs=10,
    max_steps=-1,
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,
    warmup_steps=50,
    lr_scheduler_type="cosine",
    report_to=None,
    remove_unused_columns=False,
    max_length=1024,

    # Fix for MPS device - disable fp16, optionally enable bf16
    fp16=False if torch.backends.mps.is_available() else True,
    bf16=True if torch.backends.mps.is_available() else False,

    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

In [24]:
# Create the SFTTrainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    peft_config=lora_config,
    args=sft_config,
    callbacks=[early_stopping_callback],
)

print("SFTTrainer initialized successfully!")


Adding EOS to train dataset: 100%|██████████| 180/180 [00:00<00:00, 70789.94 examples/s]
Tokenizing train dataset: 100%|██████████| 180/180 [00:00<00:00, 15094.36 examples/s]
Truncating train dataset: 100%|██████████| 180/180 [00:00<00:00, 161561.04 examples/s]
Adding EOS to eval dataset: 100%|██████████| 20/20 [00:00<00:00, 21856.72 examples/s]
Tokenizing eval dataset: 100%|██████████| 20/20 [00:00<00:00, 7590.13 examples/s]
Truncating eval dataset: 100%|██████████| 20/20 [00:00<00:00, 26596.73 examples/s]

SFTTrainer initialized successfully!





In [25]:
trainer.train()

Starting training...




Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,2.1642,1.980727,1.418927,10848.0,0.698598
2,1.3136,1.449548,1.44789,21696.0,0.702705
3,0.8115,1.42309,1.071236,32544.0,0.719623
4,0.6184,1.61111,0.968871,43392.0,0.706959
5,0.5303,1.762759,0.853456,54240.0,0.718829
6,0.4263,1.90657,0.785758,65088.0,0.717357




Training completed!


In [28]:
save_path = "../models/model_v1"
trainer.save_model(save_path)
print(f"Model saved to: {save_path}")

tokenizer.save_pretrained(save_path)
print("Tokenizer saved as well")


Model saved to: ../models/model_v1
Tokenizer saved as well


In [39]:
description = "Artisanal coffee roastery with single-origin beans"
#description = "Gluten free bakery online delivery"
#description = "Hit man for hire"
prompt = f"""Generate {DOMAIN_COUNT} creative domain name(s) (without TLD extensions like .com) for the following business description:

{description}"""

test_input = f"<s>[INST] {prompt} [/INST]"

inputs = tokenizer(test_input, return_tensors="pt").to(model.device)
input_token_length = inputs.input_ids.shape[1]

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )


# slicing the output tensor
generated_token_ids = outputs[0, input_token_length:]

# Decode only the new tokens to get the clean response
response = tokenizer.decode(generated_token_ids, skip_special_tokens=True)

print(f"Generated response:\n{response}")

Generated response:
roastriver
beanbloom
groundgrove
cupcraft
brewbatch
