In [6]:
import pandas as pd
import datasets
from typing import Dict, Sequence

In [2]:
df = pd.read_parquet("en-2-fr-translation.parquet", engine='pyarrow').rename(columns={'English words/sentences': 'input', 'French words/sentences': 'output'})
df

Unnamed: 0,input,output
0,Hi.,Salut!
1,Run!,Cours !
2,Run!,Courez !
3,Who?,Qui ?
4,Wow!,Ça alors !
...,...,...
175461,We need to uphold laws against discrimination ...,Nous devons faire respecter les lois contre la...
175462,A carbon footprint is the amount of carbon dio...,Une empreinte carbone est la somme de pollutio...
175463,Death is something that we're often discourage...,La mort est une chose qu'on nous décourage sou...
175464,Since there are usually multiple websites on a...,Puisqu'il y a de multiples sites web sur chaqu...


In [3]:
# create jsonl files
df.to_json('en-2-fr-translation.jsonl', orient='records', lines=True)

In [32]:
# play around with dataset / tokenizer
train_dataset = datasets.load_dataset('json', data_files='en-2-fr-translation.jsonl')
train_dataset

Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 5809.29it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 342.92it/s]
Generating train split: 175466 examples [00:00, 3527675.37 examples/s]


DatasetDict({
    train: Dataset({
        features: ['en', 'fr'],
        num_rows: 175466
    })
})

In [33]:
import transformers
from torch.utils.data import Dataset
import torch
from dataset import fmt_prompt
import os


In [71]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
        'microsoft/phi-2',
        model_max_length=2048,
        padding_side="right",
        use_fast=False,
        pad_token="<|pad|>",
        trust_remote_code=True,
    )

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [72]:
def preprocess(
        samples: Sequence[str],
        tokenizer: transformers.PreTrainedTokenizer
    ) -> Dict:
    """Preprocess data for training by tokenizing"""
    sources = [f"{fmt_prompt(sources)}" for sources in samples["input"]]
    targets = [f"{translation}{tokenizer.eos_token}" for translation in samples["output"]]
    complete_examples = [s + t for s,t in zip(sources, targets)]
    """tokenize examples"""
    tokenized_strings = [
        tokenizer(
            example,
            return_tensors='pt',
            padding=False,
            max_length=tokenizer.model_max_length,
            truncation=True,
        ) 
        for example in complete_examples
    ]

    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_strings]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_strings
    ]
    # return dict(
    #     input_ids=input_ids,
    #     labels=labels,
    #     input_ids_lens=input_ids_lens,
    #     labels_lens=labels_lens,
    # )
    print(input_ids)



    return None

In [73]:
class MyDataSet(Dataset):
    """Dataset for fine-tuning model"""

    def __init__(self, tokenizer: transformers.PreTrainedTokenizer, paths: str, limit=3000):
        super(MyDataSet, self).__init__()
        dataset = (
            datasets.load_dataset(
            "json",
            data_files=paths,
            split=f"train[0:{limit}]" if limit else "train",
            )
            # .filter(
            #     # filter data entries
            #     )
            .map(
                lambda samples: preprocess(samples, tokenizer),
                batched=True,
                batch_size=300,
                # create a preprocessing function 
            )
        )

        self.tokenizer = tokenizer
        self.data = None 
        # self.size = len(dataframe)

    def __len__(self) -> int:
        return self.size

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        
        return None
        

In [74]:
dataset = MyDataSet(tokenizer, ['en-2-fr-translation.jsonl'])

Map:  20%|██        | 600/3000 [00:00<00:00, 3450.68 examples/s]

[tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198, 15902,    13,   198,
          198, 21017, 18261,    25,   198,  3363,   286,  1781,     0,  3423,
          318,   257, 48718, 11059,   286,   326,  9546,    25,   220,   198,
        19221,   315,     0, 50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198,  5660,     0,   198,
          198, 21017, 18261,    25,   198,  3363,   286,  1781,     0,  3423,
          318,   257, 48718, 11059,   286,   326,  9546,    25,   220,   198,
           34,  4662,   447,   107,     0, 50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198,  5660,     0,   198,
          198, 21017, 18261,    25,   198,  3363,   286,  1781,     0,  3423,
          318,   257, 48718, 11059,   286,   326,  

Map:  60%|██████    | 1800/3000 [00:00<00:00, 3429.12 examples/s]

[tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198, 16160,   866,     0,
          198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,     0,
         3423,   318,   257, 48718, 11059,   286,   326,  9546,    25,   220,
          198, 19452,    68, 26605,   747,  2634,    13, 50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198, 16160,   866,     0,
          198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,     0,
         3423,   318,   257, 48718, 11059,   286,   326,  9546,    25,   220,
          198, 19452,  8471, 26605,   747,  2634,    13, 50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198, 16160,   866,    13,
          198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,     0,
         

Map:  80%|████████  | 2400/3000 [00:00<00:00, 3521.55 examples/s]

[tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198,  1867,   318,   340,
           30,   198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,
            0,  3423,   318,   257, 48718, 11059,   286,   326,  9546,    25,
          220,   198,  4507,     6,   395,    12,   344,  5633, 50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198,  1867,   318,   340,
           30,   198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,
            0,  3423,   318,   257, 48718, 11059,   286,   326,  9546,    25,
          220,   198,  4507,     6,   395,  2906,   269,     6,   395,  5633,
        50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198,  1867,   338,   649,
           30,   198,   198, 21017, 18261,    25,  

Map: 100%|██████████| 3000/3000 [00:00<00:00, 3545.79 examples/s]

[tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198, 12346,   257,  3128,
           13,   198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,
            0,  3423,   318,   257, 48718, 11059,   286,   326,  9546,    25,
          220,   198,  1925, 10924,   271, 17809,  3128,  5145, 50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198, 39269,  1663,    13,
          198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,     0,
         3423,   318,   257, 48718, 11059,   286,   326,  9546,    25,   220,
          198, 35882,  4618,   274,   279,   516, 34086,    13, 50256]), tensor([21017, 27759,    25,   198,  1680,   345,  3387, 15772,   428,  9546,
          393,  1573,   284, 48718,    30,   220,   198, 39269,  1663,    13,
          198,   198, 21017, 18261,    25,   198,  3363,   286,  1781,   


