In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [60]:
class Config(object):
    def __init__(self, dataset, embedding):
        self.model_name = 'TextRNN'
        self.train_path = dataset + 'data/train.txt'
        self.dev_path = dataset + 'data/dev.txt'
        self.test_path = dataset + 'data/test.txt'
        self.class_list = [x.strip() for x in open(
            dataset + 'data/class.txt'
        ).readlines()]
        
        self.vocab_path = dataset + 'data/vocab.pkl'
        self.save_pat = dataset + 'data/saved_dict/' + self.model_name + '.ckpt'
        self.embedding_pretrained = torch.tensor(
            np.load(dataset + '/data/' + embedding)['embeddings'].astype('float32')\
            if embedding != 'random' else None
        )
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.dropout = 0.5
        self.require_improvement = 1000
        self.num_classes = len(self.class_list)
        self.n_vocab = 0
        self.num_epochs = 10
        self.batch_size = 128
        self.pad_size = 32
        self.learning_rate = 1e-3
        self.embed = self.embedding_pretrained.size(1)\
            if self.embedding_pretrained is not None else 300
        self.hidden_size = 128
        
        self.num_layers = 2

class TextRNN(nn.Module):
    def __init__(self, embedding_pretrained, num_classes):
        super(TextRNN, self).__init__()
        if embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(embedding_pretrained, freeze=False)
            
        else:
            raise 'not pretrained embedding'
#         else:
#             self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab-1)

            
        self.lstm = nn.LSTM(embedding_pretrained.size(1), 128, 2,
                           bidirectional=True, batch_first=True, dropout=0.5)
        self.fc = nn.Linear(128 * 2, num_classes)
        
    def forward(self, x):
        x, _ = x
        out = self.embedding(x)
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])
        return out



In [None]:
def build_dataset(config, ues_word):
    if ues_word:
        tokenizer = lambda x: x.split(' ')  # 以空格隔开，word-level
    else:
        tokenizer = lambda x: [y for y in x]  # char-level
    if os.path.exists(config.vocab_path):
        vocab = pkl.load(open(config.vocab_path, 'rb'))
    else:
        vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
        pkl.dump(vocab, open(config.vocab_path, 'wb'))
    print(f"Vocab size: {len(vocab)}")

    def biGramHash(sequence, t, buckets):
        t1 = sequence[t - 1] if t - 1 >= 0 else 0
        return (t1 * 14918087) % buckets

    def triGramHash(sequence, t, buckets):
        t1 = sequence[t - 1] if t - 1 >= 0 else 0
        t2 = sequence[t - 2] if t - 2 >= 0 else 0
        return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets

    def load_dataset(path, pad_size=32):
        contents = []
        with open(path, 'r', encoding='UTF-8') as f:
            for line in tqdm(f):
                lin = line.strip()
                if not lin:
                    continue
                content, label = lin.split('\t')
                words_line = []
                token = tokenizer(content)
                seq_len = len(token)
                if pad_size:
                    if len(token) < pad_size:
                        token.extend([vocab.get(PAD)] * (pad_size - len(token)))
                    else:
                        token = token[:pad_size]
                        seq_len = pad_size
                # word to id
                for word in token:
                    words_line.append(vocab.get(word, vocab.get(UNK)))

                # fasttext ngram
                buckets = config.n_gram_vocab
                bigram = []
                trigram = []
                # ------ngram------
                for i in range(pad_size):
                    bigram.append(biGramHash(words_line, i, buckets))
                    trigram.append(triGramHash(words_line, i, buckets))
                # -----------------
                contents.append((words_line, int(label), seq_len, bigram, trigram))
        return contents  # [([...], 0), ([...], 1), ...]
    train = load_dataset(config.train_path, config.pad_size)
    dev = load_dataset(config.dev_path, config.pad_size)
    test = load_dataset(config.test_path, config.pad_size)
    return vocab, train, dev, test


In [5]:
import pickle as pkl
vocab = pkl.load(open('./rnn_data/vocab.pkl','rb'))
len(vocab)

4762

In [8]:
from tqdm import tqdm

def load_dataset(path, pad_size=32):
    contents = []
    with open(path, 'r', encoding='UTF-8') as f:
        for line in tqdm(f):
            lin = line.strip()
            if not lin:
                continue
            content, label = lin.split('\t')
            words_line = []
            token = [x for x in content]
            seq_len = len(token)
            if pad_size:
                if len(token) < pad_size:
                    token.extend([vocab.get('<PAD>')] * (pad_size - len(token)))
                else:
                    token = token[:pad_size]
                    seq_len = pad_size
            # word to id
            for word in token:
                words_line.append(vocab.get(word, vocab.get('<UNK>')))
            contents.append((words_line, int(label), seq_len))
    return contents  # [([...], 0), ([...], 1), ...]

