In [1]:
import os
import shutil
import random
import tqdm
import numpy as np
import pandas as pd

### Read Files

In [2]:
def read_text(path):
    files= os.listdir(path) 
    results = {'text':[], 'highlight': [], 'highlight_1':[], 'highlight_2':[], 'highlight_3':[], 'highlight_4':[]}
    for file in tqdm.tqdm(files):
        if not os.path.isdir(file):
            file_name = path + '/'+file
            with open(file_name, encoding="utf-8") as f:
                text = (f.read()).replace('\n', " ").replace("(CNN)", "").replace("--", "")
                if len(text)<1000:
                    continue
                text_highlights = text.split("@highlight")
                final_text = text_highlights[0]
                results['text'].append(final_text.strip())
                all_highlight = ""
                for i in range(1, 5):
                    key = 'highlight_'+str(i)
                    if i<len(text_highlights):
                        results[key].append(text_highlights[i])
                        all_highlight += text_highlights[i] + '.'
                    else:
                        results[key].append("")
                results['highlight'].append(all_highlight.strip())
    return pd.DataFrame(results)

In [3]:
train_dir = 'train_data'
test_dir = 'test'
test_data = read_text(test_dir)
train_data = read_text(train_dir)

100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 3986.78it/s]
100%|██████████████████████████████████████████████████████████████████████████| 40000/40000 [00:06<00:00, 6565.10it/s]


### Set up vocab

In [4]:
import spacy
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
from torchtext.data.utils import get_tokenizer
from collections import Counter, OrderedDict
from torchtext.vocab import vocab
import torch

In [5]:
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
vocab_counter = Counter()

In [6]:
def token_vocab(vocab_counters, text_dataframe):
    tokens = {'text_tokens':[], 'highlight_tokens':[], 'tokens_num':[]}
    for index, row in tqdm.tqdm(text_dataframe.iterrows()):
        text_tokens = en_tokenizer(row['text'])
        vocab_counters.update(text_tokens)
        tokens['text_tokens'].append(text_tokens)
        tokens['tokens_num'].append(len(text_tokens))
        tokens['highlight_tokens'].append(en_tokenizer(row['highlight']))
        
    return pd.DataFrame(tokens)

In [8]:
train_token = token_vocab(vocab_counter, train_data)
test_token = token_vocab(vocab_counter, test_data)
text_vocab = vocab(vocab_counter, min_freq = 2, specials=['<pad>','<unk>', '<bos>', '<eos>'])

39515it [02:04, 317.95it/s]
1982it [00:05, 357.45it/s]


In [9]:
sorted_train_df = train_token[train_token['tokens_num']>100].sort_values(by='tokens_num')
sorted_test_df = test_token[test_token['tokens_num']>100].sort_values(by='tokens_num')

In [10]:
UNK_ID = text_vocab['<unk>']
BOS_ID = text_vocab['<bos>']
EOS_ID = text_vocab['<eos>']
text_vocab.set_default_index(UNK_ID)

In [11]:
def get_ids(sorted_token_df):
    ids_data = []
    for index, row in tqdm.tqdm(sorted_token_df.iterrows()):
        text_ids = [BOS_ID]
        for ttoken in row['text_tokens']:
            text_ids.append(text_vocab[ttoken])
        text_ids.append(EOS_ID)
        highlight_ids = [BOS_ID]
        for htoken in row['highlight_tokens']:
            highlight_ids.append(text_vocab[htoken])
        highlight_ids.append(EOS_ID)
        ids_data.append((text_ids, highlight_ids))
    return ids_data

In [12]:
ids_train = get_ids(sorted_train_df)
ids_test = get_ids(sorted_test_df)

39515it [00:25, 1573.20it/s]
1982it [00:01, 1609.13it/s]


In [13]:
def collate_fn(data):
    data.sort(key=lambda x: len(x[0]), reverse=True)
    text_data = []
    target_data = []
    for unit in data:
        text_data.append(torch.tensor(unit[0]))
        target_data.append(torch.tensor(unit[1]))
    text = pad_sequence(text_data, batch_first=True)
    target = pad_sequence(target_data, batch_first=True)
    return text, target

In [14]:
train_data_loader = DataLoader(ids_train, batch_size=2, shuffle=True, collate_fn=collate_fn)
test_data_loader = DataLoader(ids_test, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [15]:
torch.save(train_data_loader, 'LSTM_data/train_data_loader.pth')
torch.save(test_data_loader, 'LSTM_data/test_data_loader.pth')
torch.save(text_vocab, 'LSTM_data/text_vocab.pth')