#Load Datas & Codes

In [None]:
!cp drive/MyDrive/Rec/PETER/module.py ./
!cp drive/MyDrive/Rec/PETER/bleu.py ./
!cp drive/MyDrive/Rec/PETER/rouge.py ./
!cp drive/MyDrive/Rec/PETER/main.py ./
!cp drive/MyDrive/Rec/PETER/utils.py ./
!unzip drive/MyDrive/Rec/PETER/Amazon.zip -d ./
!unzip drive/MyDrive/Rec/PETER/Yelp.zip -d ./
!unzip drive/MyDrive/Rec/PETER/TripAdvisor.zip -d ./

In [None]:
import os
import math
import torch
import torch.nn as nn
from module import PETER
from utils import rouge_score, bleu_score, DataLoader, Batchify, now_time, ids2tokens, unique_sentence_percent, \
    root_mean_square_error, mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity
print(torch.__version__)

1.12.1+cu113


# Setup Parameters

In [None]:
class Arg():
  dataset_path = "./Amazon/MoviesAndTV/" #@param ["./TripAdvisor/", "./Yelp/", "./Amazon/MoviesAndTV/", "./Amazon/ClothingShoesAndJewelry/"] {allow-input: true}
  index = 2 #@param {type:"slider", min:1, max:5, step:1}
  
  data_path = os.path.join(dataset_path,"reviews.pickle")
  index_dir = os.path.join(dataset_path,str(index))
  checkpoint = os.path.join("./Result",dataset_path[2:])
  outf = "generated.txt" #@param {type:"string"}
  
  emsize = 512 #@param {type:"integer"}
  nhead = 2 #@param {type:"integer"}
  nhid = 2048 #@param {type:"integer"}
  nlayers = 2 #@param {type:"integer"}
  epochs = 100 #@param {type:"slider", min:50, max:200, step:1}
  batch_size = 128 #@param {type:"integer"}
  seed = 1111 #@param {type:"slider", min:0, max:10000, step:1}
  words = 15 #@param {type:"slider", min:12, max:20, step:1}
  log_interval = 400 #@param {type:"slider", min:10, max:500, step:10}
  vocab_size = 20000 #@param {type:"integer"}
  endure_times = 5 #@param {type:"slider", min:1, max:20, step:1}

  rating_reg = 0.1 #@param {type:"number"}
  context_reg = 1.0 #@param {type:"number"}
  text_reg = 1.0 #@param {type:"number"}
  dropout = 0.2 #@param {type:"slider", min:0, max:1, step:0.05}
  lr = 1.0 #@param {type:"number"}
  clip = 1.0 #@param {type:"number"}
  cuda = True #@param {type:"boolean"}
  peter_mask = True #@param {type:"boolean"}
  use_feature = True #@param {type:"boolean"}

args = Arg()

# Initial setup
- Set random seed
- Set pytorch device (/w cuda or cpu)
- Create checkpoint folders

In [None]:
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(now_time() + 'WARNING: You have a CUDA device, so you should probably run with --cuda')
device = torch.device('cuda' if args.cuda else 'cpu')

if not os.path.exists(args.checkpoint):
    os.makedirs(args.checkpoint)
model_path = os.path.join(args.checkpoint, 'model.pt')
prediction_path = os.path.join(args.checkpoint, args.outf)

## Load Data

In [None]:
###############################################################################
# Load data
###############################################################################

print(now_time() + 'Loading data: ', args.data_path, args.index_dir)
corpus = DataLoader(args.data_path, args.index_dir, args.vocab_size)
word2idx = corpus.word_dict.word2idx
idx2word = corpus.word_dict.idx2word
feature_set = corpus.feature_set
train_data = Batchify(corpus.train, word2idx, args.words, args.batch_size, shuffle=True)
val_data = Batchify(corpus.valid, word2idx, args.words, args.batch_size)
test_data = Batchify(corpus.test, word2idx, args.words, args.batch_size)

[2022-10-28 00:14:16.653643]: Loading data:  ./Amazon/MoviesAndTV/reviews.pickle ./Amazon/MoviesAndTV/2


In [None]:
with open(os.path.join(args.index_dir, 'validation.index'), 'r') as f:
  valid_index = [int(x) for x in f.readline().split(' ')]
  print(valid_index[0])

262148