In [9]:
train_ds = load_dataset('./rnn_data/train.txt')
dev_ds = load_dataset('./rnn_data/dev.txt')
test_ds = load_dataset('./rnn_data/test.txt')


180000it [00:01, 128525.50it/s]
10000it [00:00, 54538.63it/s]
10000it [00:00, 157919.86it/s]


In [82]:
class DatasetIterater(object):
    def __init__(self, batches, batch_size, device):
        self.batches = batches
        self.batch_size = batch_size
        self.n_batches = len(batches) // batch_size
        self.residue = (len(batches) % batch_size == 0)
        
        self.index = 0
        self.device = device
        
    def _to_tensor(self, datas):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
        seq_len = torch.LongTensor([len(_[0]) for _ in datas]).to(self.device)
        
        return (x, seq_len),y
    
    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index*self.batch_size : len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches
        
        if self.index > self.n_batches:
            self.index = 0
            raise StopIteration
            pass
        else:
            batches = self.batches[self.index*self.batch_size: (self.index+1)*self.batch_size]
            batches = self._to_tensor(batches)
            self.index += 1
#             print('n',batches)
            return batches
    
    def __iter__(self):
        return self
    
    def __len__(self):
        return self.n_batches + (1 if self.residue else 0)

In [83]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dl = DatasetIterater(train_ds, batch_size=128, device=device)
dev_dl = DatasetIterater(dev_ds, batch_size=128, device=device)
test_dl = DatasetIterater(test_ds, batch_size=128, device=device)


In [62]:
embeding_pretrained = torch.tensor(
    np.load('./rnn_data/embedding_Tencent.npz')['embeddings'].astype('float32')
)
embeding_pretrained.shape
class_list = [x.strip() for x in open(
        './rnn_data/class.txt'
    ).readlines()]
class_list
num_classes = len(class_list)


model_rt = TextRNN(embeding_pretrained, num_classes)
model_rt = model_rt.to(device)

In [None]:
dataset = 'THUCNews'

embedding = 'embedding_SougouNews.npz'
model_name = 'TextRNN'

config = Config(dataset, embedding)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True

import time
start_time = time.time()
print('Load data ...')


In [95]:
from sklearn import metrics

def train(model, train_dl, dev_dl,epoches, writer = None):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    total_batch = 0
    for epoch in (range(epoches)):
        print('{}/{}'.format(epoch, epoches))
        
        for x, labels in train_dl:
            model.train()
            outputs = model(x)
            loss = F.cross_entropy(outputs, labels)
            optimizer.zero_grad()  # TODO ???
            loss.backward()
            optimizer.step()
            
            if total_batch % 100 == 0:
                true = labels.data.cpu()
                predict = torch.max(outputs.data, 1)[1].cpu()
#                 print(true)
#                 print(predict)
                train_acc = metrics.accuracy_score(true, predict)
                print('train_acc = {:.4f}'.format(train_acc))
                dev_acc, dev_loss = evalute(model, dev_dl )
                print('dev_acc = {:.4f}, dev_loss={:.4f}'.format(dev_acc, dev_loss))
                if writer is not None:
                    writer.add_scalar('loss/train', loss.item(), total_batch)
                    writer.add_scalar('loss/dev', dev_loss, total_batch)
                    writer.add_scalar('acc/train', train_acc, total_batch)
                    writer.add_scalar('acc/dev', dev_acc, total_batch)
                
            total_batch += 1
            

In [96]:
from tensorboardX import SummaryWriter
import time
writer = SummaryWriter(log_dir='./log/' + time.strftime('%m-%d_%H.%M', time.localtime()))

train(model_rt, train_dl, dev_dl,100, writer)
# type(train_dl)
# for x, y in train_dl:
#     print(y)

