In [None]:
!pip install transformers
!pip install datasets

In [None]:
import math
import sys
import nltk
import pandas as pd
from transformers import GPT2Tokenizer, GPT2Model, GPT2Config, GPT2LMHeadModel, LineByLineTextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline, set_seed
import torch
import numpy as np
import random
from tqdm import tqdm
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
import json
import os
from datasets import load_metric

In [None]:
epoch_num = 30
seed_global = 42
metric = load_metric('accuracy')

In [None]:
def compute_metrics_accuracy(eval_pred):
    global metric
    predictions, labels = eval_pred
    predictions.to('cpu')
    labels.to('cpu')
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
def prep_gpt2_tokenizer():
    # path_ex = data_path
    # Initialize a tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.add_special_tokens({'pad_token': '[PAD]', 'sep_token': '[SEP]'})
    # Customize training
    return tokenizer

In [None]:
def model_prep(tokenizer, train_data_dir, dev_data_dir, model_train_dir):
    global epoch_num, seed_global
    # config and train model
    #configuration = GPT2Config()
    #model = GPT2LMHeadModel(config=configuration)
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    model.resize_token_embeddings(len(tokenizer))
    dataset_train = LineByLineTextDataset(
        tokenizer=tokenizer,
        file_path=train_data_dir,
        block_size=128,
    )
    dataset_dev = LineByLineTextDataset(
        tokenizer=tokenizer,
        file_path=dev_data_dir,
        block_size=128,
    )
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )
    training_args = TrainingArguments(
        output_dir=model_train_dir,
        logging_dir=os.path.join(model_train_dir, '/logs'),
        overwrite_output_dir=True,
        num_train_epochs=epoch_num,
        per_device_train_batch_size=32,  # batch size per device during training
        per_device_eval_batch_size=16,  # batch size for evaluation
        # save_steps=10000,
        # save_total_limit=2,
        warmup_steps=500,  # number of warmup steps for learning rate scheduler
        weight_decay=0.01,  # strength of weight decay
        learning_rate=0.00005, #0.00025, #was 0.00005
        # prediction_loss_only=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        # save_total_limit=10,
        seed=seed_global,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=dataset_train,
        eval_dataset=dataset_dev,
        # compute_metrics=compute_metrics_accuracy,
        # CUDA_LAUNCH_BLOCKING=1,
    )
    return trainer

In [None]:
def read_file(file):
    with open(file) as f:
        lines = f.readlines()
    return lines

In [None]:
def write_output(file, str_in):
    with open(file, 'w') as f:
        f.write(str_in)

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!cp -av /content/drive/MyDrive/CUHK/ling6920/MacWhinney /content/data

In [None]:
list_cha_file = []
for file in os.listdir("data"):
    if file.endswith(".cha"):
        list_cha_file.append(os.path.join("data", file))

In [None]:
list_cha_file.sort()
print(list_cha_file)

In [None]:
from collections import OrderedDict

In [None]:
dict_data = OrderedDict()
dict_chi_age = dict()
dict_mar_age = dict()
for file in tqdm(list_cha_file):
    idx_sub = 0
    input_lines = read_file(file)
    list_sent = []
    tmp_save = ''
    age_chi_tmp = '-1'
    age_mar_tmp = '-1'
    for line in input_lines:
        line = line.strip()
        if line.startswith('*'): # begin to save
            tmp_save += line
        elif line.startswith('%mor'): # stop and save
            list_sent.append(tmp_save)
            tmp_save = ''
        elif line.startswith('@New Episode'): # save prev list_sent to dict
            if tmp_save != '':
                list_sent.append(tmp_save)
            tmp_save = ''
            dict_data[file + str(idx_sub)] = list_sent
            list_sent = []
            idx_sub += 1
        elif '|MacWhinney|CHI|' in line:
            age_chi_tmp = line.split('|')[3]
        elif '|MacWhinney|MAR|' in line:
            age_mar_tmp = line.split('|')[3]
    dict_data[file + str(idx_sub)] = list_sent
    dict_chi_age[file] = age_chi_tmp
    dict_mar_age[file] = age_mar_tmp

In [None]:
dict_chi_age

In [None]:
import re

In [None]:
def proc_sent_format(str_in):
    sent_speaker_split = str_in.split(':\t')
    speaker = sent_speaker_split[0][1:]
    #print(speaker)
    sent_raw = ' '.join(sent_speaker_split[1:])
    split_play_icon = sent_raw.split('')[0]
    #print(split_play_icon)
    # TODO for now I keep all that's in [], () and <>
    return speaker, split_play_icon

In [None]:
def dict_add_one(dict_in, item_in):
    if item_in not in dict_in:
        dict_in[item_in] = 0
    dict_in[item_in] += 1

