<a href="https://colab.research.google.com/github/leonrafael29/W266_Final_Project/blob/main/mBART/MBart_Data_Generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install requirements

In [None]:
!pip install datasets -q

[K     |████████████████████████████████| 441 kB 4.4 MB/s 
[K     |████████████████████████████████| 212 kB 46.8 MB/s 
[K     |████████████████████████████████| 115 kB 70.4 MB/s 
[K     |████████████████████████████████| 95 kB 4.8 MB/s 
[K     |████████████████████████████████| 163 kB 61.4 MB/s 
[K     |████████████████████████████████| 127 kB 67.4 MB/s 
[K     |████████████████████████████████| 115 kB 75.5 MB/s 
[?25h

Imports

In [None]:
import csv
import pandas as pd 

from datasets import load_dataset
from google.colab import files, drive  

drive.mount('/content/gdrive/', force_remount=True)
%cd gdrive/MyDrive  

Mounted at /content/gdrive/
/content/gdrive/MyDrive


In [None]:
MODEL_CHECKPOINT = 'facebook/mbart-large-cc25'
PAIRS = [
    'en-zh',
    'zh-en',
    'en-es',
    'es-zh',
    'es-en',
    'zh-es',
    ]
MBART_DATA = {
    'en-zh': {
        'size': 69020,
        'src': 'en',
        'tgt': 'zh',
        'tkn': 'zh_CN',
        'reverse': False,
        'train_path':f'Mbart/Data/en-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/en-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/en-zh-test_pairs.csv',
        },
    'zh-en': {
        'size': 69020,
        'src': 'zh',
        'tgt': 'en',
        'tkn': 'en_XX',
        'reverse': True,
        'train_path':f'Mbart/Data/en-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/en-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/en-zh-test_pairs.csv',
        },
    'en-es': {
        'size': 238511,
        'src': 'en',
        'tgt': 'es',
        'tkn': 'es_XX',
        'reverse': False,
        'train_path':f'Mbart/Data/en-es-train_pairs.csv',
        'val_path':f'Mbart/Data/en-es-val_pairs.csv',
        'test_path':f'Mbart/Data/en-es-test_pairs.csv',
        },
    'es-zh': {
        'size': 65408,
        'src': 'es',
        'tgt': 'zh',
        'tkn': 'zh_CN',
        'reverse': False,
        'train_path':f'Mbart/Data/es-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/es-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/es-zh-test_pairs.csv',
        },
    'es-en': {
        'size': 238511,
        'src': 'es',
        'tgt': 'en',
        'tkn': 'en_XX',
        'reverse': True,
        'train_path':f'Mbart/Data/en-es-train_pairs.csv',
        'val_path':f'Mbart/Data/en-es-val_pairs.csv',
        'test_path':f'Mbart/Data/en-es-test_pairs.csv',
        },
    'zh-es': {
        'size': 65408,
        'src': 'zh',
        'tgt': 'es',
        'tkn': 'es_XX',
        'reverse': True,
        'train_path':f'Mbart/Data/es-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/es-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/es-zh-test_pairs.csv',
        },
    }

DATASET = 'news_commentary'
MAX_LENGTH = 50
MAX_NEW_TOKENS = 50
TRUNCATION = True
PADDING = True
RETURN_TENSORS = 'pt'
BLEURT_CHECKPOINT = './BLEURT-20-D6'
N_EXAMPLES = 200


In [None]:
def create_datasets(pair_index, dataset = DATASET,):
  """
  Load dataset and split it in train, validation, test

  """
  pair = PAIRS[pair_index]
  src = MBART_DATA[PAIRS[pair_index]]["src"]
  tgt = MBART_DATA[PAIRS[pair_index]]["tgt"]
  reverse = MBART_DATA[PAIRS[pair_index]]["reverse"]
  train_path = MBART_DATA[PAIRS[pair_index]]["train_path"]
  val_path = MBART_DATA[PAIRS[pair_index]]["val_path"]
  test_path = MBART_DATA[PAIRS[pair_index]]["test_path"]

  # Load dataset
  if reverse:
    print(f'skipping {pair}')
    return None
  else:
    df = load_dataset(DATASET, f"{src}-{tgt}")

  # Split dataset
  split_df = df["train"].train_test_split(train_size=0.70, seed=20)

  # Convert dataset to list
  train = split_df["train"]["translation"]
  rest = split_df["test"]["translation"]

  # split rest into validation and test
  rest_half = len(rest)//2
  val = rest[:rest_half]
  test = rest[rest_half:]

  print(len(train))
  print(len(val))
  print(len(test))

  pd.DataFrame(train).to_csv(test_path)
  pd.DataFrame(val).to_csv(val_path)
  pd.DataFrame(test).to_csv(train_path)

  print(f'saved {pair}')
  return None

In [None]:
for i in range(len(PAIRS)):
  create_datasets(pair_index=i)

Downloading builder script:   0%|          | 0.00/5.36k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/116k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading and preparing dataset news_commentary/en-zh to /root/.cache/huggingface/datasets/news_commentary/en-zh/11.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4...


Downloading data:   0%|          | 0.00/18.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/69206 [00:00<?, ? examples/s]

Dataset news_commentary downloaded and prepared to /root/.cache/huggingface/datasets/news_commentary/en-zh/11.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

48444
10381
10381
saved en-zh
skipping zh-en
Downloading and preparing dataset news_commentary/en-es to /root/.cache/huggingface/datasets/news_commentary/en-es/11.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4...


Downloading data:   0%|          | 0.00/28.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/238872 [00:00<?, ? examples/s]

Dataset news_commentary downloaded and prepared to /root/.cache/huggingface/datasets/news_commentary/en-es/11.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

167210
35831
35831
saved en-es
Downloading and preparing dataset news_commentary/es-zh to /root/.cache/huggingface/datasets/news_commentary/es-zh/11.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4...


Downloading data:   0%|          | 0.00/17.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/65424 [00:00<?, ? examples/s]

Dataset news_commentary downloaded and prepared to /root/.cache/huggingface/datasets/news_commentary/es-zh/11.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

45796
9814
9814
saved es-zh
skipping es-en
skipping zh-es


In [None]:
!ls Mbart/Data

en-es-test_pairs.csv   en-zh-test_pairs.csv   es-zh-test_pairs.csv
en-es-train_pairs.csv  en-zh-train_pairs.csv  es-zh-train_pairs.csv
en-es-val_pairs.csv    en-zh-val_pairs.csv    es-zh-val_pairs.csv