0/100
train_acc = 1.0000
dev_acc = 0.8855, dev_loss=0.6683
train_acc = 0.9844
dev_acc = 0.8882, dev_loss=0.6510
train_acc = 0.9922
dev_acc = 0.8887, dev_loss=0.6322
train_acc = 0.9844
dev_acc = 0.8854, dev_loss=0.6910
train_acc = 0.9844
dev_acc = 0.8878, dev_loss=0.6773
train_acc = 0.9922
dev_acc = 0.8871, dev_loss=0.6927
train_acc = 0.9844
dev_acc = 0.8855, dev_loss=0.6654
train_acc = 0.9609
dev_acc = 0.8859, dev_loss=0.6355
train_acc = 0.9922
dev_acc = 0.8872, dev_loss=0.6720
train_acc = 0.9844
dev_acc = 0.8845, dev_loss=0.6962
train_acc = 0.9844
dev_acc = 0.8837, dev_loss=0.6584
1/100
train_acc = 1.0000
dev_acc = 0.8859, dev_loss=0.6824
train_acc = 0.9922
dev_acc = 0.8866, dev_loss=0.6924
train_acc = 0.9922
dev_acc = 0.8861, dev_loss=0.6552
train_acc = 0.9922
dev_acc = 0.8866, dev_loss=0.6683
train_acc = 1.0000
dev_acc = 0.8850, dev_loss=0.6979
train_acc = 0.9766
dev_acc = 0.8847, dev_loss=0.6766
train_acc = 0.9922
dev_acc = 0.8858, dev_loss=0.6813
train_acc = 0.9844
dev_acc = 0.886

dev_acc = 0.8883, dev_loss=0.6978
train_acc = 1.0000
dev_acc = 0.8882, dev_loss=0.6753
train_acc = 1.0000
dev_acc = 0.8888, dev_loss=0.6712
train_acc = 1.0000
dev_acc = 0.8839, dev_loss=0.7041
train_acc = 1.0000
dev_acc = 0.8861, dev_loss=0.6760
train_acc = 1.0000
dev_acc = 0.8851, dev_loss=0.7031
train_acc = 0.9844
dev_acc = 0.8875, dev_loss=0.6804
train_acc = 0.9844
dev_acc = 0.8882, dev_loss=0.6963
train_acc = 0.9844
dev_acc = 0.8880, dev_loss=0.6898
train_acc = 0.9922
dev_acc = 0.8872, dev_loss=0.6837
train_acc = 0.9844
dev_acc = 0.8866, dev_loss=0.7096
train_acc = 1.0000
dev_acc = 0.8903, dev_loss=0.7181
train_acc = 1.0000
dev_acc = 0.8872, dev_loss=0.7058
12/100
train_acc = 1.0000
dev_acc = 0.8834, dev_loss=0.7167
train_acc = 0.9922
dev_acc = 0.8858, dev_loss=0.7253
train_acc = 0.9922
dev_acc = 0.8872, dev_loss=0.7228
train_acc = 0.9922
dev_acc = 0.8850, dev_loss=0.7585
train_acc = 0.9922
dev_acc = 0.8852, dev_loss=0.7053
train_acc = 0.9844
dev_acc = 0.8838, dev_loss=0.6971
train

22/100
train_acc = 1.0000
dev_acc = 0.8888, dev_loss=0.7169
train_acc = 1.0000
dev_acc = 0.8915, dev_loss=0.6963
train_acc = 1.0000
dev_acc = 0.8858, dev_loss=0.7093
train_acc = 1.0000
dev_acc = 0.8904, dev_loss=0.7312
train_acc = 1.0000
dev_acc = 0.8896, dev_loss=0.7294
train_acc = 0.9922
dev_acc = 0.8887, dev_loss=0.7362
train_acc = 0.9922
dev_acc = 0.8892, dev_loss=0.7202
train_acc = 1.0000
dev_acc = 0.8877, dev_loss=0.7271
train_acc = 1.0000
dev_acc = 0.8897, dev_loss=0.6931
train_acc = 1.0000
dev_acc = 0.8879, dev_loss=0.7004
train_acc = 1.0000
dev_acc = 0.8892, dev_loss=0.6853
train_acc = 1.0000
dev_acc = 0.8894, dev_loss=0.6937
train_acc = 0.9844
dev_acc = 0.8865, dev_loss=0.7508
train_acc = 0.9922
dev_acc = 0.8907, dev_loss=0.7279
23/100
train_acc = 1.0000
dev_acc = 0.8882, dev_loss=0.7324
train_acc = 0.9844
dev_acc = 0.8910, dev_loss=0.7238
train_acc = 1.0000
dev_acc = 0.8883, dev_loss=0.7347
train_acc = 0.9844
dev_acc = 0.8881, dev_loss=0.7309
train_acc = 1.0000
dev_acc = 0.8

