#TODOs:

1. [ ] Config parser
2. [ ] Using tensorflow

---





#Load Datas & Codes

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp drive/MyDrive/Rec/PETER/module.py ./
!cp drive/MyDrive/Rec/PETER/bleu.py ./
!cp drive/MyDrive/Rec/PETER/rouge.py ./
!cp drive/MyDrive/Rec/PETER/main.py ./
!cp drive/MyDrive/Rec/PETER/utils.py ./
!unzip drive/MyDrive/Rec/PETER/Amazon.zip -d ./
!unzip drive/MyDrive/Rec/PETER/Yelp.zip -d ./
!unzip drive/MyDrive/Rec/PETER/TripAdvisor.zip -d ./
!rm -r sample_data/
!rm Amazon.zip
!rm Yelp.zip
!rm TripAdvisor.zip

In [None]:
import os
import math
import torch
import torch.nn as nn
import pickle
from module import PETER
from utils import rouge_score, bleu_score, DataLoader, Batchify, now_time, ids2tokens, unique_sentence_percent, \
    root_mean_square_error, mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity
print(torch.__version__)

1.12.1+cu113


# Setup Parameters

In [None]:
class Arg():
  dataset_path = "./Amazon/MoviesAndTV/" #@param ["./TripAdvisor/", "./Yelp/", "./Amazon/MoviesAndTV/", "./Amazon/ClothingShoesAndJewelry/"] {allow-input: true}
  index = 2 #@param {type:"slider", min:1, max:5, step:1}
  
  data_path = os.path.join(dataset_path,"reviews.pickle")
  index_dir = os.path.join(dataset_path,str(index))
  checkpoint = os.path.join("./Result",dataset_path[2:])
  outf = "generated.txt" #@param {type:"string"}
  
  emsize = 512 #@param {type:"integer"}
  nhead = 2 #@param {type:"integer"}
  nhid = 2048 #@param {type:"integer"}
  nlayers = 2 #@param {type:"integer"}
  epochs = 10 #@param {type:"slider", min:10, max:200, step:1}
  batch_size = 128 #@param {type:"integer"}
  seed = 1111 #@param {type:"slider", min:0, max:10000, step:1}
  words = 15 #@param {type:"slider", min:12, max:20, step:1}
  log_interval = 400 #@param {type:"slider", min:10, max:500, step:10}
  vocab_size = 20000 #@param {type:"integer"}
  endure_times = 5 #@param {type:"slider", min:1, max:20, step:1}

  rating_reg = 0.1 #@param {type:"number"}
  context_reg = 1.0 #@param {type:"number"}
  text_reg = 1.0 #@param {type:"number"}
  dropout = 0.2 #@param {type:"slider", min:0, max:1, step:0.05}
  lr = 1.0 #@param {type:"number"}
  clip = 1.0 #@param {type:"number"}
  cuda = True #@param {type:"boolean"}
  peter_mask = True #@param {type:"boolean"}
  use_feature = True #@param {type:"boolean"}

args = Arg()

# Initial setup
- Set random seed
- Set pytorch device (/w cuda or cpu)
- Create checkpoint folders

