# Submit

## Data

In [1]:
from utils.submit import create_test_indices
create_test_indices(
    'data/process/test.txt', 
    'data/origin/Test.txt', 
    'data/process/vocab.txt'
)

In [3]:
import torch
from utils.submit import TestNewsDataset, collate_fn_test
batch_size = 64
test_dataset = TestNewsDataset("./data/process/test.txt")
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn_test, shuffle=False)

In [4]:
print('=============test_loader =============')
for data in test_loader:
    print(data)
    print(data.shape)
    break

tensor([[ 197,  241, 1342,  ...,    1,    1,    1],
        [ 417,  824,  170,  ...,    1,    1,    1],
        [ 248,   41,  167,  ...,    1,    1,    1],
        ...,
        [   9,  130,   76,  ...,    1,    1,    1],
        [ 383,  254,  199,  ...,    1,    1,    1],
        [ 901,  336,  156,  ...,    1,    1,    1]])
torch.Size([64, 32])


## Predict

In [None]:
from models.rnn import RNN
from utils.tokenizer import Vocabulary

vocab = Vocabulary.load('./data/process/vocab.txt')
label_list=['财经', '彩票', '房产', '股票', '家居', '教育', '科技', '社会', '时尚', '时政', '体育', '星座', '游戏', '娱乐']
vocab_size = len(vocab)
embedding_dim = 128
hidden_dim = 128
num_layers = 1
num_classes = len(label_list)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# 加载模型
model = RNN(vocab_size, embedding_dim, hidden_dim, num_layers, num_classes)
model.load_state_dict(
    torch.load(
        'rnn.pth',  
        map_location=torch.device('cpu')
    )
)

In [None]:
# 将预测结果写入列表中
pred_list = []
model.eval()
with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        pred_list.extend(predicted.cpu().numpy().tolist())

In [None]:
# 将预测结果写入文件中
with open('./data/result.txt', 'w', encoding='utf-8') as f_result:
    for pred in pred_list:
        f_result.write(label_list[pred] + '\n')

## Take a look

In [None]:
with open('./data/origin/Test.txt', 'r', encoding='utf-8') as f_test:
    test = f_test.readlines()
    print(len(test))
    for line in test[:10]:
        print(line.strip())

with open('./data/result.txt', 'r', encoding='utf-8') as f_result:
    res = f_result.readlines()
    print(len(res))
    for line in res[:10]:
        print(line.strip())