dev_acc = 0.8884, dev_loss=0.7657
33/100
train_acc = 0.9844
dev_acc = 0.8882, dev_loss=0.7476
train_acc = 1.0000
dev_acc = 0.8864, dev_loss=0.7515
train_acc = 1.0000
dev_acc = 0.8889, dev_loss=0.7485
train_acc = 0.9844
dev_acc = 0.8889, dev_loss=0.7371
train_acc = 0.9922
dev_acc = 0.8888, dev_loss=0.7401
train_acc = 0.9922
dev_acc = 0.8888, dev_loss=0.7603
train_acc = 0.9922
dev_acc = 0.8880, dev_loss=0.7529
train_acc = 1.0000
dev_acc = 0.8881, dev_loss=0.7727
train_acc = 1.0000
dev_acc = 0.8857, dev_loss=0.7873
train_acc = 1.0000
dev_acc = 0.8856, dev_loss=0.7962
train_acc = 0.9844
dev_acc = 0.8875, dev_loss=0.7588
train_acc = 1.0000
dev_acc = 0.8871, dev_loss=0.7669
train_acc = 0.9922
dev_acc = 0.8896, dev_loss=0.7293
train_acc = 1.0000
dev_acc = 0.8880, dev_loss=0.7367
train_acc = 1.0000
dev_acc = 0.8898, dev_loss=0.7387
34/100
train_acc = 0.9922
dev_acc = 0.8901, dev_loss=0.7396
train_acc = 0.9922
dev_acc = 0.8904, dev_loss=0.7372
train_acc = 1.0000
dev_acc = 0.8882, dev_loss=0.747

train_acc = 1.0000
dev_acc = 0.8889, dev_loss=0.7483
train_acc = 0.9922
dev_acc = 0.8907, dev_loss=0.7382
44/100
train_acc = 0.9922
dev_acc = 0.8918, dev_loss=0.7562
train_acc = 1.0000
dev_acc = 0.8903, dev_loss=0.7882
train_acc = 1.0000
dev_acc = 0.8888, dev_loss=0.7745
train_acc = 0.9922
dev_acc = 0.8869, dev_loss=0.7919
train_acc = 1.0000
dev_acc = 0.8878, dev_loss=0.7886
train_acc = 1.0000
dev_acc = 0.8901, dev_loss=0.7695
train_acc = 1.0000
dev_acc = 0.8891, dev_loss=0.7695
train_acc = 1.0000
dev_acc = 0.8901, dev_loss=0.7519
train_acc = 0.9922
dev_acc = 0.8899, dev_loss=0.7591
train_acc = 0.9922
dev_acc = 0.8886, dev_loss=0.7572
train_acc = 1.0000
dev_acc = 0.8871, dev_loss=0.7772
train_acc = 1.0000
dev_acc = 0.8883, dev_loss=0.7565
train_acc = 1.0000
dev_acc = 0.8890, dev_loss=0.7584
train_acc = 1.0000
dev_acc = 0.8904, dev_loss=0.7566
45/100
train_acc = 0.9922
dev_acc = 0.8901, dev_loss=0.7579
train_acc = 1.0000
dev_acc = 0.8874, dev_loss=0.7817
train_acc = 0.9922
dev_acc = 0.8

dev_acc = 0.8881, dev_loss=0.7527
train_acc = 0.9922
dev_acc = 0.8893, dev_loss=0.7432
train_acc = 0.9922
dev_acc = 0.8874, dev_loss=0.7468
train_acc = 1.0000
dev_acc = 0.8892, dev_loss=0.7324
55/100
train_acc = 1.0000
dev_acc = 0.8898, dev_loss=0.7483
train_acc = 1.0000
dev_acc = 0.8896, dev_loss=0.7760
train_acc = 1.0000
dev_acc = 0.8888, dev_loss=0.7428
train_acc = 1.0000
dev_acc = 0.8885, dev_loss=0.7590
train_acc = 1.0000
dev_acc = 0.8873, dev_loss=0.7967
train_acc = 1.0000
dev_acc = 0.8896, dev_loss=0.7816
train_acc = 0.9922
dev_acc = 0.8870, dev_loss=0.7957
train_acc = 0.9922
dev_acc = 0.8898, dev_loss=0.7673
train_acc = 0.9922
dev_acc = 0.8887, dev_loss=0.7767
train_acc = 1.0000
dev_acc = 0.8876, dev_loss=0.7616
train_acc = 1.0000
dev_acc = 0.8859, dev_loss=0.7696
train_acc = 0.9922
dev_acc = 0.8870, dev_loss=0.7641
train_acc = 1.0000
dev_acc = 0.8867, dev_loss=0.7746
train_acc = 0.9922
dev_acc = 0.8901, dev_loss=0.7733
56/100
train_acc = 1.0000
dev_acc = 0.8900, dev_loss=0.768

