# Notebook para geração do módulo de Dataset usando Pytorch Lightning

O objetivo deste notebook é gerar um módulo de dataset com as seguintes funções:

- Separe o texto em treino, teste e validação.
- Gere os labels automaticamente a partir dos textos de dataset em `data/IWSLT/raw`.

In [1]:
import pytorch_lightning as pl
import os
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Dict, List, Tuple
from torch.utils.data import DataLoader, random_split

dataset_path = Path('../../data/IWSLT/raw/')

In [4]:
import transformers
x = os.walk(dataset_path)
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')

filepath = os.path.join(dataset_path, 'train.txt')

punc_dict = {
    ',COMMA':        1,
    '.PERIOD':       2,
    '?QUESTIONMARK': 3,
}

In [25]:
class IWSLTDataset(Dataset):
    def __init__(self, path, tokenizer, max_len, punc_dict, tok_max_len=278):

        self.path = path
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.punc_dict = punc_dict
        self.tok_max_len = tok_max_len

        data = self._load_data(path)
        token_list, punc_list = self._preprocess_IWSLT(data)
        self.data, self.labels = self._tokens_to_sentence(token_list, punc_list, max_len)

    def _load_data(self, path):
        with open(path) as f:
            data = f.read()
            data = data.split()
        return data

    def _preprocess_IWSLT(self, data):
        token_list = list()
        punc_list = list()
        
        for token in data:
            if token in punc_dict:
                punc_list.pop()
                punc_list.append(self.punc_dict[token])
            else:
                token_list.append(token)
                punc_list.append(0)
        return token_list, punc_list

    def _tokens_to_sentence(self, token_list, punc_list, max_len):
        phrases = list()
        labels = list()

        for i in range(0, len(token_list), max_len):
            j = i + max_len if max_len < len(token_list) else len(token_list)
            phrases.append(' '.join(token_list[i:j]))
            labels.append(punc_list[i:j])
        return phrases, labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = tokenizer.encode_plus(
            self.data[idx],
            max_length=self.tok_max_len,
            pad_to_max_length=True,
            truncation=True,
            return_tensors='pt'
        )
        target = torch.LongTensor(self.labels[idx])
    
        return {
            'sentence': self.data[idx],
            'input_ids': data['input_ids'],
            'attention_mask': data['attention_mask'],
            'target': target,
        }


In [26]:
ds = IWSLTDataset(path='../../data/IWSLT/raw/train.txt',
                    tokenizer=tokenizer,
                    max_len=200,
                    punc_dict=punc_dict)
        

In [28]:
ds[-1]

