In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
pretrained_embedding_path = "./data/pretrained_wordvector/sgns.sogou.char"
embed = []
word2idx = dict()
idx2word = dict()

size = None
with open(pretrained_embedding_path, "r") as f:
    idx = 0
    for line in tqdm(f):
        x = line.strip().split(' ')
        word = x[0]
        vector = np.asarray(x[1:], dtype=np.float32)
        size = vector.shape
        embed.append(vector)
        word2idx[word] = idx
        idx2word[idx] = word
        idx += 1

365077it [00:13, 27965.93it/s]


In [3]:
avg = sum(embed) / len(embed)
embed.append(avg)
word2idx['<UNK>'] = idx
idx2word[idx] = '<UNK>'
idx += 1

embed.append(np.random.normal(size=size))
word2idx['<PAD>'] = idx
idx2word[idx] = '<PAD>'
idx += 1



In [4]:
class MyDataset(Dataset):
    def __init__(self, path):
        self.data = []
        with open(path, "r") as f:
            for line in tqdm(f):
                x = line.split('\t')
                sen, label = x[0], x[1]
                self.data.append((sen, int(label)))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
                
    def collate(self, batchs):
        sen_out = []
        tot_sen = [pair[0] for pair in batchs]
        tot_label = [pair[1] for pair in batchs]
        max_len = max([len(sen) for sen in tot_sen])
        for sen in tot_sen:
            temp = []
            for ch in sen:
                if (ch not in word2idx):
                    temp.append(word2idx['<UNK>'])
                else:
                    temp.append(word2idx[ch])
            temp += [word2idx['<PAD>']] * (max_len - len(sen))
            sen_out.append(temp)
        return np.array(sen_out), np.array(tot_label)
                
            
            
            
        

In [5]:
test_dataset = MyDataset("./data/test.txt")
valid_dataset = MyDataset("./data/valid.txt")

83607it [00:00, 1334042.85it/s]
83606it [00:00, 1573826.39it/s]


In [6]:
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, collate_fn=test_dataset.collate)

In [14]:
for batch in test_dataloader:
    print(batch[0], batch[1])
    print(batch[0].shape)
    print(batch[1].shape)
    break

[[    49   1320   7302 ... 365078 365078 365078]
 [  2111    921   1993 ... 365078 365078 365078]
 [ 55875   9033  20355 ... 365078 365078 365078]
 ...
 [  8155  26842   8361 ... 365078 365078 365078]
 [   221   1123     35 ... 365078 365078 365078]
 [ 64215  44836   5652 ... 365078 365078 365078]] [ 0 11 11 11 11 11  7 11  3  3  7  3  3 11  3 13  8  7  8 13  2  2  8  7
  3  3 10  7 12  8  1  0  0  7  8 11  3 11  5  8  3 11  7  3  3 11  7 10
  7  7  2  3 11  8  7  3  7 11  3 12 11 11  6  2  1 11  7  7  2  8  7 11
 11  7  7  7  1 11  3  3 13  7 11  8  8  3  3 10 10 11  8  3 13 12  7  8
  1 11 11  2  3 13 11  8  3  2  0  3  8  3  6 11  0  0  0  8  7 11  8  3
  2  2  7 10  8  8  1  7]
(128, 26)
(128,)
