In [None]:
from .data_handler import DataHandler
import re
import json
from sklearn.model_selection import train_test_split

In [None]:
def clean_texts(texts):
    cleaned_texts = []
    for text in texts:
      text = re.sub(r'※ \[.*?\] ', '', text)
      text = re.sub(r'作者: .*? \(.*?\) 看板: .*? 標題: .*? 時間:\s+\w{3}\s+\w{3}\s+\d{1,2}\s+\d{2}:\d{2}:\d{2}\s+\d{4}', '', text)
      text = re.sub(r'http[s]?://\S+', '', text)
      text = re.sub(r'^※.*?之銘言：$', '', text, flags=re.MULTILINE)
      text = re.sub(r'.{0,5}?(網誌：|網誌版：|圖文：|圖文版：)', '', text)
      text = re.sub(r'※ 發信站:.*', '', text, flags=re.DOTALL)
      text = re.sub(r'[^\w。，!?"]', '', text, flags=re.UNICODE)
      text = re.sub(r'_+', '', text)
      cleaned_texts.append(text)

    return cleaned_texts

def load_json(file_name):
    with open(file_name, 'r', encoding='utf-8') as file:
      datas = json.load(file)
      contents = [data['content'] for data in datas]
      
    return contents

In [None]:
def preprocess_travel_data(travel_data_dir, test_size=0.2):
    texts_non_travel_related = load_json(travel_data_dir + '/non_travel_related.json')
    texts_travel_related = load_json(travel_data_dir + '/travel_related,json')
   
    cleaned_texts_non_travel = clean_texts(texts_non_travel_related)
    cleaned_texts_travel = clean_texts(texts_travel_related)

    train_texts_non_travel, test_texts_non_travel, train_labels_non_travel, test_labels_non_travel = train_test_split(
      cleaned_texts_non_travel, [0]*len(cleaned_texts_non_travel), test_size=test_size, random_state=42)
    
    train_texts_travel, test_texts_travel, train_labels_travel, test_labels_travel = train_test_split(
      cleaned_texts_travel, [1]*len(cleaned_texts_travel), test_size=test_size, random_state=42)
    
    train_texts = train_texts_non_travel + train_texts_travel
    train_labels = train_labels_non_travel + train_labels_travel
    test_texts = test_texts_non_travel + test_texts_travel
    test_labels = test_labels_non_travel + test_labels_travel

    return train_texts, train_labels, test_texts, test_labels

In [None]:
travel_data_dir = 'output'
train_texts, train_labels, test_texts, test_labels = preprocess_travel_data(travel_data_dir, test_size=0.2)

data_handler = DataHandler(tokenizer_name='bert-base-chinese')

train_encodings = data_handler.gen_encoded_data(train_texts, train_labels, max_length=256)
test_encodings = data_handler.gen_encoded_data(test_texts, test_labels, max_length=256)

data_handler.save_encoded_data('encoded_data/train', train_encodings, train_labels)
data_handler.save_encoded_data('encoded_data/test', test_encodings, test_labels)