In [1]:
%load_ext autoreload
%autoreload 2

In [31]:
import pickle

def create_dummy_dataset():
    """
    Creates a simple dummy dataset for fine-tuning.
    The dataset is a list of dictionaries, where each dictionary
    represents a chat turn with a system prompt, a user query, and a response.
    """
    data = [
        {
            "system": "You are a helpful assistant that provides concise answers.",
            "user": "What is the capital of France?", 
            "response": "The capital of France is Paris."
        },
        {
            "system": "You are a physics teacher who explains complex topics in simple terms.",
            "user": "Explain the theory of relativity in simple terms.", 
            "response": "Einstein's theory of relativity is about how space and time are linked. It has two parts: special relativity, which deals with speed in a straight line, and general relativity, which explains gravity as a curving of space and time by mass and energy."
        },
        {
            "system": "You are a poet who responds with creative and short poems.",
            "user": "Write a short poem about the ocean.", 
            "response": "Vast and blue, a world unseen,\nWhere ancient secrets lie between.\nThe waves that crash upon the shore,\nWhisper tales forevermore."
        },
        {
            "system": "You are a helpful assistant that provides concise answers.",
            "user": "What is 2+2?",
            "response": "2+2 equals 4."
        }
    ]

    # Create a larger dataset for more realistic splitting (100 samples)
    extended_data = []
    for _ in range(25):
        extended_data.extend(data)

    # Save to a pickle file
    with open("chat_traces.pkl", "wb") as f:
        pickle.dump(extended_data, f)
    
    print("Dummy dataset 'chat_traces.pkl' created successfully with 100 samples in the new format.")

if __name__ == "__main__":
    create_dummy_dataset()



Dummy dataset 'chat_traces.pkl' created successfully with 100 samples in the new format.


In [32]:
create_dummy_dataset()

Dummy dataset 'chat_traces.pkl' created successfully with 100 samples in the new format.


# Test training script configs

In [33]:
from train import *

In [35]:
config = load_config("./configs/qwen_conf.yml")

# Set up W&B environment variables
os.environ["WANDB_PROJECT"] = config['wandb_project']

# 2. Load and Prepare the Dataset
print("Loading and preparing dataset...")
with open('chat_traces.pkl', 'rb') as f:
    data = pickle.load(f)

dataset = Dataset.from_list(data)
# Pass model name to the prompt creation function to handle different prompt formats
dataset = dataset.map(lambda sample: create_prompt(sample, model_name=config.get('model_name', '')))
print(f"Dataset loaded with {len(dataset)} examples.")
print("Sample prompt:\n", dataset[0]['text'])

# 3. Load Tokenizer and Model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
tokenizer.model_max_length = config.get('max_seq_length', 1024)

# Llama 3 does not have a pad token, so we add one
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Determine the response template based on model name to correctly mask labels
model_name_lower = config['model_name'].lower()
if "qwen" in model_name_lower:
    response_template = "<|im_start|>assistant\n"
else: # Default to Llama 3
    response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"

data_collator = SFTDataCollator(tokenizer=tokenizer)

Loading and preparing dataset...


Map: 100%|██████████| 100/100 [00:00<00:00, 20672.80 examples/s]

Dataset loaded with 100 examples.
Sample prompt:
 <|im_start|>system
You are a helpful assistant that provides concise answers.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant
The capital of France is Paris.<|im_end|>
Loading model and tokenizer...





In [36]:
dataset[:5]

{'system': ['You are a helpful assistant that provides concise answers.',
  'You are a physics teacher who explains complex topics in simple terms.',
  'You are a poet who responds with creative and short poems.',
  'You are a helpful assistant that provides concise answers.',
  'You are a helpful assistant that provides concise answers.'],
 'user': ['What is the capital of France?',
  'Explain the theory of relativity in simple terms.',
  'Write a short poem about the ocean.',
  'What is 2+2?',
  'What is the capital of France?'],
 'response': ['The capital of France is Paris.',
  "Einstein's theory of relativity is about how space and time are linked. It has two parts: special relativity, which deals with speed in a straight line, and general relativity, which explains gravity as a curving of space and time by mass and energy.",
  'Vast and blue, a world unseen,\nWhere ancient secrets lie between.\nThe waves that crash upon the shore,\nWhisper tales forevermore.',
  '2+2 equals 4.',


In [40]:
dataset[0]

{'system': 'You are a helpful assistant that provides concise answers.',
 'user': 'What is the capital of France?',
 'response': 'The capital of France is Paris.',
 'prompt': '<|im_start|>system\nYou are a helpful assistant that provides concise answers.<|im_end|>\n<|im_start|>user\nWhat is the capital of France?<|im_end|>\n',
 'text': '<|im_start|>system\nYou are a helpful assistant that provides concise answers.<|im_end|>\n<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.<|im_end|>'}

In [38]:
a = data_collator([dataset[0]])

In [41]:
a

{'input_ids': tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847,    429,
           5707,  63594,  11253,     13, 151645,    198, 151644,    872,    198,
           3838,    374,    279,   6722,    315,   9625,     30, 151645,    198,
         151644,  77091,    198,    785,   6722,    315,   9625,    374,  12095,
             13, 151645]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
         151644,  77091,    198,    785,   6722,    315,   9625,    374,  12095,
             13, 151645]])}

In [39]:
a['input_ids']

tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847,    429,
           5707,  63594,  11253,     13, 151645,    198, 151644,    872,    198,
           3838,    374,    279,   6722,    315,   9625,     30, 151645,    198,
         151644,  77091,    198,    785,   6722,    315,   9625,    374,  12095,
             13, 151645]])

In [42]:
tokenizer.decode(151644)

'<|im_start|>'

In [30]:
tokenizer.decode(a['input_ids'][0])

'<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.<|im_end|>'