In [None]:
import pickle
reviews = pickle.load(open(args.data_path, 'rb'))
for i,review in enumerate(reviews[:10]):
  print(i,review)
  print('user:\t',review['user'],'\nitem:\t',review['item'],'\ntemp:\t',review['template'],'\nrating:\t',review['rating'])
  (fea, adj, tem, sco) = review['template']
  print('fea:\t',fea,'\nadj:\t', adj, '\ntem:\t',tem,'\nsco:\t', sco)

0 {'user': 'A3UT41TWD7N0D5', 'item': '0307142493', 'rating': 5, 'template': ('classic', 'when', 'i suppose the best thing about this type of christmas classic is that when people my age view it they can escape back to their innocent youth for a short while', 1), 'predicted': 'television'}
user:	 A3UT41TWD7N0D5 
item:	 0307142493 
temp:	 ('classic', 'when', 'i suppose the best thing about this type of christmas classic is that when people my age view it they can escape back to their innocent youth for a short while', 1) 
rating:	 5
fea:	 classic 
adj:	 when 
tem:	 i suppose the best thing about this type of christmas classic is that when people my age view it they can escape back to their innocent youth for a short while 
sco:	 1
1 {'user': 'A3H82LUT1EC655', 'item': '0307142493', 'rating': 5, 'template': ('movie', 'fantastic', 'this is a fantastic movie for kids and adults of all ages', 1), 'predicted': 'thrill'}
user:	 A3H82LUT1EC655 
item:	 0307142493 
temp:	 ('movie', 'fantastic', 't

In [None]:
print(corpus.item_dict.idx2entity)

['1068719', '302331', '300697', '187686', '305371', '1223686', '648659', '2031570', '195200', '195198', '6484754', '305068', '195203', '5786212', '1148363', '300844', '199157', '603403', '300319', '5823268', '305813', '305877', '583732', '305913', '192063', '192048', '187685', '93450', '529404', '548045', '193105', '2048999', '93457', '239656', '1149402', '1783324', '590144', '195216', '798951', '302990', '313127', '1371326', '1837031', '1438231', '302322', '195210', '8541913', '1946018', '289182', '305383', '672863', '2401829', '302174', '2079052', '1590819', '126260', '1415169', '122005', '3349465', '1139717', '575732', '193045', '192036', '305892', '208454', '1161242', '305907', '1210756', '2699176', '1083482', '305893', '1513860', '3199601', '302133', '113317', '93618', '2016610', '193113', '1631877', '93475', '188029', '3523344', '308694', '1938661', '677369', '578305', '1718977', '309532', '195205', '195155', '188961', '224224', '192044', '187735', '2532186', '187567', '192118', 

## Build the model

In [None]:
###############################################################################
# Build the model
###############################################################################

if args.use_feature:
    src_len = 2 + train_data.feature.size(1)  # [u, i, f]
else:
    src_len = 2  # [u, i]
tgt_len = args.words + 1  # added <bos> or <eos>
ntokens = len(corpus.word_dict)
nuser = len(corpus.user_dict)
nitem = len(corpus.item_dict)
pad_idx = word2idx['<pad>']
model = PETER(args.peter_mask, src_len, tgt_len, pad_idx, nuser, nitem, ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
text_criterion = nn.NLLLoss(ignore_index=pad_idx)  # ignore the padding when computing loss
rating_criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.25)

## Training 

In [None]:
###############################################################################
# Training code
###############################################################################

def predict(log_context_dis, topk):
    word_prob = log_context_dis.exp()  # (batch_size, ntoken)
    if topk == 1:
        context = torch.argmax(word_prob, dim=1, keepdim=True)  # (batch_size, 1)
    else:
        context = torch.topk(word_prob, topk, 1)[1]  # (batch_size, topk)
    return context  # (batch_size, topk)


def train(data):
    # Turn on training mode which enables dropout.
    model.train()
    context_loss = 0.
    text_loss = 0.
    rating_loss = 0.
    total_sample = 0
    while True:
        user, item, rating, seq, feature = data.next_batch()  # (batch_size, seq_len), data.step += 1
        batch_size = user.size(0)
        user = user.to(device)  # (batch_size,)
        item = item.to(device)
        rating = rating.to(device)
        seq = seq.t().to(device)  # (tgt_len + 1, batch_size)
        feature = feature.t().to(device)  # (1, batch_size)
        if args.use_feature:
            text = torch.cat([feature, seq[:-1]], 0)  # (src_len + tgt_len - 2, batch_size)
        else:
            text = seq[:-1]  # (src_len + tgt_len - 2, batch_size)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        optimizer.zero_grad()
        log_word_prob, log_context_dis, rating_p, _ = model(user, item, text)  # (tgt_len, batch_size, ntoken) vs. (batch_size, ntoken) vs. (batch_size,)
        context_dis = log_context_dis.unsqueeze(0).repeat((tgt_len - 1, 1, 1))  # (batch_size, ntoken) -> (tgt_len - 1, batch_size, ntoken)
        c_loss = text_criterion(context_dis.view(-1, ntokens), seq[1:-1].reshape((-1,)))
        r_loss = rating_criterion(rating_p, rating)
        t_loss = text_criterion(log_word_prob.view(-1, ntokens), seq[1:].reshape((-1,)))
        loss = args.text_reg * t_loss + args.context_reg * c_loss + args.rating_reg * r_loss
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        context_loss += batch_size * c_loss.item()
        text_loss += batch_size * t_loss.item()
        rating_loss += batch_size * r_loss.item()
        total_sample += batch_size

        if data.step % args.log_interval == 0 or data.step == data.total_step:
            cur_c_loss = context_loss / total_sample
            cur_t_loss = text_loss / total_sample
            cur_r_loss = rating_loss / total_sample
            print(now_time() + 'context ppl {:4.4f} | text ppl {:4.4f} | rating loss {:4.4f} | {:5d}/{:5d} batches'.format(
                math.exp(cur_c_loss), math.exp(cur_t_loss), cur_r_loss, data.step, data.total_step))
            context_loss = 0.
            text_loss = 0.
            rating_loss = 0.
            total_sample = 0
        if data.step == data.total_step:
            break


def evaluate(data):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    context_loss = 0.
    text_loss = 0.
    rating_loss = 0.
    total_sample = 0
    with torch.no_grad():
        while True:
            user, item, rating, seq, feature = data.next_batch()  # (batch_size, seq_len), data.step += 1
            batch_size = user.size(0)
            user = user.to(device)  # (batch_size,)
            item = item.to(device)
            rating = rating.to(device)
            seq = seq.t().to(device)  # (tgt_len + 1, batch_size)
            feature = feature.t().to(device)  # (1, batch_size)
            if args.use_feature:
                text = torch.cat([feature, seq[:-1]], 0)  # (src_len + tgt_len - 2, batch_size)
            else:
                text = seq[:-1]  # (src_len + tgt_len - 2, batch_size)
            log_word_prob, log_context_dis, rating_p, _ = model(user, item, text)  # (tgt_len, batch_size, ntoken) vs. (batch_size, ntoken) vs. (batch_size,)
            context_dis = log_context_dis.unsqueeze(0).repeat((tgt_len - 1, 1, 1))  # (batch_size, ntoken) -> (tgt_len - 1, batch_size, ntoken)
            c_loss = text_criterion(context_dis.view(-1, ntokens), seq[1:-1].reshape((-1,)))
            r_loss = rating_criterion(rating_p, rating)
            t_loss = text_criterion(log_word_prob.view(-1, ntokens), seq[1:].reshape((-1,)))

            context_loss += batch_size * c_loss.item()
            text_loss += batch_size * t_loss.item()
            rating_loss += batch_size * r_loss.item()
            total_sample += batch_size

            if data.step == data.total_step:
                break
    return context_loss / total_sample, text_loss / total_sample, rating_loss / total_sample


def generate(data):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    idss_predict = []
    context_predict = []
    rating_predict = []
    with torch.no_grad():
        while True:
            user, item, rating, seq, feature = data.next_batch()
            user = user.to(device)  # (batch_size,)
            item = item.to(device)
            bos = seq[:, 0].unsqueeze(0).to(device)  # (1, batch_size)
            feature = feature.t().to(device)  # (1, batch_size)
            if args.use_feature:
                text = torch.cat([feature, bos], 0)  # (src_len - 1, batch_size)
            else:
                text = bos  # (src_len - 1, batch_size)
            start_idx = text.size(0)
            for idx in range(args.words):
                # produce a word at each step
                if idx == 0:
                    log_word_prob, log_context_dis, rating_p, _ = model(user, item, text, False)  # (batch_size, ntoken) vs. (batch_size, ntoken) vs. (batch_size,)
                    rating_predict.extend(rating_p.tolist())
                    context = predict(log_context_dis, topk=args.words)  # (batch_size, words)
                    context_predict.extend(context.tolist())
                else:
                    log_word_prob, _, _, _ = model(user, item, text, False, False, False)  # (batch_size, ntoken)
                word_prob = log_word_prob.exp()  # (batch_size, ntoken)
                word_idx = torch.argmax(word_prob, dim=1)  # (batch_size,), pick the one with the largest probability
                text = torch.cat([text, word_idx.unsqueeze(0)], 0)  # (len++, batch_size)
            ids = text[start_idx:].t().tolist()  # (batch_size, seq_len)
            idss_predict.extend(ids)

            if data.step == data.total_step:
                break

    # rating
    predicted_rating = [(r, p) for (r, p) in zip(data.rating.tolist(), rating_predict)]
    RMSE = root_mean_square_error(predicted_rating, corpus.max_rating, corpus.min_rating)
    print(now_time() + 'RMSE {:7.4f}'.format(RMSE))
    MAE = mean_absolute_error(predicted_rating, corpus.max_rating, corpus.min_rating)
    print(now_time() + 'MAE {:7.4f}'.format(MAE))
    # text
    tokens_test = [ids2tokens(ids[1:], word2idx, idx2word) for ids in data.seq.tolist()]
    tokens_predict = [ids2tokens(ids, word2idx, idx2word) for ids in idss_predict]
    BLEU1 = bleu_score(tokens_test, tokens_predict, n_gram=1, smooth=False)
    print(now_time() + 'BLEU-1 {:7.4f}'.format(BLEU1))
    BLEU4 = bleu_score(tokens_test, tokens_predict, n_gram=4, smooth=False)
    print(now_time() + 'BLEU-4 {:7.4f}'.format(BLEU4))
    USR, USN = unique_sentence_percent(tokens_predict)
    print(now_time() + 'USR {:7.4f} | USN {:7}'.format(USR, USN))
    feature_batch = feature_detect(tokens_predict, feature_set)
    DIV = feature_diversity(feature_batch)  # time-consuming
    print(now_time() + 'DIV {:7.4f}'.format(DIV))
    FCR = feature_coverage_ratio(feature_batch, feature_set)
    print(now_time() + 'FCR {:7.4f}'.format(FCR))
    feature_test = [idx2word[i] for i in data.feature.squeeze(1).tolist()]  # ids to words
    FMR = feature_matching_ratio(feature_batch, feature_test)
    print(now_time() + 'FMR {:7.4f}'.format(FMR))
    text_test = [' '.join(tokens) for tokens in tokens_test]
    text_predict = [' '.join(tokens) for tokens in tokens_predict]
    tokens_context = [' '.join([idx2word[i] for i in ids]) for ids in context_predict]
    ROUGE = rouge_score(text_test, text_predict)  # a dictionary
    for (k, v) in ROUGE.items():
        print(now_time() + '{} {:7.4f}'.format(k, v))
    text_out = ''
    for (real, ctx, fake) in zip(text_test, tokens_context, text_predict):
        text_out += '{}\n{}\n{}\n\n'.format(real, ctx, fake)
    return text_out


In [None]:
# Loop over epochs.
best_val_loss = float('inf')
endure_count = 0
for epoch in range(1, args.epochs + 1):
    print(now_time() + 'epoch {}'.format(epoch))
    train(train_data)
    val_c_loss, val_t_loss, val_r_loss = evaluate(val_data)
    if args.rating_reg == 0:
        val_loss = val_t_loss
    else:
        val_loss = val_t_loss + val_r_loss
    print(now_time() + 'context ppl {:4.4f} | text ppl {:4.4f} | rating loss {:4.4f} | valid loss {:4.4f} on validation'.format(
        math.exp(val_c_loss), math.exp(val_t_loss), val_r_loss, val_loss))
    # Save the model if the validation loss is the best we've seen so far.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        with open(model_path, 'wb') as f:
            torch.save(model, f)
    else:
        endure_count += 1
        print(now_time() + 'Endured {} time(s)'.format(endure_count))
        if endure_count == args.endure_times:
            print(now_time() + 'Cannot endure it anymore | Exiting from early stop')
            break
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        scheduler.step()
        print(now_time() + 'Learning rate set to {:2.8f}'.format(scheduler.get_last_lr()[0]))

[2022-10-28 01:22:04.736443]: epoch 1
[2022-10-28 01:22:21.496513]: context ppl 2221.1663 | text ppl 1430.7640 | rating loss 3.7824 |   200/ 2762 batches
[2022-10-28 01:22:35.873731]: context ppl 909.7735 | text ppl 419.4971 | rating loss 1.7960 |   400/ 2762 batches
[2022-10-28 01:22:50.363296]: context ppl 806.7235 | text ppl 291.7831 | rating loss 1.4394 |   600/ 2762 batches
[2022-10-28 01:23:04.966700]: context ppl 751.1223 | text ppl 236.5549 | rating loss 1.1892 |   800/ 2762 batches
[2022-10-28 01:23:19.745184]: context ppl 711.6663 | text ppl 196.6061 | rating loss 1.2655 |  1000/ 2762 batches
[2022-10-28 01:23:34.637404]: context ppl 678.8927 | text ppl 170.1068 | rating loss 1.4964 |  1200/ 2762 batches
[2022-10-28 01:23:49.614659]: context ppl 676.4737 | text ppl 159.5947 | rating loss 1.5253 |  1400/ 2762 batches
[2022-10-28 01:24:04.766256]: context ppl 666.7142 | text ppl 149.7882 | rating loss 1.5510 |  1600/ 2762 batches
[2022-10-28 01:24:20.251931]: context ppl 663.16

## Test the model

In [None]:
# Load the best saved model.
with open(model_path, 'rb') as f:
    model = torch.load(f).to(device)

# Run on test data.
test_c_loss, test_t_loss, test_r_loss = evaluate(test_data)
print('=' * 89)
print(now_time() + 'context ppl {:4.4f} | text ppl {:4.4f} | rating loss {:4.4f} on test | End of training'.format(
    math.exp(test_c_loss), math.exp(test_t_loss), test_r_loss))

print(now_time() + 'Generating text')
text_o = generate(test_data)
with open(prediction_path, 'w', encoding='utf-8') as f:
    f.write(text_o)
print(now_time() + 'Generated text saved to ({})'.format(prediction_path))

[2022-10-25 02:04:25.496466]: context ppl 491.2008 | text ppl 56.4413 | rating loss 0.9243 on test | End of training
[2022-10-25 02:04:25.496580]: Generating text
[2022-10-25 02:04:39.006644]: RMSE  0.9614
[2022-10-25 02:04:39.017312]: MAE  0.7116
[2022-10-25 02:04:41.014030]: BLEU-1 12.8172
[2022-10-25 02:04:44.584114]: BLEU-4  1.1692
[2022-10-25 02:05:00.831474]: USR  0.1920 | USN    8481
[2022-10-25 02:10:23.069716]: DIV  1.9166
[2022-10-25 02:10:23.481523]: FCR  0.1422
[2022-10-25 02:10:23.494143]: FMR  0.1203
[2022-10-25 02:10:28.078829]: rouge_1/f_score 15.3442
[2022-10-25 02:10:28.079018]: rouge_1/r_score 13.9346
[2022-10-25 02:10:28.079750]: rouge_1/p_score 19.8911
[2022-10-25 02:10:28.079789]: rouge_2/f_score  2.2119
[2022-10-25 02:10:28.079817]: rouge_2/r_score  2.0991
[2022-10-25 02:10:28.079842]: rouge_2/p_score  2.7884
[2022-10-25 02:10:28.080301]: rouge_l/f_score 11.8085
[2022-10-25 02:10:28.080348]: rouge_l/r_score 12.5235
[2022-10-25 02:10:28.080375]: rouge_l/p_score 16

## Display the Generated Text

In [None]:
def dispRes(gen_text,id):
  texts = gen_text.split('\n\n')
  print(id,'/',len(texts))
  ts = texts[id].split('\n')
  print('GT : ',ts[0],'\nCTX: ',ts[1],'\nGEN: ',ts[2])
  # print(text_test, tokens_context, text_predict)
dispRes(text_o,3)

3 / 32004
GT :  the rooms are spacious and the bathroom has a large tub 
CTX:  <eos> the and a was with pool is good to very in room nice were 
GEN:  the rooms are spacious and clean
