In [1]:
import os
import time

import import_ipynb
import ujson as json
import copy
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class MySet(Dataset):
    def __init__(self,file):
        super(MySet, self).__init__()
        self.content = open(file).readlines()
        indices = np.arange(len(self.content))
        val_indices = np.random.choice(indices, len(self.content) // 5)
        self.val_indices = set(val_indices.tolist())

    def __len__(self):
        return len(self.content)

    def __getitem__(self, idx):
        rec = json.loads(self.content[idx])
        if idx in self.val_indices:
            rec['is_train'] = 0
        else:
            rec['is_train'] = 1
        return rec

def collate_fn(recs):
    labels=map(lambda x: x['label'], recs)
    is_train=map(lambda x: x['is_train'], recs)
    forward = map(lambda x: x['forward'], recs)
    backward = map(lambda x: x['backward'], recs)

    def to_tensor_dict(recs):
        recs=list(recs)
        a,b,c,d,e,f=[copy.deepcopy(recs) for i in range(6)]
        values = torch.FloatTensor(list(map(lambda r: list(map(lambda x: x['values'], r)), a)))
        masks = torch.FloatTensor(list(map(lambda r: list(map(lambda x: x['masks'], r)), b)))
        deltas = torch.FloatTensor(list(map(lambda r: list(map(lambda x: x['deltas'], r)), c)))
        evals = torch.FloatTensor(list(map(lambda r: list(map(lambda x: x['evals'], r)), e)))
        eval_masks = torch.FloatTensor(list(map(lambda r: list(map(lambda x: x['eval_masks'], r)), f)))
        return {'values': values, 'masks': masks, 'deltas': deltas, 'evals': evals, 'eval_masks': eval_masks}
    
    ret_dict = {'forward': to_tensor_dict(forward), 'backward': to_tensor_dict(backward)}
    
    return ret_dict

def get_loader(file, batch_size = 64, shuffle = True):
    data_set = MySet(file)
    data_iter = DataLoader(dataset = data_set, 
                              batch_size = batch_size, 
                              num_workers = 0, 
                              shuffle = shuffle, 
                              pin_memory = True, 
                              collate_fn = collate_fn
    )

    return data_iter