In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Generator, Any
import pandas as pd
import json
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from dotenv import load_dotenv
import os
from rich import print as rprint
from mbay_nmt.utils import domain as d
from mbay_nmt.utils.models import new_object_id
from datasets import load_dataset, Dataset, DatasetDict
from rich import print as rprint

load_dotenv()

True

In [4]:
# from huggingface_hub import notebook_login

# notebook_login()

In [5]:
uri = os.environ["MONGODB_URI"]

# Create a new client and connect to the server
client = MongoClient(uri, server_api=ServerApi("1"))

# Send a ping to confirm a successful connection
try:
    client.admin.command("ping")
    print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
    print(e)

Pinged your deployment. You successfully connected to MongoDB!


In [6]:
entries = [
    d.Entry(**entry)
    for entry in client.get_database("dictionary").get_collection("entries-prod").find()
]

In [7]:
entries[0]

Entry(id=ObjectId('64eca312f6197fd20d762cf5'), created_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), updated_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), headword='àlmbétɨ̀, àlmétɨ̀', part_of_speech='NI', sound_filename='NewExpSS2215.mp3', french=Translation(translation='match', key='m'), english=Translation(translation='match.', key='m'), related_word=None, grammatical_note=None, examples=[Example(id=ObjectId('64eca312f6197fd20d76096f'), created_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), updated_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), parent_id=ParentId(id=ObjectId('64eca312f6197fd20d762cf5'), type='entry'), mbay='gà àlmbétɨ̀', english=Translation(translation='-light a match', key='l'), french=Translation(translation='allumer une allumette', key='a'), sound_filename=None), Example(id=ObjectId('64eca312f6197fd20d760970'), created_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), up

In [12]:
from typing import Literal, TypedDict


class Record(TypedDict):
    type: Literal["entry", "example", "expression"]
    mbay: str
    french: str
    english: str

In [9]:
entries[0]


def entry_to_records(entry: d.Entry) -> Generator[Record, Any, None]:
    yield {
        "type": "entry",
        "mbay": entry.headword,
        "french": entry.french.translation,
        "english": entry.english.translation,
    }

    for example in entry.examples:
        yield {
            "type": "example",
            "mbay": example.mbay,
            "french": example.french.translation,
            "english": example.english.translation,
        }

    for expression in entry.expressions:
        yield {
            "type": "expression",
            "mbay": expression.mbay,
            "french": expression.french.translation,
            "english": expression.english.translation,
        }


list(entry_to_records(entries[0]))

