In [21]:
import torch
import numpy as np

# Load Device

In [22]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device('cuda:0')

# 获取测试用户和测试新闻

In [23]:
import pickle
with open('./data2/TestUsers.pkl', 'rb') as f:
    TestUsers = pickle.load(f)
with open('./data2/TestSamples.pkl', 'rb') as f:
    TestSamples = pickle.load(f)

# 加载模型

In [24]:
from pensmodule.Generator import *
model_path = './runs/seq2seq/exp/checkpoint_train_mod4_step_1500.pth'

def load_model_from_ckpt(path):
    checkpoint = torch.load(path, weights_only= False)
    model = checkpoint['model']
    if torch.cuda.device_count() > 1:
        print('multiple gpu training')
        model = nn.DataParallel(model)
    return model

model = load_model_from_ckpt(model_path).to(device)
model.eval()

HeadlineGen(
  (embeddings): Embedding(141910, 300)
  (encoder): LSTMEncoder(
    (embeddings): Embedding(141910, 300)
    (rnn): LSTM(300, 64, batch_first=True, dropout=0.2, bidirectional=True)
    (bridge): ModuleList(
      (0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (decoder): Decoder_P(
    (embeddings): Embedding(141910, 300)
    (dropout): Dropout(p=0.2, inplace=False)
    (rnn): LSTM(300, 128, batch_first=True)
    (attention): Attention(
      (linear_out): Linear(in_features=256, out_features=128, bias=True)
    )
    (transform): ModuleList(
      (0-1): 2 x Linear(in_features=64, out_features=128, bias=True)
    )
    (out): Linear(in_features=128, out_features=141910, bias=True)
    (p_gen_linear): Linear(in_features=128, out_features=1, bias=True)
  )
  (loss_fn): NLLLoss()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [25]:
from pensmodule.UserEncoder import NRMS
embedding_matrix = np.load('./data2/embedding_matrix.npy')

usermodel = NRMS(embedding_matrix)
usermodel.load_state_dict(torch.load('./runs/userencoder/NAML-2.pkl'))
usermodel = usermodel.to(device)
usermodel.eval()

NRMS(
  (embed): Embedding(141910, 300, padding_idx=0)
  (attn_word): MultiHeadAttention(
    (W_Q): Linear(in_features=300, out_features=400, bias=True)
    (W_K): Linear(in_features=300, out_features=400, bias=True)
    (W_V): Linear(in_features=300, out_features=400, bias=True)
  )
  (attn_pool_word): AttentionPooling(
    (att_fc1): Linear(in_features=400, out_features=200, bias=True)
    (att_fc2): Linear(in_features=200, out_features=1, bias=True)
    (drop_layer): Dropout(p=0.2, inplace=False)
  )
  (attn_pool_news): AttentionPooling(
    (att_fc1): Linear(in_features=64, out_features=32, bias=True)
    (att_fc2): Linear(in_features=32, out_features=1, bias=True)
    (drop_layer): Dropout(p=0.2, inplace=False)
  )
  (drop_layer): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=400, out_features=64, bias=True)
  (criterion): CrossEntropyLoss()
)

In [26]:
news_scoring = np.load('./data2/news_scoring2.npy')
sources = np.load('./data2/sources.npy')

i_dset = TestImpressionDataset(news_scoring, sources, TestUsers, TestSamples)
test_iter = DataLoader(i_dset, batch_size=16, shuffle=False)

In [27]:
from pensmodule.Generator.eval import predict
with open('./data2/dict.pkl', 'rb') as f:
    news_index,category_dict,word_dict = pickle.load(f)
index2word = {}
for k,v in word_dict.items():
    index2word[v] = k
print(len(word_dict),embedding_matrix.shape)
refs, hyps, scores1, scores2, scoresf = predict(usermodel, model, test_iter, device, index2word, beam=False, beam_size=3, eos_id=2)
# refs, hyps, scores1, scores2, scoresf = predict(usermodel, model, test_iter, device, index2word, beam=True, beam_size=3, eos_id=2)

print('refs:', refs[:10])
print('hyps:', hyps[:10])


141910 (141910, 300)


100%|██████████| 1288/1288 [00:58<00:00, 21.90it/s]


refs: ["legal battle looms over trump epa's rule change of obama's clean power plan rule", 'wise choices for stylish updating of old homes', 'verlander may be reconsidering his stance on mlbs juicing balls', 'infamous o.j. simpson launching official twitter account', '15 year old cori gauff beats venus williams at wimbledon', 'still much room for improvement in many us states', 'eagles still have plenty of news despite nfl dead zone', 'smart moves retirees should make', 'is the express pass in universal orlando worth the cost?', 'cbs news anchor angie miles says goodbye']
hyps: ['trump administration over obama era power plant rule', 'house styles to make your home look to freshen up', 'justin verlander s juiced balls , and he had to have known league baseball', 'o . j . simpson launched a twitter account with video post video', 'coco gauff , coco gauff , still 15 , gauff , beat one of the wimbledon', 'world will celebrate international literacy day , countries around the world will ce

In [28]:
scores1.mean(), scores2.mean(), scoresf.mean()

(0.24155567693380653, 0.0866230724526373, 0.2137680350959111)