In [2]:
# coding: UTF-8
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
from utils import build_dataset, build_iterator, get_time_dif
import argparse
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN,TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
#args = parser.parse_args()

if __name__ == '__main__':
    dataset = '/home/kayzhou/zhangyue/text'  # 数据集

    # 推文利用glove初始化 embedding_tweet.npz, 随机初始化:random
    embedding = "embedding_tweet.npz"
    model_name = "Transformer"  # 'TextRCNN'  # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
    x = import_module("models." + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config,True)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter)

Loading data...
Vocab size: 100002


3760566it [01:21, 46370.58it/s]
417841it [00:07, 57656.94it/s]
4000it [00:00, 58102.31it/s]


Time usage: 0:01:30
<bound method Module.parameters of Model(
  (embedding): Embedding(100002, 200)
  (postion_embedding): Positional_Encoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (encoder): Encoder(
    (attention): Multi_Head_Attention(
      (fc_Q): Linear(in_features=200, out_features=200, bias=True)
      (fc_K): Linear(in_features=200, out_features=200, bias=True)
      (fc_V): Linear(in_features=200, out_features=200, bias=True)
      (attention): Scaled_Dot_Product_Attention()
      (fc): Linear(in_features=200, out_features=200, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
      (layer_norm): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
    )
    (feed_forward): Position_wise_Feed_Forward(
      (fc1): Linear(in_features=200, out_features=1024, bias=True)
      (fc2): Linear(in_features=1024, out_features=200, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
      (layer_norm): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
