<a href="https://colab.research.google.com/github/leonrafael29/W266_Final_Project/blob/main/mBART/Model_Saver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install requirements

In [None]:
!pip install datasets -q
!pip install sentencepiece -q
!pip install transformers -q
!pip install git+https://github.com/google-research/bleurt.git -q

# !wget -N https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip . -q
# !unzip -q -n BLEURT-20.zip


[K     |████████████████████████████████| 451 kB 30.7 MB/s 
[K     |████████████████████████████████| 212 kB 88.9 MB/s 
[K     |████████████████████████████████| 115 kB 56.2 MB/s 
[K     |████████████████████████████████| 182 kB 64.2 MB/s 
[K     |████████████████████████████████| 127 kB 85.3 MB/s 
[K     |████████████████████████████████| 1.3 MB 25.9 MB/s 
[K     |████████████████████████████████| 5.5 MB 30.9 MB/s 
[K     |████████████████████████████████| 7.6 MB 52.5 MB/s 
[K     |████████████████████████████████| 352 kB 32.8 MB/s 
[?25h  Building wheel for BLEURT (setup.py) ... [?25l[?25hdone


Imports

In [None]:
import csv
import numpy as np
import pandas as pd
import torch
from bleurt import score
from datasets import load_dataset
from transformers import MBartForConditionalGeneration, \
    MBart50TokenizerFast, MBartConfig,\
    TrainingArguments, Trainer
    

Mount google drive to use for file saving and loading

In [None]:
from google.colab import files, drive
drive.mount('/content/gdrive/', force_remount=True)
%cd gdrive/MyDrive

Mounted at /content/gdrive/
/content/gdrive/MyDrive


Global variables

In [None]:
ORIGINAL_MODEL_CHECKPOINT = 'facebook/mbart-large-50-many-to-many-mmt'
MODEL_CHECKPOINT = 'Mbart/Model/checkpoint-1000'
PAIRS = [
    'en-zh',
    'zh-en',
    'en-es',
    'es-zh',
    'es-en',
    'zh-es',
    ]
MBART_DATA = {
    'en-zh': {
        'size': 69020,
        'train': 48444,
        'val': 10381,
        'src': 'en',
        'tgt': 'zh',
        'src_tkn': 'en_XX',
        'tgt_tkn':'zh_CN',
        'tkn': 'zh_CN',
        'reverse': False,
        'train_path':f'Mbart/Data/en-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/en-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/en-zh-test_pairs.csv',
        },
    'zh-en': {
        'size': 69020,
        'train': 48444,
        'val': 10381,
        'src': 'zh',
        'tgt': 'en',
        'src_tkn': 'zh_CN',
        'tgt_tkn':'en_XX',
        'tkn': 'en_XX',
        'reverse': True,
        'train_path':f'Mbart/Data/en-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/en-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/en-zh-test_pairs.csv',
        },
    'en-es': {
        'size': 238511,
        'train': 167210,
        'val': 35831,
        'src': 'en',
        'tgt': 'es',
        'src_tkn': 'en_XX',
        'tgt_tkn':'es_XX',
        'tkn': 'es_XX',
        'reverse': False,
        'train_path':f'Mbart/Data/en-es-train_pairs.csv',
        'val_path':f'Mbart/Data/en-es-val_pairs.csv',
        'test_path':f'Mbart/Data/en-es-test_pairs.csv',
        },
    'es-zh': {
        'size': 65408,
        'train': 45796,
        'val': 9814,
        'src': 'es',
        'tgt': 'zh',
        'src_tkn': 'es_XX',
        'tgt_tkn':'zh_CN',
        'tkn': 'zh_CN',
        'reverse': False,
        'train_path':f'Mbart/Data/es-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/es-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/es-zh-test_pairs.csv',
        },
    'es-en': {
        'size': 238511,
        'train': 167210,
        'val': 35831,
        'src': 'es',
        'tgt': 'en',
        'src_tkn': 'es_XX',
        'tgt_tkn':'en_XX',
        'tkn': 'en_XX',
        'reverse': True,
        'train_path':f'Mbart/Data/en-es-train_pairs.csv',
        'val_path':f'Mbart/Data/en-es-val_pairs.csv',
        'test_path':f'Mbart/Data/en-es-test_pairs.csv',
        },
    'zh-es': {
        'size': 65408,
        'train': 45796,
        'val': 9814,
        'src': 'zh',
        'tgt': 'es',
        'src_tkn': 'zh_CN',
        'tgt_tkn':'es_XX',
        'tkn': 'es_XX',
        'reverse': True,
        'train_path':f'Mbart/Data/es-zh-train_pairs.csv',
        'val_path':f'Mbart/Data/es-zh-val_pairs.csv',
        'test_path':f'Mbart/Data/es-zh-test_pairs.csv',
        },
    }

DATASET = 'news_commentary'
MAX_LENGTH = 50
MAX_NEW_TOKENS = 50
TRUNCATION = True
PADDING = True
RETURN_TENSORS = 'pt'
BLEURT_CHECKPOINT = './BLEURT-20-D3'
N_EXAMPLES = 100

%env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128

env: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128


Load Model, Metrics and Tokenizer

In [None]:
# config = MBartConfig.from_pretrained(ORIGINAL_MODEL_CHECKPOINT)

# config.encoder_layers = 1
# config.decoder_layers = 1
# config.num_hidden_layers = 1

In [None]:
model = MBartForConditionalGeneration.from_pretrained(ORIGINAL_MODEL_CHECKPOINT)

Downloading:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

In [None]:
torch.save(model, 'Mbart/Model/Single/Baseline')

In [None]:
model.save_pretrained('Mbart/Model/Tiny/TinyMBart')