{'sentence': "not the i have a plan speech listen to politicians now with their comprehensive 12-point plans they 're not inspiring anybody because there are leaders and there are those who lead leaders hold a position of power or authority but those who lead inspire us whether they 're individuals or organizations we follow those who lead not because we have to but because we want to we follow those who lead not for them but for ourselves and it 's those who start with why that have the ability to inspire those around them or find others who inspire them thank you very much",
 'input_ids': tensor([[  101,  2025,  1996,  1045,  2031,  1037,  2933,  4613,  4952,  2000,
           8801,  2085,  2007,  2037,  7721,  2260,  1011,  2391,  3488,  2027,
           1005,  2128,  2025, 18988, 10334,  2138,  2045,  2024,  4177,  1998,
           2045,  2024,  2216,  2040,  2599,  4177,  2907,  1037,  2597,  1997,
           2373,  2030,  3691,  2021,  2216,  2040,  2599, 18708,  2149,  3251,
   

In [41]:
class IWSLTDataset2(pl.LightningDataModule):
    def __init__(self, dataset_path, tokenizer, ref=True):
        super().__init__()
        self.tokenizer = tokenizer
        self.ref = ref
        self.data = None
        self.dataset_path = dataset_path
        self.filenames = {
            'train': 'train.txt',
            'ref': 'ref.txt',
            'dev': 'dev.txt',
            'asr': 'asr.txt',
        }
        self.punc_dict = {
            ',COMMA':        1,
            '.PERIOD':       2,
            '?QUESTIONMARK': 3,
        }

    def _load_data(self, split):
        import os
        split = self.filenames[split]
        datapath = os.path.join(self.dataset_path, split)
        with open(datapath) as f:
            self.data = f.read()
            self.data = self.data.split()

    def _preprocess_IWSLT(self):
        token_list = list()
        punc_list = list()
        
        for token in self.data:
            if token in self.punc_dict:
                punc_list.pop()
                punc_list.append(self.punc_dict[token])
            else:
                token_list.append(token)
                punc_list.append(0)
        return token_list, punc_list

    def _tokens_to_sentence(self, token_list, punc_list, sentence_size):
        phrases = list()
        labels = list()

        for i in range(0, len(token_list), sentence_size):
            j = i + sentence_size if sentence_size < len(token_list) else len(token_list)
            phrases.append(' '.join(token_list[i:j]))
            labels.append(punc_list[i:j])
        return phrases, labels

    def _get_data(self, sentence_size, split='train'):
        self._load_data(split)
        token_list, punc_list = self._preprocess_IWSLT()
        phrases, labels = self._tokens_to_sentence(token_list, punc_list, sentence_size)
        return phrases, labels

    def prepare_data(self, sentence_size=200):
        self.train_X, self.train_y = self._get_data(split='train', sentence_size=sentence_size)
        self.dev_X, self.dev_y = self._get_data(split='dev', sentence_size=sentence_size)
        self.test_X, self.test_y = self._get_data(split='ref', sentence_size=sentence_size)
        
        self.train_X = self.train_X[:-1]
        self.dev_X = self.dev_X[:-1]
        self.test_X = self.test_X[:-1]

        self.train_y = self.train_y[:-1]
        self.dev_y = self.dev_y[:-1]
        self.test_y = self.test_y[:-1]

        self.train_X = self.tokenizer.batch_encode_plus(self.train_X, pad_to_max_length=True,)
        self.dev_X = self.tokenizer.batch_encode_plus(self.dev_X, pad_to_max_length=True,)
        self.test_X = self.tokenizer.batch_encode_plus(self.test_X, pad_to_max_length=True,)

        self.train_y = torch.LongTensor(self.train_y)
        self.dev_y = torch.LongTensor(self.dev_y)
        self.test_y = torch.LongTensor(self.test_y)

    def setup(self):
        pass
        
    def train_dataloader(self):
        return DataLoader(self.train_X, batch_size=1)

    def val_dataloader(self):
        return DataLoader(self.test_X, batch_size=1)

    def test_dataloader(self):
        return DataLoader(self.test_X, batch_size=1)

In [42]:
ds = IWSLTDataset(dataset_path, tokenizer)
ds.prepare_data()

In [39]:
len(ds.train_X['input_ids'][0])

276

In [53]:
trainer = pl.Trainer()

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


In [13]:
with open(filepath) as f:
    data = f.read()

data = data.split()

punc_dict = {
    ',COMMA':        1,
    '.PERIOD':       2,
    '?QUESTIONMARK': 3,
}

def preprocess_IWSLT(data: List[str], punc: Dict[str, int]) -> List[Tuple[str, int]]:
    token_list = list()
    punc_list = list()
    
    for token in data:
        if token in punc_dict:
            punc_list.pop()
            punc_list.append(punc_dict[token])
        else:
            token_list.append(token)
            punc_list.append(0)
    return token_list, punc_list


token_list, punc_list = preprocess_IWSLT(data, punc_dict)

In [14]:
max_len = 50
list(zip(token_list[:max_len], punc_list[:max_len]))
token_list[:max_len]


['it',
 'can',
 'be',
 'a',
 'very',
 'complicated',
 'thing',
 'the',
 'ocean',
 'and',
 'it',
 'can',
 'be',
 'a',
 'very',
 'complicated',
 'thing',
 'what',
 'human',
 'health',
 'is',
 'and',
 'bringing',
 'those',
 'two',
 'together',
 'might',
 'seem',
 'a',
 'very',
 'daunting',
 'task',
 'but',
 'what',
 'i',
 "'m",
 'going',
 'to',
 'try',
 'to',
 'say',
 'is',
 'that',
 'even',
 'in',
 'that',
 'complexity',
 'there',
 "'s",
 'some']

In [15]:
def tokens_to_sentence(token_list, punc_list, sentence_size=200):
    phrases = list()
    labels = list()

    for i in range(0, len(token_list), sentence_size):
        j = i + sentence_size if sentence_size < len(token_list) else len(token_list)
        phrases.append(' '.join(token_list[i:j]))
        labels.append(punc_list[i:j])
    return phrases, labels


In [16]:
phrases, labels = tokens_to_sentence(token_list, punc_list)

In [17]:
phrases[1], labels[1]

("into the water rolf bolin who was a professor at the hopkin 's marine station where i work wrote in the 1940s that the fumes from the scum floating on the inlets of the bay were so bad they turned lead-based paints black people working in these canneries could barely stay there all day because of the smell but you know what they came out saying they say you know what you smell you smell money that pollution was money to that community and those people dealt with the pollution and absorbed it into their skin and into their bodies because they needed the money we made the ocean unhappy we made people very unhappy and we made them unhealthy the connection between ocean health and human health is actually based upon another couple simple adages and i want to call that pinch a minnow hurt a whale the pyramid of ocean life now when an ecologist looks at the ocean i have to tell you we look at the ocean in a very different way and we see different things than when a regular person looks at 

In [11]:
tokenizer.batch_encode_plus(phrases[:1])

{'input_ids': [[101, 2009, 2064, 2022, 1037, 2200, 8552, 2518, 1996, 4153, 1998, 2009, 2064, 2022, 1037, 2200, 8552, 2518, 2054, 2529, 2740, 2003, 1998, 5026, 2216, 2048, 2362, 2453, 4025, 1037, 2200, 4830, 16671, 2075, 4708, 2021, 2054, 1045, 1005, 1049, 2183, 2000, 3046, 2000, 2360, 2003, 2008, 2130, 1999, 2008, 11619, 2045, 1005, 1055, 2070, 3722, 6991, 2008, 1045, 2228, 2065, 2057, 3305, 2057, 2064, 2428, 2693, 2830, 1998, 2216, 3722, 6991, 2024, 1050, 1005, 1056, 2428, 6991, 2055, 1996, 3375, 2671, 1997, 2054, 1005, 1055, 2183, 2006, 2021, 2477, 2008, 2057, 2035, 3492, 2092, 2113, 1998, 1045, 1005, 1049, 2183, 2000, 2707, 2007, 2023, 2028, 2065, 23603, 9932, 1050, 1005, 1056, 3407, 9932, 1050, 1005, 1056, 6343, 3407, 2057, 2113, 2008, 2157, 2057, 1005, 2310, 5281, 2008, 1998, 2065, 2057, 2074, 2202, 2008, 1998, 2057, 3857, 2013, 2045, 2059, 2057, 2064, 2175, 2000, 1996, 2279, 3357, 2029, 2003, 2008, 2065, 1996, 4153, 9932, 1050, 1005, 1056, 3407, 9932, 1050, 1005, 1056, 6343, 3407