In [1]:
import os, glob
import numpy as np
import random
import json
import copy

Given a series of classes, these must be split equally into training validation and holdout datasets.

now, in order to generate a series of minibatches, each of these must be randomized

In [2]:
class TextDataSet(object):
    
    def __init__(self, filepath='../data/20news-18828', length=20000):
        self.basepath = filepath
        self.length=length
        self.class_map={}
        self.classes = os.listdir(filepath)
        for index, value in enumerate(self.classes):
            self.class_map[value] = index
        self.dataset = None
        
    def load(self, class_map, dataset):
        with open(class_map, 'r') as _file:
            self.class_map = copy.copy(json.load(_file))
        with open(dataset, 'r') as _file:
            self.dataset = copy.copy(json.load(_file))
        for cls in self.class_map:
            random.shuffle(self.dataset[str(self.class_map[cls])])
        
    def create_datasets(self):

        train = {}
        val = {}
        test ={}
        
        for i in self.classes:
            train[self.class_map[i]]=[]
            val[self.class_map[i]]=[]
            test[self.class_map[i]]=[]
            for filename in glob.glob(os.path.join(self.basepath, i, '*')):
                r = np.random.random_sample()
                if r > 0.95:
                    test[self.class_map[i]].append(filename)
                elif r > 0.9:
                    val[self.class_map[i]].append(filename)
                else:
                    train[self.class_map[i]].append(filename)
            random.shuffle(train[self.class_map[i]])
            random.shuffle(test[self.class_map[i]])
            random.shuffle(val[self.class_map[i]])
            
        with open('train.json', 'w') as output:
            json.dump(train, output)
        with open('test.json', 'w') as output:
            json.dump(test, output)
        with open('val.json', 'w') as output:
            json.dump(val, output)                
                
        with open('class_map.json', 'w') as output:
            json.dump(self.class_map, output)

    def get_text(self, filename):
        output= np.ndarray(shape=(self.length,), dtype=np.integer)
        index = 0
        with open(filename, 'r', encoding='utf-8', errors='ignore') as input_file:
            for line in input_file.readlines():
                for char in line:
                    if index >= self.length:
                        break
                    output[index] = self.decode_character(char)
                    index += 1
        return output
            
    def decode_character(self, char):
        try:
            return ord(char)
        except UnicodeDecodeError:
            return 0
    
    def get_random_filenames(self):
        tmp = []
        for cls in  self.class_map:
            try:
                tmp.append( (self.dataset[str(self.class_map[cls])].pop(), self.class_map[cls]))
            except IndexError:
                raise StopIteration
        random.shuffle(tmp)
        return [i[0] for i in tmp], [i[1] for i in tmp]
    
    def __iter__(self):
        return self
    
    def __next__(self):
        return self.next()
    
    def next(self):
        x, y = self.get_random_filenames()
        tmp_x = []
        for i in x:
            encoding = self.get_text(i)
            tmp_x.append(encoding)
        x = tmp_x    
        tmp_x = np.zeros(shape=(len(self.class_map), self.length))
        for index, arr in enumerate(x):
            tmp_x[index][:len(arr)] = arr
        # So, now we're going to have to create the Y matrix
        tmp_y = np.zeros(shape=(20,20))
        for index, value in enumerate(y):
                tmp_y[index, value] =1
        return tmp_x, tmp_y

In [3]:
dataset = TextDataSet()
dataset.create_datasets()

In [4]:
train_set = TextDataSet()
train_set.load('class_map.json', 'train.json')

In [5]:
import time
now = time.time()
length=0
for i in train_set:
    length+=1
print(time.time()-now)
print(length)

45.516915798187256
575
