In [None]:
! pip install keras-nlp datasets wandb

In [2]:
import keras_nlp
import tensorflow as tf
import datasets
from datasets import load_dataset
import numpy as np
import pandas as pd
import wandb
import matplotlib.pyplot as plt

In [3]:
tf.config.list_physical_devices('GPU')

[]

In [None]:
wandb.login()

### Load and preprocess data

In [6]:
dataset = datasets.load_dataset("tatsu-lab/alpaca", split="train")
df = pd.DataFrame(dataset)
df = df[['text']]
df.head()

Unnamed: 0,text
0,Below is an instruction that describes a task....
1,Below is an instruction that describes a task....
2,Below is an instruction that describes a task....
3,Below is an instruction that describes a task....
4,Below is an instruction that describes a task....


In [9]:
print(dataset['text'][0])

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Give three tips for staying healthy.

### Response:
1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. 
2. Exercise regularly to keep your body active and strong. 
3. Get enough sleep and maintain a consistent sleep schedule.


In [10]:
n = int(0.9 * len(df))
train_examples = df[:n]
val_examples = df[n:]

In [11]:
train_examples.head()

Unnamed: 0,text
0,Below is an instruction that describes a task....
1,Below is an instruction that describes a task....
2,Below is an instruction that describes a task....
3,Below is an instruction that describes a task....
4,Below is an instruction that describes a task....


In [12]:
val_examples.head()

Unnamed: 0,text
46801,"Below is an instruction that describes a task,..."
46802,"Below is an instruction that describes a task,..."
46803,"Below is an instruction that describes a task,..."
46804,"Below is an instruction that describes a task,..."
46805,"Below is an instruction that describes a task,..."


In [13]:
train_examples = tf.data.Dataset.from_tensor_slices((train_examples))

val_examples = tf.data.Dataset.from_tensor_slices((val_examples))

In [14]:
BUFFER_SIZE = 20000
BATCH_SIZE = 32

In [15]:
def make_batches(ds):
    return ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [16]:
train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)

### Train model

In [17]:
num_epochs = 5

In [18]:
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=5e-5,
    decay_steps=train_batches.cardinality() * num_epochs,
    end_learning_rate=0.0,
)

optimizer = tf.keras.optimizers.Adam(learning_rate)

In [None]:
wandb.init(project="gpt2-instruct-tune",
           config={
               "learning_rate": learning_rate,
               "architecture": "gpt2",
               "dataset": "tatsu-lab/alpaca",
               "epochs": num_epochs,
               }
           )

In [20]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [21]:
wandb.log({"Loss": loss})

In [23]:
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=300
)

generator = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en",
    preprocessor=preprocessor,
)

Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/preprocessor.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/task.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/config.json...
100%|██████████| 484/484 [00:00<00:00, 617kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/model.weights.h5...
100%|██████████| 475M/475M [00:09<00:00, 51.6MB/s]


In [29]:
generator.compile(
    optimizer=optimizer,
    loss=loss,
    weighted_metrics=["accuracy"],
    )

In [None]:
history = generator.fit(train_batches, validation_data=val_batches, epochs=num_epochs)

Epoch 1/5


In [None]:
wandb.finish()

In [None]:
metrics_df = pd.DataFrame(history.history)
metrics_df.head()

In [None]:
metrics_df[["loss", "val_loss"]].plot()
metrics_df[["accuracy", "val_accuracy"]].plot()

In [None]:
output = generator.generate("Formula 1 is a ", max_length=100)
print(output)

In [None]:
prompt = "Imagine you're a detective solving a mystery in a futuristic city. Describe your first clue."

output = generator.generate(f"### Instruction:\n{prompt}\n### Response:\n", max_length=100)

print(output)

### Save model

In [None]:
generator.save('gpt2-alpaca.keras')