train_acc = 0.9922
dev_acc = 0.8890, dev_loss=0.7456
train_acc = 0.9922
dev_acc = 0.8915, dev_loss=0.7561
train_acc = 1.0000
dev_acc = 0.8923, dev_loss=0.7513
train_acc = 1.0000
dev_acc = 0.8942, dev_loss=0.7656
train_acc = 1.0000
dev_acc = 0.8913, dev_loss=0.7377
66/100
train_acc = 0.9922
dev_acc = 0.8917, dev_loss=0.7560
train_acc = 0.9844
dev_acc = 0.8917, dev_loss=0.7453
train_acc = 0.9922
dev_acc = 0.8897, dev_loss=0.7605
train_acc = 1.0000
dev_acc = 0.8914, dev_loss=0.7623
train_acc = 0.9922
dev_acc = 0.8908, dev_loss=0.7606
train_acc = 1.0000
dev_acc = 0.8895, dev_loss=0.7569
train_acc = 1.0000
dev_acc = 0.8873, dev_loss=0.7631
train_acc = 1.0000
dev_acc = 0.8901, dev_loss=0.7737
train_acc = 1.0000
dev_acc = 0.8898, dev_loss=0.7505
train_acc = 1.0000
dev_acc = 0.8904, dev_loss=0.7540
train_acc = 0.9844
dev_acc = 0.8882, dev_loss=0.7779
train_acc = 1.0000
dev_acc = 0.8886, dev_loss=0.7516
train_acc = 1.0000
dev_acc = 0.8857, dev_loss=0.7772
train_acc = 1.0000
dev_acc = 0.8895, de

dev_acc = 0.8906, dev_loss=0.7531
train_acc = 1.0000
dev_acc = 0.8876, dev_loss=0.7360
train_acc = 1.0000
dev_acc = 0.8902, dev_loss=0.7539
train_acc = 1.0000
dev_acc = 0.8898, dev_loss=0.7682
train_acc = 1.0000
dev_acc = 0.8885, dev_loss=0.7907
train_acc = 1.0000
dev_acc = 0.8897, dev_loss=0.7721
train_acc = 1.0000
dev_acc = 0.8888, dev_loss=0.7504
77/100
train_acc = 1.0000
dev_acc = 0.8922, dev_loss=0.7494
train_acc = 0.9922
dev_acc = 0.8902, dev_loss=0.7448
train_acc = 1.0000
dev_acc = 0.8920, dev_loss=0.7342
train_acc = 1.0000
dev_acc = 0.8902, dev_loss=0.7616
train_acc = 0.9922
dev_acc = 0.8908, dev_loss=0.7354
train_acc = 0.9922
dev_acc = 0.8866, dev_loss=0.7535
train_acc = 1.0000
dev_acc = 0.8856, dev_loss=0.7753
train_acc = 0.9844
dev_acc = 0.8890, dev_loss=0.7926
train_acc = 1.0000
dev_acc = 0.8874, dev_loss=0.7659
train_acc = 1.0000
dev_acc = 0.8900, dev_loss=0.7555
train_acc = 1.0000
dev_acc = 0.8883, dev_loss=0.7712
train_acc = 1.0000
dev_acc = 0.8899, dev_loss=0.7719
train

train_acc = 0.9844
dev_acc = 0.8919, dev_loss=0.7586
train_acc = 0.9922
dev_acc = 0.8893, dev_loss=0.7625
train_acc = 1.0000
dev_acc = 0.8894, dev_loss=0.7625
train_acc = 1.0000
dev_acc = 0.8883, dev_loss=0.7595
train_acc = 0.9922
dev_acc = 0.8894, dev_loss=0.7860
train_acc = 1.0000
dev_acc = 0.8889, dev_loss=0.7809
train_acc = 1.0000
dev_acc = 0.8894, dev_loss=0.7679
88/100
train_acc = 1.0000
dev_acc = 0.8894, dev_loss=0.7765
train_acc = 0.9922
dev_acc = 0.8868, dev_loss=0.7753
train_acc = 1.0000
dev_acc = 0.8875, dev_loss=0.7541
train_acc = 1.0000
dev_acc = 0.8904, dev_loss=0.7670
train_acc = 0.9922
dev_acc = 0.8885, dev_loss=0.7666
train_acc = 1.0000
dev_acc = 0.8886, dev_loss=0.7921
train_acc = 1.0000
dev_acc = 0.8888, dev_loss=0.7818
train_acc = 1.0000
dev_acc = 0.8870, dev_loss=0.8156
train_acc = 0.9922
dev_acc = 0.8897, dev_loss=0.7930
train_acc = 1.0000
dev_acc = 0.8882, dev_loss=0.7906
train_acc = 1.0000
dev_acc = 0.8905, dev_loss=0.7986
train_acc = 1.0000
dev_acc = 0.8897, de