[{'type': 'entry',
  'mbay': 'àlmbétɨ̀, àlmétɨ̀',
  'french': 'match',
  'english': 'match.'},
 {'type': 'example',
  'mbay': 'gà àlmbétɨ̀',
  'french': 'allumer une allumette',
  'english': '-light a match'},
 {'type': 'example',
  'mbay': 'ī-gá àlmbétɨ̀ ādɨ̄-m̄.',
  'french': 'Allumez une allumette pour moi.',
  'english': 'Light a match for me.'},
 {'type': 'example',
  'mbay': 'kùm-àlmbétɨ̀',
  'french': 'allumette non allumée',
  'english': '-unlit match stick'},
 {'type': 'example',
  'mbay': 'Màn̄ à ɔ̀dɨ̀ kùm-àlmbétɨ̀ ànḛ̄ à ùnjɨ̄ àĺ.',
  'french': "Si l'eau touche une allumette, elle ne s'allumera pas.",
  'english': "If water touches a matchstick it won't light."},
 {'type': 'example',
  'mbay': 'kāgɨ̄-àlmbétɨ̀',
  'french': '-allumette utilisée',
  'english': '-used match stick'},
 {'type': 'example',
  'mbay': 'ādɨ̄-m̄ kāgɨ̄-àlmbétɨ̀ kɨ́rā m̄-ɗāa-ň mbī-ḿ.',
  'french': "Donne-moi une allumette utilisée pour que je puisse me nettoyer l'oreille avec.",
  'english': 'Give me a 

In [10]:
records: list[Record] = []
for entry in entries:
    records.extend(entry_to_records(entry))

# Let's check the first few records
records[:5]

[{'type': 'entry',
  'mbay': 'àlmbétɨ̀, àlmétɨ̀',
  'french': 'match',
  'english': 'match.'},
 {'type': 'example',
  'mbay': 'gà àlmbétɨ̀',
  'french': 'allumer une allumette',
  'english': '-light a match'},
 {'type': 'example',
  'mbay': 'ī-gá àlmbétɨ̀ ādɨ̄-m̄.',
  'french': 'Allumez une allumette pour moi.',
  'english': 'Light a match for me.'},
 {'type': 'example',
  'mbay': 'kùm-àlmbétɨ̀',
  'french': 'allumette non allumée',
  'english': '-unlit match stick'},
 {'type': 'example',
  'mbay': 'Màn̄ à ɔ̀dɨ̀ kùm-àlmbétɨ̀ ànḛ̄ à ùnjɨ̄ àĺ.',
  'french': "Si l'eau touche une allumette, elle ne s'allumera pas.",
  'english': "If water touches a matchstick it won't light."}]

In [11]:
len(records) / len(entries)

2.0945046586803575

In [4]:
CSV_DATASET_PATH = "../../datasets/mbay-translations-flattened.csv.gzip"
SPLIT_DATASET_PATH = "../../datasets/mbay-translations/"
TOKENIZED_DATASET_PATH = "../../datasets/mbay-translations-tokenized/"

In [14]:
df = pd.DataFrame(records)
df.to_csv(CSV_DATASET_PATH, index=False, compression="gzip")

NameError: name 'records' is not defined

In [5]:
dst = Dataset.from_csv(CSV_DATASET_PATH)
dst

Dataset({
    features: ['type', 'mbay', 'french', 'english'],
    num_rows: 11015
})

In [6]:
dst[10]

{'type': 'example',
 'mbay': 'ī-ɗāa àngérì nà̰ wétɨ́ ī-sō hólēe tɨ́ nò.',
 'french': 'Faites attention de ne pas tomber dans le trou.',
 'english': 'Be careful lest you fall into the hole.'}

In [7]:
train_test = dst.train_test_split(0.2)
test_valid = train_test["test"].train_test_split(0.5)

train_test_valid_dst = DatasetDict(
    {
        "train": train_test["train"],
        "test": test_valid["test"],
        "validation": test_valid["train"],
    }
)

In [30]:
train_test_valid_dst["test"][2]

{'type': 'example',
 'mbay': 'Ngɔ̀r làā làmíǹ gásɨ̀ ngá̰y.',
 'french': 'Récemment, les citrons sont très difficiles à trouver.',
 'english': 'Recently lemons are very hard to find.'}

In [22]:
train_test_valid_dst.save_to_disk(SPLIT_DATASET_PATH)

Saving the dataset (0/1 shards):   0%|          | 0/8812 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1102 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1101 [00:00<?, ? examples/s]

In [10]:
# T5Tokenizer.__call__?

In [31]:
from mbay_nmt.fine_tune_t5.utils import preprocess_records

In [32]:
from transformers import AutoTokenizer

t5_tokenizer = AutoTokenizer.from_pretrained("google/mt5-base")

In [33]:
train_test_valid_dst.column_names["train"]

['type', 'mbay', 'french', 'english']

In [24]:
test_tokenizer = AutoTokenizer.from_pretrained("google/mt5-base")

Downloading (…)okenizer_config.json:   0%|          | 0.00/376 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/702 [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]



In [28]:
input_txt = f"Translate French to Mbay: Récemment, les citrons sont très difficiles à trouver."
target_txt = "Ngɔ̀r làā làmíǹ gásɨ̀ ngá̰y."

In [25]:
v = test_tokenizer(
    input_txt,
        text_target=target_txt,
        max_length=512,
        truncation=True,
        padding="max_length")


print(v)

{'input_ids': [89349, 21273, 288, 352, 13921, 267, 8661, 297, 53675, 261, 520, 96910, 263, 259, 2759, 259, 7711, 263, 259, 31268, 299, 259, 369, 13417, 295, 260, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [26]:
print(
    sum(1 for i in v["input_ids"] if i != 0),
    sum(1 for i in v["labels"] if i != 0),
)    

27 18


In [27]:
test_tokenizer.decode(v["labels"])

'Ngɔ̀r làā làmíǹ gásɨ̀ ngá̰y.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p

In [34]:
from functools import partial

final_dst = train_test_valid_dst.map(
    partial(preprocess_records, t5_tokenizer),
    batched=True,
    remove_columns=train_test_valid_dst["train"].column_names,
)

Map:   0%|          | 0/8812 [00:00<?, ? examples/s]

Map:   0%|          | 0/1102 [00:00<?, ? examples/s]

Map:   0%|          | 0/1101 [00:00<?, ? examples/s]

In [30]:
tokenizer.pad_token_id

0

In [None]:
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

In [33]:
rprint(train_test_valid_dst["train"][0])

In [29]:
# rprint(final_dst["train"][0])

In [39]:
# t5_tokenizer.decode(final_dst["train"][0]["labels"])

In [40]:
final_dst.save_to_disk(TOKENIZED_DATASET_PATH)

Saving the dataset (0/1 shards):   0%|          | 0/35248 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4408 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4404 [00:00<?, ? examples/s]

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)