In [None]:
#!pip install datasets vertexai mercury

In [None]:
#! gcloud auth list

In [None]:
from datasets import load_dataset
import random
import time
import vertexai
from vertexai.preview.tuning import sft
import json
import utils
import mercury as mr

In [None]:
# Load the data
dataset = load_dataset("nguha/legalbench", "contract_nli_explicit_identification")

# Merge and shuffle
data = dataset["train"].to_list() + dataset["test"].to_list()  # Convert to lists before concatenating
random.shuffle(data)

# Add new index
for idx, d in enumerate(data):
    d["new_index"] = idx

In [None]:
len(data)
mr.JSON(data)

In [None]:
base_prompt_zero_shot = "Identify if the clause provides that all Confidential Information shall be expressly identified by the Disclosing Party. Answer with only `Yes` or `No`"

In [None]:
n_train = 30
n_test = len(data) - n_train

In [None]:
train_messages = []
test_messages = []

for d in data:
  prompts=[]
  prompts = [{"role": "system", "parts": [{"text": base_prompt_zero_shot}]}]
  prompts.append({"role": "user", "parts": [{"text": d["text"]}]})
  prompts.append({"role": "model", "parts": [{"text": d["answer"]}]}) 

  if int(d["new_index"]) < n_train:
    #train_messages.append({'messages': prompts})
    train_messages.append({'contents': prompts})

  else:
    #test_messages.append({'messages': prompts})
    test_messages.append({'contents': prompts})

len(train_messages), len(test_messages), n_test, train_messages[5]

In [None]:
for d in data:
  tuningdataset=[]
  tuningdataset = [{"role": "system", "parts": [{"text": system_instructions}]}]
  tuningdataset.append({"role": "user", "parts": [{"text": d["text"]}]})
  tuningdataset.append({"role": "model", "parts": [{"text": d["answer"]}]}) 
  tuningdataset.append({'contents': prompts})

In [None]:
utils.dicts_to_jsonl(train_messages, "train_contents", False)
utils.dicts_to_jsonl(test_messages, "test_contents", False)

In [None]:
#upload_blob(bucket_name, source_file_name, destination_blob_name)
#delete_blob(bucket_name, blob_name):
utils.delete_blob("mchrestkha-sample-data","legalbench/contract_nli_explicit_identification/train_contents.jsonl")
utils.delete_blob("mchrestkha-sample-data","legalbench/contract_nli_explicit_identification/test_contents.jsonl")
utils.upload_blob("mchrestkha-sample-data","train_contents.jsonl","legalbench/contract_nli_explicit_identification/train_contents.jsonl")
utils.upload_blob("mchrestkha-sample-data","test_contents.jsonl","legalbench/contract_nli_explicit_identification/test_contents.jsonl")

In [None]:
vertexai.init(project="mchrestkha-sandbox", location="us-central1")

sft_tuning_job = sft.train(
    source_model="gemini-1.5-pro-001",
    train_dataset="gs://mchrestkha-sample-data/legalbench/contract_nli_explicit_identification/train_contents.jsonl",
    #train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
    # The following parameters are optional
    validation_dataset="gs://mchrestkha-sample-data/legalbench/contract_nli_explicit_identification/test_contents.jsonl",
    epochs=5,
    adapter_size=4,
    learning_rate_multiplier=1.0,
    tuned_model_display_name="1.5_flash_tuned_legalbench_tuned_nli_explicit_identificationv2",
)

# Polling for job completion
while not sft_tuning_job.has_ended:
    time.sleep(60)
    sft_tuning_job.refresh()

print(sft_tuning_job.tuned_model_name)
print(sft_tuning_job.tuned_model_endpoint_name)
print(sft_tuning_job.experiment)

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix