In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from typing import Generator, Any
import pandas as pd
import json
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from dotenv import load_dotenv
import os
from rich import print as rprint
from mbay_nmt.utils import domain as d
from mbay_nmt.utils.models import new_object_id
from datasets import load_dataset, Dataset, DatasetDict

load_dotenv()

True

In [4]:
# from huggingface_hub import notebook_login

# notebook_login()

In [5]:
uri = os.environ["MONGODB_URI"]

# Create a new client and connect to the server
client = MongoClient(uri, server_api=ServerApi("1"))

# Send a ping to confirm a successful connection
try:
    client.admin.command("ping")
    print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
    print(e)

Pinged your deployment. You successfully connected to MongoDB!


In [6]:
entries = [
    d.Entry(**entry)
    for entry in client.get_database("dictionary").get_collection("entries-prod").find()
]

In [7]:
entries[0]

Entry(id=ObjectId('64eca312f6197fd20d762cf5'), created_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), updated_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), headword='àlmbétɨ̀, àlmétɨ̀', part_of_speech='NI', sound_filename='NewExpSS2215.mp3', french=Translation(translation='match', key='m'), english=Translation(translation='match.', key='m'), related_word=None, grammatical_note=None, examples=[Example(id=ObjectId('64eca312f6197fd20d76096f'), created_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), updated_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), parent_id=ParentId(id=ObjectId('64eca312f6197fd20d762cf5'), type='entry'), mbay='gà àlmbétɨ̀', english=Translation(translation='-light a match', key='l'), french=Translation(translation='allumer une allumette', key='a'), sound_filename=None), Example(id=ObjectId('64eca312f6197fd20d760970'), created_at=datetime.datetime(2023, 9, 9, 12, 27, 55, tzinfo=TzInfo(UTC)), up

In [8]:
from typing import Literal, TypedDict


class Record(TypedDict):
    type: Literal["entry", "example", "expression"]
    mbay: str
    french: str
    english: str

In [9]:
entries[0]


def entry_to_records(entry: d.Entry) -> Generator[Record, Any, None]:
    yield {
        "type": "entry",
        "mbay": entry.headword,
        "french": entry.french.translation,
        "english": entry.english.translation,
    }

    for example in entry.examples:
        yield {
            "type": "example",
            "mbay": example.mbay,
            "french": example.french.translation,
            "english": example.english.translation,
        }

    for expression in entry.expressions:
        yield {
            "type": "expression",
            "mbay": expression.mbay,
            "french": expression.french.translation,
            "english": expression.english.translation,
        }


list(entry_to_records(entries[0]))

[{'type': 'entry',
  'mbay': 'àlmbétɨ̀, àlmétɨ̀',
  'french': 'match',
  'english': 'match.'},
 {'type': 'example',
  'mbay': 'gà àlmbétɨ̀',
  'french': 'allumer une allumette',
  'english': '-light a match'},
 {'type': 'example',
  'mbay': 'ī-gá àlmbétɨ̀ ādɨ̄-m̄.',
  'french': 'Allumez une allumette pour moi.',
  'english': 'Light a match for me.'},
 {'type': 'example',
  'mbay': 'kùm-àlmbétɨ̀',
  'french': 'allumette non allumée',
  'english': '-unlit match stick'},
 {'type': 'example',
  'mbay': 'Màn̄ à ɔ̀dɨ̀ kùm-àlmbétɨ̀ ànḛ̄ à ùnjɨ̄ àĺ.',
  'french': "Si l'eau touche une allumette, elle ne s'allumera pas.",
  'english': "If water touches a matchstick it won't light."},
 {'type': 'example',
  'mbay': 'kāgɨ̄-àlmbétɨ̀',
  'french': '-allumette utilisée',
  'english': '-used match stick'},
 {'type': 'example',
  'mbay': 'ādɨ̄-m̄ kāgɨ̄-àlmbétɨ̀ kɨ́rā m̄-ɗāa-ň mbī-ḿ.',
  'french': "Donne-moi une allumette utilisée pour que je puisse me nettoyer l'oreille avec.",
  'english': 'Give me a 

In [10]:
records: list[Record] = []
for entry in entries:
    records.extend(entry_to_records(entry))

# Let's check the first few records
records[:5]

[{'type': 'entry',
  'mbay': 'àlmbétɨ̀, àlmétɨ̀',
  'french': 'match',
  'english': 'match.'},
 {'type': 'example',
  'mbay': 'gà àlmbétɨ̀',
  'french': 'allumer une allumette',
  'english': '-light a match'},
 {'type': 'example',
  'mbay': 'ī-gá àlmbétɨ̀ ādɨ̄-m̄.',
  'french': 'Allumez une allumette pour moi.',
  'english': 'Light a match for me.'},
 {'type': 'example',
  'mbay': 'kùm-àlmbétɨ̀',
  'french': 'allumette non allumée',
  'english': '-unlit match stick'},
 {'type': 'example',
  'mbay': 'Màn̄ à ɔ̀dɨ̀ kùm-àlmbétɨ̀ ànḛ̄ à ùnjɨ̄ àĺ.',
  'french': "Si l'eau touche une allumette, elle ne s'allumera pas.",
  'english': "If water touches a matchstick it won't light."}]

In [11]:
len(records) / len(entries)

2.0945046586803575

In [14]:
CSV_DATASET_PATH = "../../datasets/mbay-translations-flattened.csv.gzip"
SPLIT_DATASET_PATH = "../../datasets/mbay-translations/"
TOKENIZED_DATASET_PATH = "../../datasets/mbay-translations-tokenized/"

In [15]:
df = pd.DataFrame(records)
df.to_csv(CSV_DATASET_PATH, index=False, compression="gzip")

In [16]:
dst = Dataset.from_csv(CSV_DATASET_PATH)

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [12]:
train_test = dst.train_test_split(0.2)
test_valid = train_test["test"].train_test_split(0.5)

train_test_valid_dst = DatasetDict(
    {
        "train": train_test["train"],
        "test": test_valid["test"],
        "validation": test_valid["train"],
    }
)

In [14]:
train_test_valid_dst.save_to_disk(SPLIT_DATASET_PATH)

Saving the dataset (0/1 shards):   0%|          | 0/8812 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1102 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1101 [00:00<?, ? examples/s]

In [24]:
from transformers import AutoTokenizer

checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [56]:
from typing import Iterable

Lang = Literal["mbay", "french", "english"]

prefix = "Translate English to Mbay: "
source_lang: Lang = "english"
target_lang: Lang = "mbay"


def prepare_pair(examples, prefix: str, source_lang: Lang, target_lang: Lang):
    inputs = [prefix + example for example in examples[source_lang]]
    targets = [example for example in examples[target_lang]]
    return inputs, targets


def preprocess_records(examples):
    inputs: list[str] = []
    targets: list[str] = []

    _inputs, _target = prepare_pair(
        examples, "Translate English to Mbay: ", "english", "mbay"
    )
    inputs.extend(_inputs)
    targets.extend(_target)

    _inputs, _target = prepare_pair(
        examples, "Translate Mbay to English: ", "mbay", "english"
    )
    inputs.extend(_inputs)
    targets.extend(_target)

    _inputs, _target = prepare_pair(
        examples, "Translate French to Mbay: ", "french", "mbay"
    )
    inputs.extend(_inputs)
    targets.extend(_target)

    _inputs, _target = prepare_pair(
        examples, "Translate Mbay to French: ", "mbay", "french"
    )
    inputs.extend(_inputs)
    targets.extend(_target)

    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=128, truncation=True
    )
    return model_inputs

In [26]:
from mbay_nmt.fine_tune_t5.utils import preprocess_records

In [27]:
from transformers import AutoTokenizer

t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base")

In [33]:
train_test_valid_dst.column_names["train"]

['type', 'mbay', 'french', 'english']

In [34]:
from functools import partial

final_dst = train_test_valid_dst.map(
    partial(preprocess_records, t5_tokenizer),
    batched=True,
    remove_columns=train_test_valid_dst["train"].column_names,
)

Map:   0%|          | 0/8812 [00:00<?, ? examples/s]

Map:   0%|          | 0/1102 [00:00<?, ? examples/s]

Map:   0%|          | 0/1101 [00:00<?, ? examples/s]

In [36]:
final_dst.save_to_disk(TOKENIZED_DATASET_PATH)

Saving the dataset (0/1 shards):   0%|          | 0/35248 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4408 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4404 [00:00<?, ? examples/s]

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
import evaluate

metric = evaluate.load("sacrebleu")

In [None]:
import numpy as np


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result