### For this project, we have decide to use a pretrained model NLLB-200 and finetune it to fit our project needs.
Model link: https://huggingface.co/facebook/nllb-200-distilled-600M

In [180]:
#Import packages
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq, NllbTokenizerFast
import numpy
import evaluate
import pandas
from sklearn.model_selection import train_test_split
from datasets import load_dataset, Dataset, DatasetDict

In [181]:
#Load model directly
tokenizer = NllbTokenizerFast.from_pretrained("facebook/nllb-200-distilled-600M", src_lang = 'en', tgt_lang = 'zh')
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

We will be using 2 datasets from Amazon to finetune the NLLB-200 model. The source dataset will be in English and the target dataset will be in Chinese.

In [145]:
#load the source dataset
en_data = pandas.read_json('en-US.jsonl', lines=True)
print(en_data.shape)
print(en_data.columns)
en_data.sample(10)

(16521, 8)
Index(['id', 'locale', 'partition', 'scenario', 'intent', 'utt', 'annot_utt',
       'worker_id'],
      dtype='object')


Unnamed: 0,id,locale,partition,scenario,intent,utt,annot_utt,worker_id
15749,16375,en-US,test,email,email_sendemail,compose an email to bob,compose an email to [person : bob],249
15357,15970,en-US,test,email,email_query,what was the subject of the last email from mom,what was the subject of the last email from [r...,35
6278,6503,en-US,test,general,general_quirky,i was chased by a dog,i was chased by a dog,521
5714,5919,en-US,test,news,news_query,what is the latest update on the new york brid...,what is the latest update on the [news_topic :...,672
5116,5296,en-US,test,news,news_query,olly get me the popular news from b. b. c.,olly get me the popular news from [media_type ...,580
13262,13754,en-US,test,qa,qa_stock,tell me the current price of exxon mobil stock,tell me the current price of [business_name : ...,432
4985,5160,en-US,test,play,play_music,could you please play the song of michael jackson,could you please play the song of [artist_name...,51
2714,2812,en-US,test,iot,iot_coffee,can you make me a cup of coffee,can you make me a cup of coffee,249
15888,16519,en-US,test,email,email_sendemail,i need to add a new email to my contacts,i need to add a new email to my contacts,313
6006,6221,en-US,test,general,general_quirky,i do not know how to answer this question you ...,i do not know how to answer this question you ...,30


In [146]:
#load target dataset
zh_data = pandas.read_json('zh-CN.jsonl', lines=True)
print(zh_data.shape)
print(zh_data.columns)
zh_data.sample(10)

(16521, 10)
Index(['id', 'locale', 'partition', 'scenario', 'intent', 'utt', 'annot_utt',
       'worker_id', 'slot_method', 'judgments'],
      dtype='object')


Unnamed: 0,id,locale,partition,scenario,intent,utt,annot_utt,worker_id,slot_method,judgments
5395,5588,zh-CN,train,weather,weather_query,我需要戴个帽子吗,我需要戴个 [weather_descriptor : 帽子] 吗,23,"[{'slot': 'weather_descriptor', 'method': 'tra...","[{'worker_id': '36', 'intent_score': 1, 'slots..."
15103,15700,zh-CN,test,social,social_post,近期市场上的过期产品,近期市场上的过期产品,30,[],"[{'worker_id': '26', 'intent_score': 1, 'slots..."
12236,12678,zh-CN,train,transport,transport_query,有火车提供的前往北京的行程吗,有 [transport_type : 火车] 提供的前往北京的行程吗,38,"[{'slot': 'transport_type', 'method': 'transla...","[{'worker_id': '21', 'intent_score': 1, 'slots..."
3762,3893,zh-CN,train,audio,audio_volume_mute,请把音量控制静音,请把音量控制静音,35,[],"[{'worker_id': '0', 'intent_score': 1, 'slots_..."
10185,10564,zh-CN,train,lists,lists_remove,请删除任务清单,请删除任务清单,30,[],"[{'worker_id': '10', 'intent_score': 1, 'slots..."
3495,3617,zh-CN,train,takeaway,takeaway_query,这附近有什么外卖吗,这附近有什么 [order_type : 外卖] 吗,30,"[{'slot': 'order_type', 'method': 'translation'}]","[{'worker_id': '26', 'intent_score': 1, 'slots..."
922,959,zh-CN,train,iot,iot_cleaning,通过蓝牙启用设备机器人,通过蓝牙启用设备机器人,38,[],"[{'worker_id': '36', 'intent_score': 1, 'slots..."
9875,10239,zh-CN,train,cooking,cooking_recipe,给我一个烹饪鸡肉的教程,给我一个烹饪 [food_type : 鸡肉] 的教程,15,"[{'slot': 'food_type', 'method': 'translation'}]","[{'worker_id': '2', 'intent_score': 1, 'slots_..."
12983,13464,zh-CN,test,general,general_quirky,我们为什么会在这里,我们为什么会在这里,5,[],"[{'worker_id': '0', 'intent_score': 1, 'slots_..."
10762,11156,zh-CN,train,lists,lists_createoradd,请创建一个新的清单,请创建一个新的清单,30,[],"[{'worker_id': '10', 'intent_score': 1, 'slots..."


In [147]:
data = pandas.concat([en_data['utt'], zh_data['utt']], axis=1, keys=['en', 'zh'])
print(data.shape)
print(data.columns)
data.sample(50)

(16521, 2)
Index(['en', 'zh'], dtype='object')


Unnamed: 0,en,zh
11557,what restaurant is open after midnight,午夜后有什么餐馆还开吗
5788,if it is six am here what time is it in tokyo,如果是这里早晨六点东京几点
8097,hey siri clear all my calendar appoints for today,清除我的今天所有的预约
15366,add doctor rosenstock as an email contact,将李雷医生添加到电子邮件联系人中
4725,did you meet her,你见过她吗
4845,olly make my usual,常规设置
5025,turn my lights down to a lower level of bright...,把我的灯向下调到低亮度
8837,do deletion of next calendar event,删除下一个日历事件
2593,confirm to buy laptop,确定购买笔记本电脑
12710,check for all the movie theater prices and ava...,查一下我所在位置所有电影院的价格和是否有新上映的电影


In [176]:
train_data, test_data = train_test_split(data)
dataset = DatasetDict({
                        "train": Dataset.from_pandas(train_data),
                        "test" : Dataset.from_pandas(test_data)
                    })
dataset = dataset.remove_columns(["__index_level_0__"])
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['en', 'zh'],
        num_rows: 12390
    })
    test: Dataset({
        features: ['en', 'zh'],
        num_rows: 4131
    })
})


We will be following the tutorial from Hugging Face to use PyTorch Trainer to finetune the pretrained model. https://huggingface.co/docs/transformers/en/training

In [164]:
metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = numpy.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [182]:
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset.data["train"],
    eval_dataset=dataset.data["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [184]:
trainer.train()

  0%|          | 0/4647 [01:28<?, ?it/s]


IndexError: index out of bounds