In [1]:
from pathlib2 import Path

In [2]:
from torch.utils.data import Dataset, DataLoader

In [3]:
class SegmentationDataset(Dataset):
    def __init__(self, root_dir: str):
        PAD_STR = '<pad>'
        BOUNDARY_STR = '<boundary>'
        self.examples = []
        self.targets = []
        all_objects = Path(root_dir).glob('**/*')
        self.filenames = [str(p) for p in all_objects if p.is_file() and str(p).split("/")[-1] != '.DS_Store']
        passages = []
        for filename in self.filenames:
            file = open(str(filename), "rt", encoding="utf8")
            raw_content = file.read()
            file.close()
            clean_txt = raw_content.strip()
            sentences = [s for s in clean_txt.split("\n") if len(s) > 0 and s != "\n"]
            passages.append(sentences)
        max_len = max([len(s) for s in passages])
        for passage in passages:
            target = [0]
            for i in range(1,len(passage)):
                if passage[i - 1][:3] == '===':
                    target[-1] = 1
                    passage[i - 1] = BOUNDARY_STR
                else:
                    target.append(0)
            target += [0]*(max_len - len(target))
            self.targets.append(target)
            example = [word for word in passage if word != BOUNDARY_STR]
            example += [PAD_STR]*(max_len - len(example))
            self.examples.append(example)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        ret_vals = {}
        ret_vals['sentences'] = self.examples[idx]
        ret_vals['target'] = self.targets[idx]
        return ret_vals

In [4]:
dataset = SegmentationDataset('data')

In [5]:
len(dataset)

50

In [6]:
len(dataset.examples)

50

In [7]:
len(dataset.filenames)

50

In [8]:
len(dataset[0]['sentences'])

449

In [9]:
len(dataset[1]['sentences'])

449

In [10]:
len(dataset[0]['target'])

449

In [11]:
len(dataset[1]['target'])

449

In [12]:
dataset[0]['sentences']

['WTTE, channel 28, is a Fox-affiliated television station located in Columbus, Ohio, USA.',
 "WTTE's broadcast license is owned by Cunningham Broadcasting, while the station's operations are controlled via local marketing agreement (LMA) by the Sinclair Broadcast Group, WTTE's original owners and present proprietors of ABC affiliate WSYX (channel 6).",
 'Sinclar Broadcast Group also operates Chillicothe-licensed, CW-affiliated WWHO (channel 53), through a shared services agreement with Manhan Media.',
 "The three stations share studios on Dublin Road in Grandview Heights, a suburb of Columbus; WTTE and WSYX also share a transmitter on Columbus' west side.",
 'WTTE began operations on June 1, 1984 as the first general-entertainment independent station in central Ohio.',
 'The station was founded by the Commercial Radio Institute, a subsidiary of the Baltimore-based Sinclair Broadcast Group.',
 'WTTE quickly became the dominant independent station in the area largely because its program

In [13]:
dataset[0]['target']

[1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,


In [14]:
dataloader = DataLoader(dataset, batch_size=4,
                        shuffle=True, num_workers=0)

In [15]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, len(sample_batched['sentences']),
          len(sample_batched['target']))

0 449 449
1 449 449
2 449 449
3 449 449
4 449 449
5 449 449
6 449 449
7 449 449
8 449 449
9 449 449
10 449 449
11 449 449
12 449 449