In [None]:
class PETER_Model():
  def __init__(self,args):
    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
      if not args.cuda:
        print(now_time() + 'WARNING: You have a CUDA device, so you should probably run with --cuda')
    self.device = torch.device('cuda' if args.cuda else 'cpu')

    if not os.path.exists(args.checkpoint):
      os.makedirs(args.checkpoint)
    self.model_path = os.path.join(args.checkpoint, 'model.pt')
    self.prediction_path = os.path.join(args.checkpoint, args.outf)

  def LoadData(self,args):
    ###############################################################################
    # Load data
    ###############################################################################

    print(now_time() + 'Loading data: ', args.data_path, args.index_dir)
    self.corpus = DataLoader(args.data_path, args.index_dir, args.vocab_size)
    self.word2idx = self.corpus.word_dict.word2idx
    self.idx2word = self.corpus.word_dict.idx2word
    self.feature_set = self.corpus.feature_set
    self.train_data = Batchify(self.corpus.train, self.word2idx, args.words, args.batch_size, shuffle=True)
    self.val_data = Batchify(self.corpus.valid, self.word2idx, args.words, args.batch_size)
    self.test_data = Batchify(self.corpus.test, self.word2idx, args.words, args.batch_size)

  def BuildModel(self,args):
    ###############################################################################
    # Build the model
    ###############################################################################

    if args.use_feature:
        src_len = 2 + self.train_data.feature.size(1)  # [u, i, f]
    else:
        src_len = 2  # [u, i]
    self.tgt_len = args.words + 1  # added <bos> or <eos>
    self.ntokens = len(self.corpus.word_dict)
    nuser = len(self.corpus.user_dict)
    nitem = len(self.corpus.item_dict)
    pad_idx = self.word2idx['<pad>']
    self.model = PETER(args.peter_mask, src_len, self.tgt_len, pad_idx, nuser, nitem, self.ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(self.device)
    self.text_criterion = nn.NLLLoss(ignore_index=pad_idx)  # ignore the padding when computing loss
    self.rating_criterion = nn.MSELoss()
    self.optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    self.scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.25)

  def predict(self, log_context_dis, topk):
    word_prob = log_context_dis.exp()  # (batch_size, ntoken)
    if topk == 1:
      context = torch.argmax(word_prob, dim=1, keepdim=True)  # (batch_size, 1)
    else:
      context = torch.topk(word_prob, topk, 1)[1]  # (batch_size, topk)
    return context  # (batch_size, topk)

  def train(self, args, data):
    # Turn on training mode which enables dropout.
    self.model.train()
    context_loss = 0.
    text_loss = 0.
    rating_loss = 0.
    total_sample = 0
    while True:
      user, item, rating, seq, feature = data.next_batch()  # (batch_size, seq_len), data.step += 1
      batch_size = user.size(0)
      user = user.to(self.device)  # (batch_size,)
      item = item.to(self.device)
      rating = rating.to(self.device)
      seq = seq.t().to(self.device)  # (tgt_len + 1, batch_size)
      feature = feature.t().to(self.device)  # (1, batch_size)
      if args.use_feature:
        text = torch.cat([feature, seq[:-1]], 0)  # (src_len + tgt_len - 2, batch_size)
      else:
        text = seq[:-1]  # (src_len + tgt_len - 2, batch_size)
      # Starting each batch, we detach the hidden state from how it was previously produced.
      # If we didn't, the model would try backpropagating all the way to start of the dataset.
      self.optimizer.zero_grad()
      log_word_prob, log_context_dis, rating_p, _ = self.model(user, item, text)  # (tgt_len, batch_size, ntoken) vs. (batch_size, ntoken) vs. (batch_size,)
      context_dis = log_context_dis.unsqueeze(0).repeat((self.tgt_len - 1, 1, 1))  # (batch_size, ntoken) -> (tgt_len - 1, batch_size, ntoken)
      c_loss = self.text_criterion(context_dis.view(-1, self.ntokens), seq[1:-1].reshape((-1,)))
      r_loss = self.rating_criterion(rating_p, rating)
      t_loss = self.text_criterion(log_word_prob.view(-1, self.ntokens), seq[1:].reshape((-1,)))
      loss = args.text_reg * t_loss + args.context_reg * c_loss + args.rating_reg * r_loss
      loss.backward()

      # `clip_grad_norm` helps prevent the exploding gradient problem.
      torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)
      self.optimizer.step()

      context_loss += batch_size * c_loss.item()
      text_loss += batch_size * t_loss.item()
      rating_loss += batch_size * r_loss.item()
      total_sample += batch_size

      if data.step % args.log_interval == 0 or data.step == data.total_step:
        cur_c_loss = context_loss / total_sample
        cur_t_loss = text_loss / total_sample
        cur_r_loss = rating_loss / total_sample
        print(now_time() + 'context ppl {:4.4f} | text ppl {:4.4f} | rating loss {:4.4f} | {:5d}/{:5d} batches'.format(
              math.exp(cur_c_loss), math.exp(cur_t_loss), cur_r_loss, data.step, data.total_step))
        context_loss = 0.
        text_loss = 0.
        rating_loss = 0.
        total_sample = 0
      if data.step == data.total_step:
        break

  def evaluate(self,args, data):
    # Turn on evaluation mode which disables dropout.
    self.model.eval()
    context_loss = 0.
    text_loss = 0.
    rating_loss = 0.
    total_sample = 0
    with torch.no_grad():
      while True:
        user, item, rating, seq, feature = data.next_batch()  # (batch_size, seq_len), data.step += 1
        batch_size = user.size(0)
        user = user.to(self.device)  # (batch_size,)
        item = item.to(self.device)
        rating = rating.to(self.device)
        seq = seq.t().to(self.device)  # (tgt_len + 1, batch_size)
        feature = feature.t().to(self.device)  # (1, batch_size)
        if args.use_feature:
          text = torch.cat([feature, seq[:-1]], 0)  # (src_len + tgt_len - 2, batch_size)
        else:
          text = seq[:-1]  # (src_len + tgt_len - 2, batch_size)
        log_word_prob, log_context_dis, rating_p, _ = self.model(user, item, text)  # (tgt_len, batch_size, ntoken) vs. (batch_size, ntoken) vs. (batch_size,)
        context_dis = log_context_dis.unsqueeze(0).repeat((self.tgt_len - 1, 1, 1))  # (batch_size, ntoken) -> (tgt_len - 1, batch_size, ntoken)
        c_loss = self.text_criterion(context_dis.view(-1, self.ntokens), seq[1:-1].reshape((-1,)))
        r_loss = self.rating_criterion(rating_p, rating)
        t_loss = self.text_criterion(log_word_prob.view(-1, self.ntokens), seq[1:].reshape((-1,)))

        context_loss += batch_size * c_loss.item()
        text_loss += batch_size * t_loss.item()
        rating_loss += batch_size * r_loss.item()
        total_sample += batch_size

        if data.step == data.total_step:
          break
    return context_loss / total_sample, text_loss / total_sample, rating_loss / total_sample

  def generate(self,args,data):
    # Turn on evaluation mode which disables dropout.
    self.model.eval()
    idss_predict = []
    context_predict = []
    rating_predict = []
    with torch.no_grad():
      while True:
        user, item, rating, seq, feature = data.next_batch()
        user = user.to(self.device)  # (batch_size,)
        item = item.to(self.device)
        bos = seq[:, 0].unsqueeze(0).to(self.device)  # (1, batch_size)
        feature = feature.t().to(self.device)  # (1, batch_size)
        if args.use_feature:
          text = torch.cat([feature, bos], 0)  # (src_len - 1, batch_size)
        else:
          text = bos  # (src_len - 1, batch_size)
        start_idx = text.size(0)
        for idx in range(args.words):
          # produce a word at each step
          if idx == 0:
            log_word_prob, log_context_dis, rating_p, _ = self.model(user, item, text, False)  # (batch_size, ntoken) vs. (batch_size, ntoken) vs. (batch_size,)
            rating_predict.extend(rating_p.tolist())
            context = self.predict(log_context_dis, topk=args.words)  # (batch_size, words)
            context_predict.extend(context.tolist())
          else:
            log_word_prob, _, _, _ = self.model(user, item, text, False, False, False)  # (batch_size, ntoken)
          word_prob = log_word_prob.exp()  # (batch_size, ntoken)
          word_idx = torch.argmax(word_prob, dim=1)  # (batch_size,), pick the one with the largest probability
          text = torch.cat([text, word_idx.unsqueeze(0)], 0)  # (len++, batch_size)
        ids = text[start_idx:].t().tolist()  # (batch_size, seq_len)
        idss_predict.extend(ids)

        if data.step == data.total_step:
          break

    # rating
    predicted_rating = [(r, p) for (r, p) in zip(data.rating.tolist(), rating_predict)]
    RMSE = root_mean_square_error(predicted_rating, self.corpus.max_rating, self.corpus.min_rating)
    print(now_time() + 'RMSE {:7.4f}'.format(RMSE))
    MAE = mean_absolute_error(predicted_rating, self.corpus.max_rating, self.corpus.min_rating)
    print(now_time() + 'MAE {:7.4f}'.format(MAE))
    # text
    tokens_test = [ids2tokens(ids[1:], self.word2idx, self.idx2word) for ids in data.seq.tolist()]
    tokens_predict = [ids2tokens(ids, self.word2idx, self.idx2word) for ids in idss_predict]
    BLEU1 = bleu_score(tokens_test, tokens_predict, n_gram=1, smooth=False)
    print(now_time() + 'BLEU-1 {:7.4f}'.format(BLEU1))
    BLEU4 = bleu_score(tokens_test, tokens_predict, n_gram=4, smooth=False)
    print(now_time() + 'BLEU-4 {:7.4f}'.format(BLEU4))
    USR, USN = unique_sentence_percent(tokens_predict)
    print(now_time() + 'USR {:7.4f} | USN {:7}'.format(USR, USN))
    feature_batch = feature_detect(tokens_predict, self.feature_set)
    DIV = feature_diversity(feature_batch)  # time-consuming
    print(now_time() + 'DIV {:7.4f}'.format(DIV))
    FCR = feature_coverage_ratio(feature_batch, self.feature_set)
    print(now_time() + 'FCR {:7.4f}'.format(FCR))
    feature_test = [self.idx2word[i] for i in data.feature.squeeze(1).tolist()]  # ids to words
    FMR = feature_matching_ratio(feature_batch, feature_test)
    print(now_time() + 'FMR {:7.4f}'.format(FMR))
    text_test = [' '.join(tokens) for tokens in tokens_test]
    text_predict = [' '.join(tokens) for tokens in tokens_predict]
    tokens_context = [' '.join([idx2word[i] for i in ids]) for ids in context_predict]
    ROUGE = rouge_score(text_test, text_predict)  # a dictionary
    for (k, v) in ROUGE.items():
      print(now_time() + '{} {:7.4f}'.format(k, v))
    text_out = ''
    for (real, ctx, fake) in zip(text_test, tokens_context, text_predict):
      text_out += '{}\n{}\n{}\n\n'.format(real, ctx, fake)
    return text_out

  def TRAIN(self,args):
      
    # Loop over epochs.
    best_val_loss = float('inf')
    endure_count = 0
    for epoch in range(1, args.epochs + 1):
      print(now_time() + 'epoch {}'.format(epoch))
      self.train(args,self.train_data)
      val_c_loss, val_t_loss, val_r_loss = self.evaluate(args,self.val_data)
      if args.rating_reg == 0:
        val_loss = val_t_loss
      else:
        val_loss = val_t_loss + val_r_loss
      print(now_time() + 'context ppl {:4.4f} | text ppl {:4.4f} | rating loss {:4.4f} | valid loss {:4.4f} on validation'.format(
        math.exp(val_c_loss), math.exp(val_t_loss), val_r_loss, val_loss))
      # Save the model if the validation loss is the best we've seen so far.
      if val_loss < best_val_loss:
        best_val_loss = val_loss
        with open(self.model_path, 'wb') as f:
          torch.save(self.model, f)
      else:
        endure_count += 1
        print(now_time() + 'Endured {} time(s)'.format(endure_count))
        if endure_count == args.endure_times:
            print(now_time() + 'Cannot endure it anymore | Exiting from early stop')
            break
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        self.scheduler.step()
        print(now_time() + 'Learning rate set to {:2.8f}'.format(self.scheduler.get_last_lr()[0]))

  def TEST(self,args):
    # Load the best saved model.
    with open(self.model_path, 'rb') as f:
        self.model = torch.load(f).to(self.device)

    # Run on test data.
    test_c_loss, test_t_loss, test_r_loss = self.evaluate(args,self.test_data)
    print('=' * 89)
    print(now_time() + 'context ppl {:4.4f} | text ppl {:4.4f} | rating loss {:4.4f} on test | End of training'.format(
        math.exp(test_c_loss), math.exp(test_t_loss), test_r_loss))

    print(now_time() + 'Generating text')
    text_o = self.generate(args,self.test_data)
    with open(self.prediction_path, 'w', encoding='utf-8') as f:
      f.write(text_o)
    print(now_time() + 'Generated text saved to ({})'.format(self.prediction_path))

    self.ShowGenSent(text_o,3)
  
  def ShowDataset(self,args,n=5):
    with open(os.path.join(args.index_dir, 'validation.index'), 'r') as f:
      valid_index = [int(x) for x in f.readline().split(' ')]
      print('validation size: ',valid_index[0])
      f.close()

    reviews = pickle.load(open(args.data_path, 'rb'))
    for i,review in enumerate(reviews[:n]):
      print(i,review)
      print('user:\t',review['user'],'\nitem:\t',review['item'],'\ntemp:\t',review['template'],'\nrating:\t',review['rating'])
      (fea, adj, tem, sco) = review['template']
      print('fea:\t',fea,'\nadj:\t', adj, '\ntem:\t',tem,'\nsco:\t', sco)

  def ShowGenSent(self, gen_text,id):
    texts = gen_text.split('\n\n')
    print(id,'/',len(texts))
    ts = texts[id].split('\n')
    print('GT : ',ts[0],'\nCTX: ',ts[1],'\nGEN: ',ts[2])
    # print(text_test, tokens_context, text_predict)