In [None]:
# do not differentiate ages for the time being
dict_speaker = dict()
dict_counter = dict()
processed_sent = []
for idx_name, sent_list in dict_data.items(): 

    #age_chi = idx_name.split('/')[1][1:2]
    age_chi = dict_chi_age[idx_name.split('.cha')[0] + '.cha'].split(';')[0]
    age_mar = dict_mar_age[idx_name.split('.cha')[0] + '.cha'].split(';')[0]

    if len(sent_list) > 0:
        #print(sent_list[0])
        last_speaker = -1
        last_speaker_str = ''
        tmp_save_list = []
        
        for sent in sent_list:
            speaker, text = proc_sent_format(sent)
            if speaker not in dict_speaker:
                dict_speaker[speaker] = 0
            dict_speaker[speaker] += 1
            # TODO only use FAT, CHI, MOT, MAR
            speaker_str = speaker
            if speaker == 'CHI' or speaker == 'MAR':
                speaker = 1
            elif speaker == 'FAT' or speaker == 'MOT':
                speaker = 0
            else:
                speaker = 2 

            if speaker != 0 and last_speaker == 0:
                tmp_str = ' '.join(tmp_save_list)
                processed_sent.append(tmp_str)
                tmp_save_list = []
                dict_add_one(dict_counter, age_chi)
            if speaker_str == 'CHI' and last_speaker_str != speaker_str:
                tmp_save_list.append('[AGE=' + age_chi + ']')
            elif speaker_str == 'MAR' and last_speaker_str != speaker_str:
                tmp_save_list.append('[AGE=' + age_mar + ']')
            if speaker == 0 and last_speaker != 0:
                if last_speaker_str == 'CHI':
                    tmp_save_list.append('[SEP]')
                elif last_speaker_str == 'MAR':
                    tmp_save_list.append('[SEP]')
            if speaker != 2:
                tmp_save_list.append(text)
            last_speaker = speaker
            last_speaker_str = speaker_str


In [None]:
for str_t in processed_sent:
    for i in range(8):
        if ('[AGE=' + str(i)) in str_t and ('[AGE=' + str(i+2)) in str_t:
            print(str_t)

In [None]:
dict_counter

In [None]:
dict_speaker

In [None]:
random.shuffle(processed_sent)
print(len(processed_sent))
len_processed_sent = int(len(processed_sent) * 0.8)
len_processed_sent_test = int(len(processed_sent) * 0.9)
print(len_processed_sent)
print(len_processed_sent_test)

In [None]:
write_output('data/train.txt', '\n'.join(processed_sent[:len_processed_sent]))
write_output('data/dev.txt', '\n'.join(processed_sent[len_processed_sent:len_processed_sent_test]))
write_output('data/test.txt', '\n'.join(processed_sent[len_processed_sent_test:]))

In [None]:
!cp /content/data/train.txt /content/drive/MyDrive/CUHK/ling6920/yige/yc-data-4-23-2/train.txt
!cp /content/data/dev.txt /content/drive/MyDrive/CUHK/ling6920/yige/yc-data-4-23-2/dev.txt
!cp /content/data/test.txt /content/drive/MyDrive/CUHK/ling6920/yige/yc-data-4-23-2/test.txt

In [None]:
setup_seed(seed_global)

In [None]:
print('preparing the tokenizer...')
# tokenizer = prep_tokenizer(os.path.join(sys.argv[2], 'stock_return_str_conv.txt'))
# tokenizer.save_model(sys.argv[4])
tokenizer = prep_gpt2_tokenizer()

In [None]:
output = tokenizer('well (.) how_about to your friends ?  [SEP] them too . ')["input_ids"]
print(output[10])
print(tokenizer.decode(output[10]))

In [None]:
print('checking the device...')
print('CUDA: ' + str(torch.cuda.is_available()))

In [None]:
print('preparing the model...')
trainer = model_prep(tokenizer, 'data/train.txt', 'data/dev.txt', 'tmp')
print('training the model...')
trainer.train()
print('saving the model...')
trainer.save_model('model')

print('training all set!')

In [None]:
str_check = "[AGE=3] Daddy did they have shoes .  [SEP] "

In [None]:
from transformers import pipeline, set_seed

model = GPT2LMHeadModel.from_pretrained('model')

generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
set_seed(42)

In [None]:
generator(str_check, max_length=40, num_return_sequences=1)

In [None]:
# load checkpoint
tokenizer = prep_gpt2_tokenizer()

model = GPT2LMHeadModel.from_pretrained('model-4-23-2')

generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
set_seed(42)

In [None]:
for i in range(8):
    str_concat = '[AGE=' + str(i) + '] ' + str_check
    print(i)
    print(generator(str_concat, max_length=40, num_return_sequences=1))