Skip to content

Commit

Permalink
Support csv and json files, csv URLs.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Oct 10, 2023
1 parent b6e39f2 commit dbeca8e
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 15 deletions.
78 changes: 63 additions & 15 deletions google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from collections.abc import Mapping
import dataclasses
import datetime
import json
import pathlib
import re
from typing import Any, Iterable, TypedDict, Union

Expand Down Expand Up @@ -184,35 +186,81 @@ class TuningExampleDict(TypedDict):
output: str


TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str]]
TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]]
TuningDataOptions = Union[
glm.Dataset, Mapping[str, Iterable[str]], Iterable[TuningExampleOptions]
] # TODO(markdaoust): csv, drive
pathlib.Path | str | glm.Dataset, Mapping[str, Iterable[str]], Iterable[TuningExampleOptions]
]


def encode_tuning_data(
data: TuningDataOptions, input_key="text_input", output_key="output"
) -> glm.Dataset:
if isinstance(data, str):
# Strings are either URLs or system paths.
if re.match("^\w+://\S+$", data):
data = _normalize_url(data)
else:
# Normalize system paths to use pathlib
data = pathlib.Path(data)

if isinstance(data, pathlib.Path):
if data.suffix.lower() == ".json":
with open(data) as f:
data = json.load(f)

if isinstance(data, (str, pathlib.Path)):
import pandas as pd
data = pd.read_csv(data)

if isinstance(data, glm.Dataset):
return data
elif hasattr(data, "keys"):
new_data = list()
return _convert_dict(data, input_key, output_key)
else:
return _convert_iterable(data)


def _normalize_url(url: str) -> str:
sheet_base = "https://docs.google.com/spreadsheets"
if url.startswith(sheet_base):
# Normalize google-sheets URLs to download the csv.
match = re.match(f"{sheet_base}/d/[^/]+", url)
if match is None:
raise ValueError("Incomplete Google Sheets URL: {data}")
url = f"{match.group(0)}/export?format=csv"
return url


def _convert_dict(data, input_key, output_key):
new_data = list()

try:
inputs = data[input_key]
except KeyError as e:
raise KeyError(f'input_key is "{input_key}", but data has keys: {sorted(data.keys())}')

try:
outputs = data[output_key]
for i, o in zip(inputs, outputs):
new_data.append(glm.TuningExample({"text_input": i, "output": o}))
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))
else:
new_data = list()
for example in data:
example = encode_tuning_example(example)
new_data.append(example)
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))
except KeyError as e:
raise KeyError(f'output_key is "{output_key}", but data has keys: {sorted(data.keys())}')

for i, o in zip(inputs, outputs):
new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)}))
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))


def _convert_iterable(data):
new_data = list()
for example in data:
example = encode_tuning_example(example)
new_data.append(example)
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))


def encode_tuning_example(example: TuningExampleOptions):
if isinstance(example, tuple):
example = glm.TuningExample(text_input=example[0], output=example[1])
if isinstance(example, (tuple, list)):
a, b = example
example = glm.TuningExample(text_input=a, output=b)
else: # dict or glm.TuningExample
example = glm.TuningExample(example)
return example
Expand Down
4 changes: 4 additions & 0 deletions tests/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
text_input,output
a,1
b,2
c,3
5 changes: 5 additions & 0 deletions tests/test1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
{"text_input": "a", "output": "1"},
{"text_input": "b", "output": "2"},
{"text_input": "c", "output": "3"}
]
1 change: 1 addition & 0 deletions tests/test2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}
5 changes: 5 additions & 0 deletions tests/test3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
["a","1"],
["b","2"],
["c","3"]
]
16 changes: 16 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
import datetime
import dataclasses
import pathlib
import pytz
from typing import Any, Union
import unittest
Expand All @@ -32,6 +33,8 @@

import pandas as pd

HERE = pathlib.Path(__file__).parent


class UnitTests(parameterized.TestCase):
def setUp(self):
Expand Down Expand Up @@ -424,6 +427,19 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source):
]
),
],
["csv-path-string", str(HERE / "test.csv")],
["csv-path", HERE / "test.csv"],
["json-file-1", HERE / "test1.json"],
["json-file-2", HERE / "test2.json"],
["json-file-3", HERE / "test3.json"],
[
"sheet-share",
"https://docs.google.com/spreadsheets/d/1OffcVSqN6X-RYdWLGccDF3KtnKoIpS7O_9cZbicKK4A/edit?usp=sharing",
],
[
"sheet-export-csv",
"https://docs.google.com/spreadsheets/d/1OffcVSqN6X-RYdWLGccDF3KtnKoIpS7O_9cZbicKK4A/export?format=csv",
],
)
def test_create_dataset(self, data, ik="text_input", ok="output"):
ds = model_types.encode_tuning_data(data, input_key=ik, output_key=ok)
Expand Down

0 comments on commit dbeca8e

Please sign in to comment.