In [7]:
import tensorflow as tf
import tensorflow.keras.losses
from transformers import AutoModelForCausalLM, AutoTokenizer, TFAutoModelForCausalLM
import datasets

# Load DialoGPT and tokenizer
model_name = "microsoft/DialoGPT-medium"
model = TFAutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

# Load and tokenize your custom dataset
dataset = datasets.load_dataset("garage-bAInd/Open-Platypus")

# Define a tokenization function for training data
def tokenize_training_data(examples):
    input_text = examples["input"]
    output_text = examples["output"]

    # Tokenize the input and output text separately
    input_ids = tokenizer(input_text, padding="max_length", truncation=True, max_length=926, return_tensors="pt")["input_ids"]
    output_ids = tokenizer(output_text, padding="max_length", truncation=True, max_length=1000, return_tensors="pt")["input_ids"]

    # Determine the actual length of the input sequence
    actual_length = input_ids.shape[1]

    # Ensure that input_ids are within the model's vocabulary range
    max_allowed_token_id = tokenizer.vocab_size - 1
    input_ids = [id if id <= max_allowed_token_id else tokenizer.pad_token_id for id in input_ids[0]]

    return {
        "input_ids": input_ids[:actual_length],  # Truncate to actual length
        "output_ids": output_ids[0].tolist(),
    }



# Tokenize your training dataset
tokenized_training_data = dataset["train"].map(tokenize_training_data, batched=True)

# Define training arguments
training_args = tf.data.Dataset.from_tensor_slices(
    {
        "input_ids": tokenized_training_data["input_ids"],
        "labels": tokenized_training_data["labels"],
    }
)
training_args = training_args.batch(4)

# Create a Trainer instance for fine-tuning
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(input_ids, labels):
    with tf.GradientTape() as tape:
        logits = model(input_ids)["logits"]
        # Use a loss function that doesn't require specifying from_logits
        loss_value = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(labels, logits)

    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss_value



# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    for batch in training_args:
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        loss = train_step(input_ids, labels)
        print(f"Epoch {epoch + 1}, Loss: {loss.numpy()}")

# Save the fine-tuned model
model.save_pretrained("./fine-tuned-dialoGPT")

# Define a function for running the chatbot
def run_chatbot():
    # Load the fine-tuned model for inference
    chatbot = TFAutoModelForCausalLM.from_pretrained("./fine-tuned-dialoGPT")
    chatbot_tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Initialize the conversation history
    conversation_history = ["Hello, how can I help you today?"]

    while True:
        user_input = input("User: ")
        conversation_history.append(user_input)
        input_ids = chatbot_tokenizer.encode(conversation_history, return_tensors="tf")
        response_ids = chatbot.generate(input_ids, max_length=256, pad_token_id=tokenizer.pad_token_id)
        response_text = chatbot_tokenizer.decode(response_ids[0], skip_special_tokens=True)
        print(f"Chatbot: {response_text}")
        conversation_history.append(response_text)

# Call the chatbot function to start chatting
run_chatbot()


All model checkpoint layers were used when initializing TFGPT2LMHeadModel.

All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at microsoft/DialoGPT-medium.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.


Map:   0%|          | 0/24926 [00:00<?, ? examples/s]

ArrowInvalid: Column 3 named input_ids expected length 1000 but got length 926