In [None]:
peter = PETER_Model(args)
peter.LoadData(args)
peter.ShowDataset(args)
peter.BuildModel(args)
peter.TRAIN(args)
peter.TEST(args)

[2022-11-02 22:23:39.002609]: Loading data:  ./Amazon/MoviesAndTV/reviews.pickle ./Amazon/MoviesAndTV/2
[2022-11-02 22:23:52.885782]: epoch 1
[2022-11-02 22:24:20.819452]: context ppl 47382.5616 | text ppl 45035.2354 | rating loss 13.6080 |   400/ 2762 batches
[2022-11-02 22:24:49.352403]: context ppl 47463.1084 | text ppl 44949.0962 | rating loss 13.4653 |   800/ 2762 batches
[2022-11-02 22:25:18.522203]: context ppl 47495.7530 | text ppl 45202.9664 | rating loss 12.7859 |  1200/ 2762 batches
[2022-11-02 22:25:48.196921]: context ppl 47747.5363 | text ppl 44977.4790 | rating loss 11.4419 |  1600/ 2762 batches
[2022-11-02 22:26:18.378987]: context ppl 47365.7255 | text ppl 44894.6036 | rating loss 11.1933 |  2000/ 2762 batches
[2022-11-02 22:26:48.807786]: context ppl 47359.3785 | text ppl 45218.6345 | rating loss 10.8277 |  2400/ 2762 batches
[2022-11-02 22:27:16.126388]: context ppl 47469.5643 | text ppl 45208.3515 | rating loss 11.1706 |  2762/ 2762 batches
[2022-11-02 22:27:25.4902



[2022-11-02 22:31:36.037226]: context ppl 47326.8571 | text ppl 44997.0674 | rating loss 12.0830 |   400/ 2762 batches
[2022-11-02 22:32:06.462627]: context ppl 47621.6900 | text ppl 45055.1683 | rating loss 12.0415 |   800/ 2762 batches
[2022-11-02 22:32:36.920785]: context ppl 47473.9108 | text ppl 45055.4206 | rating loss 12.1151 |  1200/ 2762 batches
[2022-11-02 22:33:07.385099]: context ppl 47416.6424 | text ppl 44933.5833 | rating loss 12.1282 |  1600/ 2762 batches
[2022-11-02 22:33:37.882602]: context ppl 47468.7321 | text ppl 44947.4373 | rating loss 12.0592 |  2000/ 2762 batches
[2022-11-02 22:34:08.344511]: context ppl 47387.7166 | text ppl 44980.1070 | rating loss 12.0610 |  2400/ 2762 batches
[2022-11-02 22:34:35.854143]: context ppl 47375.6929 | text ppl 45041.3135 | rating loss 12.1004 |  2762/ 2762 batches
[2022-11-02 22:34:45.271417]: context ppl 47739.8632 | text ppl 44897.7224 | rating loss 11.9134 | valid loss 22.6256 on validation
[2022-11-02 22:34:45.271567]: Endur

#NRT

In [None]:
from module import NRT
from utils import DataLoader, Batchify

## NRT Class

In [None]:
class NRT_Model():
  def __init__(self,args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
      if not args.cuda:
        print(now_time() + 'WARNING: You have a CUDA device, so you should probably run with --cuda')
    self.device = torch.device('cuda' if args.cuda else 'cpu')

    if not os.path.exists(args.checkpoint):
      os.makedirs(args.checkpoint)
    self.model_path = os.path.join(args.checkpoint, 'model.pt')
    self.prediction_path = os.path.join(args.checkpoint, args.outf)
    self.NAME = 'NRT'

  def LoadData(self,args):
    ###############################################################################
    # Load data
    ###############################################################################

    print(now_time() + self.NAME,'Loading data: ', args.data_path, args.index_dir)
    self.corpus = DataLoader(args.data_path, args.index_dir, args.vocab_size)
    self.word2idx = self.corpus.word_dict.word2idx
    self.idx2word = self.corpus.word_dict.idx2word
    self.feature_set = self.corpus.feature_set
    self.train_data = Batchify(self.corpus.train, self.word2idx, args.words, args.batch_size, shuffle=True)
    self.val_data = Batchify(self.corpus.valid, self.word2idx, args.words, args.batch_size)
    self.test_data = Batchify(self.corpus.test, self.word2idx, args.words, args.batch_size)

  def BuildModel(self,args):
    ###############################################################################
    # Build the model
    ###############################################################################

    nuser = len(self.corpus.user_dict)
    nitem = len(self.corpus.item_dict)
    self.ntoken = len(self.corpus.word_dict)
    pad_idx = self.word2idx['<pad>']
    self.model = NRT(nuser, nitem, ntoken, args.emsize, args.nhid, args.nlayers, self.corpus.max_rating, self.corpus.min_rating).to(self.device)
    self.text_criterion = nn.NLLLoss(ignore_index=pad_idx)  # ignore the padding when computing loss
    self.rating_criterion = nn.MSELoss()
    self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr)
    #optimizer = torch.optim.Adadelta(model.parameters())  # lr is optional to Adadelta

  
  def train(self, atgs, data):
    self.model.train()
    text_loss = 0.
    rating_loss = 0.
    total_sample = 0
    while True:
      user, item, rating, seq = data.next_batch()  # (batch_size, seq_len), data.step += 1
      batch_size = user.size(0)
      user = user.to(self.device)  # (batch_size,)
      item = item.to(self.device)
      rating = rating.to(self.device)
      seq = seq.to(self.device)  # (batch_size, seq_len + 2)
      # Starting each batch, we detach the hidden state from how it was previously produced.
      # If we didn't, the model would try backpropagating all the way to start of the dataset.
      self.optimizer.zero_grad()
      rating_p, log_word_prob = self.model(user, item, seq[:, :-1])  # (batch_size,) vs. (batch_size, seq_len + 1, ntoken)
      r_loss = rating_criterion(rating_p, rating)
      t_loss = text_criterion(log_word_prob.view(-1, self.ntoken), seq[:, 1:].reshape((-1,)))
      l2_loss = torch.cat([x.view(-1) for x in self.model.parameters()]).pow(2.).sum()
      loss = args.text_reg * t_loss + args.rating_reg * r_loss + args.l2_reg * l2_loss
      loss.backward()
      self.optimizer.step()

      text_loss += batch_size * t_loss.item()
      rating_loss += batch_size * r_loss.item()
      total_sample += batch_size

      if data.step == data.total_step:
        break
    return text_loss / total_sample, rating_loss / total_sample


  def evaluate(self, args, data):
    self.model.eval()
    text_loss = 0.
    rating_loss = 0.
    total_sample = 0
    with torch.no_grad():
      while True:
        user, item, rating, seq = data.next_batch()  # (batch_size, seq_len), data.step += 1
        batch_size = user.size(0)
        user = user.to(self.device)  # (batch_size,)
        item = item.to(self.device)
        rating = rating.to(self.device)
        seq = seq.to(self.device)  # (batch_size, seq_len + 2)
        rating_p, log_word_prob = self.model(user, item, seq[:, :-1])  # (batch_size,) vs. (batch_size, seq_len + 1, ntoken)
        r_loss = self.rating_criterion(rating_p, rating)
        t_loss = self.text_criterion(log_word_prob.view(-1, self.ntoken), seq[:, 1:].reshape((-1,)))

        text_loss += batch_size * t_loss.item()
        rating_loss += batch_size * r_loss.item()
        total_sample += batch_size

        if data.step == data.total_step:
          break
    return text_loss / total_sample, rating_loss / total_sample


  def generate(self, args, data):
    self.model.eval()
    idss_predict = []
    rating_predict = []
    with torch.no_grad():
      while True:
        user, item, _, seq = data.next_batch()  # (batch_size, seq_len), data.step += 1
        user = user.to(self.device)  # (batch_size,)
        item = item.to(self.device)
        inputs = seq[:, :1].to(self.device)  # (batch_size, 1)
        hidden = None
        ids = inputs
        for idx in range(args.words):
          # produce a word at each step
          if idx == 0:
            rating_p, hidden = self.model.encoder(user, item)
            rating_predict.extend(rating_p.tolist())
            log_word_prob, hidden = self.model.decoder(inputs, hidden)  # (batch_size, 1, ntoken)
          else:
            log_word_prob, hidden = self.model.decoder(inputs, hidden)  # (batch_size, 1, ntoken)
          word_prob = log_word_prob.squeeze().exp()  # (batch_size, ntoken)
          inputs = torch.argmax(word_prob, dim=1, keepdim=True)  # (batch_size, 1), pick the one with the largest probability
          ids = torch.cat([ids, inputs], 1)  # (batch_size, len++)
        ids = ids[:, 1:].tolist()  # remove bos
        idss_predict.extend(ids)

        if data.step == data.total_step:
          break
    return idss_predict, rating_predict


  def TRAIN(self,args):
    # Loop over epochs.
    best_val_loss = float('inf')
    endure_count = 0
    for epoch in range(1, args.epochs + 1):
        print(now_time() + 'epoch {}'.format(epoch))
        train_t_loss, train_r_loss = self.train(args,train_data)
        print(now_time() + 'text ppl {:4.4f} | rating loss {:4.4f} | total loss {:4.4f} on train'.format(
            math.exp(train_t_loss), train_r_loss, train_t_loss + train_r_loss))
        val_t_loss, val_r_loss = self.evaluate(args,val_data)
        val_loss = val_t_loss + val_r_loss
        print(now_time() + 'text ppl {:4.4f} | rating loss {:4.4f} | total loss {:4.4f} on validation'.format(
            math.exp(val_t_loss), val_r_loss, val_loss))
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            with open(self.model_path, 'wb') as f:
                torch.save(self.model, f)
        else:
            endure_count += 1
            print(now_time() + 'Endured {} time(s)'.format(endure_count))
            if endure_count == args.endure_times:
                print(now_time() + 'Cannot endure it anymore | Exiting from early stop')
                break

  def TEST(self,args):
    # Load the best saved model.
    with open(self.model_path, 'rb') as f:
        self.model = torch.load(f).to(self.device)

    # Run on test data.
    test_t_loss, test_r_loss = self.evaluate(args,self.test_data)
    print('=' * 89)
    print(now_time() + 'text ppl {:4.4f} | rating loss {:4.4f} | total loss {:4.4f} on test | End of training'.format(
            math.exp(test_t_loss), test_r_loss, test_t_loss + test_r_loss))
    print(now_time() + 'Generating text')
    idss_predicted, rating_predicted = self.generate(args,self.test_data)
    # rating
    predicted_rating = [(r, p) for (r, p) in zip(test_data.rating.tolist(), rating_predicted)]
    RMSE = root_mean_square_error(predicted_rating, self.corpus.max_rating, self.corpus.min_rating)
    print(now_time() + 'RMSE {:7.4f}'.format(RMSE))
    MAE = mean_absolute_error(predicted_rating, self.corpus.max_rating, self.corpus.min_rating)
    print(now_time() + 'MAE {:7.4f}'.format(MAE))
    # text
    tokens_test = [ids2tokens(ids[1:], self.word2idx, self.idx2word) for ids in test_data.seq.tolist()]
    tokens_predict = [ids2tokens(ids, self.word2idx, self.idx2word) for ids in idss_predicted]
    BLEU1 = bleu_score(tokens_test, tokens_predict, n_gram=1, smooth=False)
    print(now_time() + 'BLEU-1 {:7.4f}'.format(BLEU1))
    BLEU4 = bleu_score(tokens_test, tokens_predict, n_gram=4, smooth=False)
    print(now_time() + 'BLEU-4 {:7.4f}'.format(BLEU4))
    USR, USN = unique_sentence_percent(tokens_predict)
    print(now_time() + 'USR {:7.4f} | USN {:7}'.format(USR, USN))
    feature_batch = feature_detect(tokens_predict, feature_set)
    DIV = feature_diversity(feature_batch)  # time-consuming
    print(now_time() + 'DIV {:7.4f}'.format(DIV))
    FCR = feature_coverage_ratio(feature_batch, feature_set)
    print(now_time() + 'FCR {:7.4f}'.format(FCR))
    FMR = feature_matching_ratio(feature_batch, test_data.feature)
    print(now_time() + 'FMR {:7.4f}'.format(FMR))
    text_test = [' '.join(tokens) for tokens in tokens_test]
    text_predict = [' '.join(tokens) for tokens in tokens_predict]
    ROUGE = rouge_score(text_test, text_predict)  # a dictionary
    for (k, v) in ROUGE.items():
        print(now_time() + '{} {:7.4f}'.format(k, v))
    text_out = ''
    for (real, fake) in zip(text_test, text_predict):
        text_out += '{}\n{}\n\n'.format(real, fake)
    with open(self.prediction_path, 'w', encoding='utf-8') as f:
        f.write(text_out)
    print(now_time() + 'Generated text saved to ({})'.format(self.prediction_path))

    self.ShowGenSent(text_o,3)

  def ShowGenSent(self, gen_text,id):
    texts = gen_text.split('\n\n')
    print(id,'/',len(texts))
    ts = texts[id].split('\n')
    print('GT : ',ts[0],'\nCTX: ',ts[1],'\nGEN: ',ts[2])
    # print(text_test, tokens_context, text_predict)


In [None]:
nrt = NRT_Model(args)
nrt.LoadData(args)
nrt.BuildModel(args)
nrt.TRAIN(args)
nrt.TEST(args)