dev_acc = 0.8894, dev_loss=0.7946
train_acc = 1.0000
dev_acc = 0.8899, dev_loss=0.7921
train_acc = 1.0000
dev_acc = 0.8879, dev_loss=0.7942
train_acc = 1.0000
dev_acc = 0.8905, dev_loss=0.8129
train_acc = 1.0000
dev_acc = 0.8916, dev_loss=0.8171
train_acc = 1.0000
dev_acc = 0.8874, dev_loss=0.8124
train_acc = 1.0000
dev_acc = 0.8896, dev_loss=0.7847
train_acc = 0.9922
dev_acc = 0.8903, dev_loss=0.7959
train_acc = 0.9844
dev_acc = 0.8892, dev_loss=0.7847
99/100
train_acc = 1.0000
dev_acc = 0.8907, dev_loss=0.7927
train_acc = 1.0000
dev_acc = 0.8878, dev_loss=0.8070
train_acc = 1.0000
dev_acc = 0.8892, dev_loss=0.7875
train_acc = 1.0000
dev_acc = 0.8902, dev_loss=0.7796
train_acc = 0.9766
dev_acc = 0.8902, dev_loss=0.7744
train_acc = 0.9922
dev_acc = 0.8891, dev_loss=0.7735
train_acc = 0.9922
dev_acc = 0.8896, dev_loss=0.7593
train_acc = 1.0000
dev_acc = 0.8902, dev_loss=0.7903
train_acc = 1.0000
dev_acc = 0.8888, dev_loss=0.7798
train_acc = 0.9922
dev_acc = 0.8865, dev_loss=0.8037
train

In [76]:
import numpy as np
def evalute(model, data_dl, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype = int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for texts, labels in data_dl:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu()
            predict = torch.max(outputs.data.cpu(), 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predict)
            
        acc = metrics.accuracy_score(labels_all, predict_all)
        
    return acc, loss_total / len(data_dl)

In [122]:
# x = (torch.rand(1, 32, embeding_pretrained.size(1)).to(device), torch.LongTensor([1]).to(device))
x
with SummaryWriter(comment='Net1') as w:
    w.add_graph(model_rt, (x,))

  if self.input_size != input.size(-1):
  if hx.size() != expected_hidden_size:


(tensor([[[                  1,      94149973749648,                  24,
            ...,      94149973751056, 8243124912085009483,
                94149973751088],
          [7016431369869852682,      94149973751184, 4193527942359286840,
            ...,      94149973751088,      94149973752448,
           3268325118455054353],
          [     94149887632056,      94149973752792,                   0,
            ..., 5575186034917900313,      94149973754392,
           3180189357294223361],
          ...,
          [8387201278994620457, 3327632609719377755, 7305999817851800110,
            ..., 7214801925396114990, 8316288633764079205,
           7310231100686167914],
          [5269275830158911084, 6585789199896700276, 2317486216248775775,
            ..., 2314885530453827951, 7955995078896805220,
           7810775745979970665],
          [8390876208520244326, 8295679392196665402, 8389758742978129780,
            ..., 2916483597133571428, 7089073068528197471,
           82431225503

In [117]:
for x, label in train_dl:
    print(x)
    break

(tensor([[   5,  167,    2,  ..., 4760, 4760, 4760],
        [  20, 1473, 1456,  ..., 4760, 4760, 4760],
        [ 217,  647,    4,  ..., 4760, 4760, 4760],
        ...,
        [ 220,  617,   15,  ..., 4760, 4760, 4760],
        [  14,    6,   54,  ..., 4760, 4760, 4760],
        [  20,  435,  979,  ..., 4760, 4760, 4760]], device='cuda:0'), tensor([32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32], device='cuda:0'))
