In [1]:
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Categorical
from torchtext import datasets
import os
import time
import numpy as np 
import random
import argparse

from networks import Distilbert_LSTM, Policy_C, Policy_N, Policy_S, ValueNetwork
from utils.utils import sample_policy_c, sample_policy_n, sample_policy_s, evaluate_lm, compute_policy_value_losses
from utils.utils import cnn_cost, clstm_cost, c_cost, n_cost, s_cost, openDfFromPickle, calculate_stats_from_cm

In [2]:
print('Reading data...')
train_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_train_distilbert-base-uncased.pkl")
valid_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_val_distilbert-base-uncased.pkl")
test_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_test_distilbert-base-uncased.pkl")
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')


Reading data...
Number of training examples: 20000
Number of validation examples: 5000
Number of testing examples: 25000


In [3]:
train_data

Unnamed: 0,text,label,label_str,text_bert_input_ids,text_bert_attention_mask
23311,I borrowed this movie despite its extremely lo...,1,pos,"[101, 1045, 11780, 2023, 3185, 2750, 2049, 518...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
23623,After the unexpected accident that killed an i...,1,pos,"[101, 2044, 1996, 9223, 4926, 2008, 2730, 2019...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1020,On the summer blockbuster hit BASEketball This...,0,neg,"[101, 2006, 1996, 2621, 27858, 2718, 2918, 348...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
12645,Can Scarcely Imagine Better Movie Than This br...,1,pos,"[101, 2064, 20071, 5674, 2488, 3185, 2084, 202...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1533,A still famous but decadent actor Morgan Freem...,0,neg,"[101, 1037, 2145, 3297, 2021, 5476, 3372, 3364...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...
21575,My discovery of the cinema of Jan Svankmajer o...,1,pos,"[101, 2026, 5456, 1997, 1996, 5988, 1997, 5553...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
5390,The story is similar to ET an extraterrestrial...,0,neg,"[101, 1996, 2466, 2003, 2714, 2000, 3802, 2019...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
860,I have read the novel Reaper of Ben Mezrich fe...,0,neg,"[101, 1045, 2031, 3191, 1996, 3117, 19559, 199...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
15795,Went to see this finnish film and ve got to sa...,1,pos,"[101, 2253, 2000, 2156, 2023, 6983, 2143, 1998...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [4]:
train_data["text"].iloc[1]

'After the unexpected accident that killed an inexperienced climber Michelle Joyner Eight months has passed The Rocky Mountain Rescue receive distress call set by brilliant terrorist mastermind Eric Quaien John Lithgow Quaien has lost three large cases that has millions of dollars inside Two experienced climbers Walker Sylvester Stallone and Tucker Micheal Rooker and helicopter pilot Janine Turner are to the rescue but they are set by trap by Quaien and his men Now the two climbers and pilot are forced to play deadly game of hide and seek While Quaien is trying to find the millions of dollars and he kidnapped Tucker to find the money Once Tucker finds the money Tucker will be dead Against explosive firepower bitter cold and dizzying heights Walker must outwit Quaien for survival br br Directed by Renny Harlin Driven Mindhunters Nightmare on Elm Street The Dream Master made an entertaining non stop action picture This film is spectacular exciting visually exciting action picture with pl

In [5]:
# split the datasets into batches
BATCH_SIZE = 1  # the batch size for a dataset iterator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

device: cuda


In [6]:
xtrain = torch.from_numpy(np.stack(train_data["text_bert_input_ids"].values))[:, 0:400]
xtrain_mask = torch.from_numpy(np.stack(train_data["text_bert_attention_mask"].values))[:, 0:400]
ytrain = torch.from_numpy(train_data["label"].values)
xvalid = torch.from_numpy(np.stack(valid_data["text_bert_input_ids"].values))[:, 0:400]
xvalid_mask = torch.from_numpy(np.stack(valid_data["text_bert_attention_mask"].values))[:, 0:400]
yvalid = torch.from_numpy(valid_data["label"].values)
xtest = torch.from_numpy(np.stack(test_data["text_bert_input_ids"].values))[:, 0:400]
xtest_mask = torch.from_numpy(np.stack(test_data["text_bert_attention_mask"].values))[:, 0:400]
ytest = torch.from_numpy(test_data["label"].values)

print(xtrain.shape, ytrain.shape)
print(xvalid.shape, yvalid.shape)
print(xtest.shape, ytest.shape)

torch.Size([20000, 400]) torch.Size([20000])
torch.Size([5000, 400]) torch.Size([5000])
torch.Size([25000, 400]) torch.Size([25000])


In [7]:
from torch.utils.data import DataLoader, TensorDataset
train_loader = DataLoader(TensorDataset(xtrain, xtrain_mask, ytrain), batch_size=BATCH_SIZE)
valid_loader = DataLoader(TensorDataset(xvalid, xvalid_mask, yvalid), batch_size=BATCH_SIZE)
test_loader = DataLoader(TensorDataset(xtest, xtest_mask, ytest), batch_size=BATCH_SIZE)

In [8]:
random.seed(2023)
np.random.seed(2023)
torch.manual_seed(2023)
torch.cuda.manual_seed(2023)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [9]:
# set up parameters
NUM_RNN_LAYERS = 4
HIDDEN_DIM_LSTM = 512 
HIDDEN_DIM_DENSE = 1 * HIDDEN_DIM_LSTM # HIDDEN_DIM_LSTM * NUM_RNN_LAYERS
OUTPUT_DIM = 1
CHUNCK_SIZE = 20
MAX_K = 4  # the output dimension for step size 0, 1, 2, 3
LABEL_DIM = 2
BATCH_SIZE = 1
gamma = 0.99
alpha = 0.2
learning_rate = 0.001

In [10]:
# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
# set up models
trns_lstm = Distilbert_LSTM(NUM_RNN_LAYERS, HIDDEN_DIM_LSTM, bert_checkpoint="lvwerra/distilbert-imdb").to(device)
print(trns_lstm)
policy_s = Policy_S(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, OUTPUT_DIM).to(device)
policy_n = Policy_N(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, MAX_K).to(device)
policy_c = Policy_C(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, LABEL_DIM).to(device)
value_net = ValueNetwork(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, OUTPUT_DIM).to(device)


Downloading (…)lve/main/config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at lvwerra/distilbert-imdb were not used when initializing DistilBertModel: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


dummy start: torch.Size([1, 20])
Transformer: torch.Size([1, 20, 768])
permute: torch.Size([20, 1, 768])
dummy start: torch.Size([1, 20])
Transformer: torch.Size([1, 20, 768])
permute: torch.Size([20, 1, 768])
lstm out: torch.Size([20, 1, 512]), hidden: torch.Size([4, 1, 512]), cell: torch.Size([4, 1, 512])
torch.Size([4, 1, 512])
Distilbert_LSTM(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
      

In [11]:
# set up optimiser
params_pg = list(policy_s.parameters()) + list(policy_c.parameters()) + list(policy_n.parameters())
optim_loss = optim.Adam(trns_lstm.parameters(), lr=learning_rate)
optim_policy = optim.Adam(params_pg, lr=learning_rate)
optim_value = optim.Adam(value_net.parameters(), lr=learning_rate)

In [12]:
def finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch):
    '''
    Called when a data sample has been processed.
    '''
    baseline_value_sum = torch.stack(baseline_value_batch).sum()
    policy_loss = torch.stack(policy_loss_sum).mean()
    encoder_loss = torch.stack(encoder_loss_sum).mean()
    objective_loss = encoder_loss - policy_loss + baseline_value_sum
    # set gradient to zero
    optim_loss.zero_grad()
    optim_policy.zero_grad()
    optim_value.zero_grad()
    # back propagation
    objective_loss.backward()
    # gradient update
    optim_loss.step()
    optim_policy.step()
    optim_value.step()

In [13]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [14]:
trns_lstm.distilbert.requires_grad_(False)

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [15]:
print('Training starts...')

for epoch in range(5):
    print('\nEpoch', epoch+1)
    # log the start time of the epoch
    start = time.time()
    # set the models in training mode
    trns_lstm.train()
    policy_s.train()
    policy_n.train()
    policy_c.train()
    # reset the count of reread_or_skim_times
    reread_or_skim_times = 0
    policy_loss_sum = []
    encoder_loss_sum = []
    baseline_value_batch = []
    pbar = tqdm(train_loader)
    cm = np.zeros((LABEL_DIM, LABEL_DIM))
    for index, (x, xmask, y) in enumerate(pbar):
        label = y.to(device)
        text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
        text_mask = xmask.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE)
        curr_step = 0  # the position of the current chunk
        h_0 = torch.zeros([NUM_RNN_LAYERS,1,HIDDEN_DIM_LSTM]).to(device)  # run on GPU
        c_0 = torch.zeros([NUM_RNN_LAYERS,1,HIDDEN_DIM_LSTM]).to(device)
        count = 0  # maximum skim/reread time: 5
        baseline_value_ep = []
        saved_log_probs = []  # for the use of policy gradient update
        # collect the computational costs for every time step
        cost_ep = []  
        while curr_step < CHUNCK_SIZE and count < 5: 
            # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
            # update count
            count += 1
            # pass the input through cnn-lstm and policy s
            text_input = text[curr_step] # text_input 1*20
            text_mask_input = text_mask[curr_step]
            # print(f"input h: {h_0.shape}")
            ht, ct = trns_lstm(text_input, text_mask_input, h_0, c_0)  #ht: NUM_RNN_LAYERS * 1 * HIDDEN_DIM_LSTM
            # separate the value which is the input of value net
            ht_ = ht.clone().detach().requires_grad_(True)
            
            # NUM_RNN_LAYERS * 1 * 128, next input of lstm
            h_0 = ht # .unsqueeze(0)
            c_0 = ct
            
            ht_ = ht_[-1, :, :]# .view(1, ht_.shape[0] * ht_.shape[2]) # ht_: 1, NUM_RNN_LAYERS * HIDDEN_DIM_LSTM
            ht = ht[-1, :, :] # .view(1, ht.shape[0] * ht.shape[2])
            # compute a baseline value for the value network
            bi = value_net(ht_)
            # draw a stop decision
            stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
            stop_decision = stop_decision.item()
            if stop_decision == 1: # classify
                break
            else: 
                reread_or_skim_times += 1
                # draw an action (reread or skip)
                step, log_prob_n = sample_policy_n(ht, policy_n)
                curr_step += int(step)  # reread or skip
                if curr_step < CHUNCK_SIZE and count < 5:
                    # If the code can still execute the next loop, it is not the last time step.
                    cost_ep.append(clstm_cost + s_cost + n_cost)
                    # add the baseline value
                    baseline_value_ep.append(bi)
                    # add the log prob for the current actions
                    saved_log_probs.append(log_prob_s + log_prob_n)
        # draw a predicted label
        output_c = policy_c(ht)
        # cross entrpy loss input shape: input(N, C), target(N)
        loss = criterion(output_c, label)  # positive value
        # draw a predicted label 
        pred_label, log_prob_c = sample_policy_c(output_c)
        # update the confusion matrix
        cm[pred_label][y] += 1
        if stop_decision == 1:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost)
            saved_log_probs.append(log_prob_s + log_prob_c)
        else:
            # At the moment, the probability of drawing a stop decision is 1,
            # so its log probability is zero which can be ignored in th sum.
            saved_log_probs.append(log_prob_c.unsqueeze(0))
        # add the baseline value
        baseline_value_ep.append(bi)
        # add the cross entropy loss
        encoder_loss_sum.append(loss)
        # set reward for the current data sample
        if pred_label.item() == label:
            reward = 1 
        else:
            reward = -1 
        # compute the policy losses and value losses for the current episode
        policy_loss_ep = []
        value_losses = []
        for i, log_prob in enumerate(saved_log_probs):
            # baseline_value_ep[i].item(): updating the policy loss doesn't include the gradient of baseline values
            advantage = reward - baseline_value_ep[i].item()
            policy_loss_ep.append(log_prob * advantage)
            value_losses.append((reward - baseline_value_ep[i]) ** 2)
        policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
        baseline_value_batch.append(torch.cat(value_losses).sum())
        # update gradients
        if (index + 1) % 32 == 0:  # take the average of samples, backprop
            finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
            del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]
            
        if (index + 1) % 32 == 0:
            stats = calculate_stats_from_cm(cm)
            cm = np.zeros((LABEL_DIM, LABEL_DIM))
            acc = stats["accuracy"]
            recall = stats["recall"]
            precision = stats["precision"]
            f1 = stats["f1"]
            writer.add_scalar("train_accuracy", acc, len(train_loader)*epoch + index)
            writer.add_scalar("train_recall", recall,  len(train_loader)*epoch + index)
            writer.add_scalar("train_precision", precision,  len(train_loader)*epoch + index)
            writer.add_scalar("train_f1", f1,  len(train_loader)*epoch + index)
            pbar.set_description(f"episode: {index + 1}, reread_or_skim_times: {reread_or_skim_times}, accuracy: {acc:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1:.2f}")
            
            """print(f'\n current episode: {index + 1}')
            # log the current position of the text which the agent has gone through
            print('curr_step: ', curr_step)
            # log the sum of the rereading and skimming times
            print(f'current reread_or_skim_times: {reread_or_skim_times}')"""


    print('Epoch time elapsed: %.2f s' % (time.time() - start))
    print('reread_or_skim_times in this epoch:', reread_or_skim_times)
    count_all, count_correct = evaluate_lm(trns_lstm, policy_s, policy_n, policy_c, valid_loader)
    print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
    writer.add_scalar("validation_acccuracy", count_correct / count_all,  len(train_loader)*epoch + index)
    # count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_loader)
    # print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
    
print('Compute the accuracy on the testing set...')
count_all, count_correct = evaluate_lm(trns_lstm, policy_s, policy_n, policy_c, test_loader)
print('Accuracy on the testing set: %.2f' % (count_correct / count_all))

Training starts...

Epoch 1


episode: 20000, reread_or_skim_times: 87513, accuracy: 0.719, precision: 0.742, recall: 0.727, f1: 0.72: 100%|██████████| 20000/20000 [32:23<00:00, 10.29it/s]


Epoch time elapsed: 1943.52 s
reread_or_skim_times in this epoch: 87513


Evaluating...: 100%|██████████| 5000/5000 [06:02<00:00, 13.80it/s]


Evaluation time elapsed: 362.44 s
Average FLOPs per sample:  7072170
Epoch: 1, Accuracy on the validation set: 0.79

Epoch 2


episode: 20000, reread_or_skim_times: 99996, accuracy: 0.719, precision: 0.742, recall: 0.727, f1: 0.72: 100%|██████████| 20000/20000 [38:00<00:00,  8.77it/s]


Epoch time elapsed: 2280.81 s
reread_or_skim_times in this epoch: 99996


Evaluating...: 100%|██████████| 5000/5000 [05:47<00:00, 14.37it/s]


Evaluation time elapsed: 347.99 s
Average FLOPs per sample:  7072170
Epoch: 2, Accuracy on the validation set: 0.80

Epoch 3


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.875, precision: 0.878, recall: 0.878, f1: 0.87: 100%|██████████| 20000/20000 [36:45<00:00,  9.07it/s]


Epoch time elapsed: 2205.86 s
reread_or_skim_times in this epoch: 100000


Evaluating...: 100%|██████████| 5000/5000 [05:42<00:00, 14.61it/s]


Evaluation time elapsed: 342.15 s
Average FLOPs per sample:  7072170
Epoch: 3, Accuracy on the validation set: 0.79

Epoch 4


episode: 20000, reread_or_skim_times: 99819, accuracy: 0.875, precision: 0.895, recall: 0.882, f1: 0.87: 100%|██████████| 20000/20000 [36:05<00:00,  9.24it/s]


Epoch time elapsed: 2165.53 s
reread_or_skim_times in this epoch: 99819


Evaluating...: 100%|██████████| 5000/5000 [06:01<00:00, 13.84it/s]


Evaluation time elapsed: 361.29 s
Average FLOPs per sample:  7072170
Epoch: 4, Accuracy on the validation set: 0.81

Epoch 5


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.781, precision: 0.781, recall: 0.782, f1: 0.78: 100%|██████████| 20000/20000 [36:07<00:00,  9.23it/s]


Epoch time elapsed: 2167.28 s
reread_or_skim_times in this epoch: 100000


Evaluating...: 100%|██████████| 5000/5000 [05:40<00:00, 14.67it/s]


Evaluation time elapsed: 340.78 s
Average FLOPs per sample:  7072170
Epoch: 5, Accuracy on the validation set: 0.78
Compute the accuracy on the testing set...


Evaluating...: 100%|██████████| 25000/25000 [28:46<00:00, 14.48it/s]

Evaluation time elapsed: 1726.08 s
Average FLOPs per sample:  7072170
Accuracy on the testing set: 0.78





In [20]:
print('Training starts...')

for epoch in range(5, 30):
    print('\nEpoch', epoch+1)
    # log the start time of the epoch
    start = time.time()
    # set the models in training mode
    trns_lstm.train()
    policy_s.train()
    policy_n.train()
    policy_c.train()
    # reset the count of reread_or_skim_times
    reread_or_skim_times = 0
    policy_loss_sum = []
    encoder_loss_sum = []
    baseline_value_batch = []
    pbar = tqdm(train_loader)
    cm = np.zeros((LABEL_DIM, LABEL_DIM))
    for index, (x, xmask, y) in enumerate(pbar):
        label = y.to(device)
        text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
        text_mask = xmask.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE)
        curr_step = 0  # the position of the current chunk
        h_0 = torch.zeros([NUM_RNN_LAYERS,1,HIDDEN_DIM_LSTM]).to(device)  # run on GPU
        c_0 = torch.zeros([NUM_RNN_LAYERS,1,HIDDEN_DIM_LSTM]).to(device)
        count = 0  # maximum skim/reread time: 5
        baseline_value_ep = []
        saved_log_probs = []  # for the use of policy gradient update
        # collect the computational costs for every time step
        cost_ep = []  
        while curr_step < CHUNCK_SIZE and count < 5: 
            # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
            # update count
            count += 1
            # pass the input through cnn-lstm and policy s
            text_input = text[curr_step] # text_input 1*20
            text_mask_input = text_mask[curr_step]
            # print(f"input h: {h_0.shape}")
            ht, ct = trns_lstm(text_input, text_mask_input, h_0, c_0)  #ht: NUM_RNN_LAYERS * 1 * HIDDEN_DIM_LSTM
            # separate the value which is the input of value net
            ht_ = ht.clone().detach().requires_grad_(True)
            
            # NUM_RNN_LAYERS * 1 * 128, next input of lstm
            h_0 = ht # .unsqueeze(0)
            c_0 = ct
            
            ht_ = ht_[-1, :, :]# .view(1, ht_.shape[0] * ht_.shape[2]) # ht_: 1, NUM_RNN_LAYERS * HIDDEN_DIM_LSTM
            ht = ht[-1, :, :] # .view(1, ht.shape[0] * ht.shape[2])
            # compute a baseline value for the value network
            bi = value_net(ht_)
            # draw a stop decision
            stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
            stop_decision = stop_decision.item()
            if stop_decision == 1: # classify
                break
            else: 
                reread_or_skim_times += 1
                # draw an action (reread or skip)
                step, log_prob_n = sample_policy_n(ht, policy_n)
                curr_step += int(step)  # reread or skip
                if curr_step < CHUNCK_SIZE and count < 5:
                    # If the code can still execute the next loop, it is not the last time step.
                    cost_ep.append(clstm_cost + s_cost + n_cost)
                    # add the baseline value
                    baseline_value_ep.append(bi)
                    # add the log prob for the current actions
                    saved_log_probs.append(log_prob_s + log_prob_n)
        # draw a predicted label
        output_c = policy_c(ht)
        # cross entrpy loss input shape: input(N, C), target(N)
        loss = criterion(output_c, label)  # positive value
        # draw a predicted label 
        pred_label, log_prob_c = sample_policy_c(output_c)
        # update the confusion matrix
        cm[pred_label][y] += 1
        if stop_decision == 1:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost)
            saved_log_probs.append(log_prob_s + log_prob_c)
        else:
            # At the moment, the probability of drawing a stop decision is 1,
            # so its log probability is zero which can be ignored in th sum.
            saved_log_probs.append(log_prob_c.unsqueeze(0))
        # add the baseline value
        baseline_value_ep.append(bi)
        # add the cross entropy loss
        encoder_loss_sum.append(loss)
        # set reward for the current data sample
        if pred_label.item() == label:
            reward = 1 
        else:
            reward = -1 
        # compute the policy losses and value losses for the current episode
        policy_loss_ep = []
        value_losses = []
        for i, log_prob in enumerate(saved_log_probs):
            # baseline_value_ep[i].item(): updating the policy loss doesn't include the gradient of baseline values
            advantage = reward - baseline_value_ep[i].item()
            policy_loss_ep.append(log_prob * advantage)
            value_losses.append((reward - baseline_value_ep[i]) ** 2)
        policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
        baseline_value_batch.append(torch.cat(value_losses).sum())
        # update gradients
        if (index + 1) % 32 == 0:  # take the average of samples, backprop
            finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
            del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]
            
        if (index + 1) % 32 == 0:
            stats = calculate_stats_from_cm(cm)
            cm = np.zeros((LABEL_DIM, LABEL_DIM))
            acc = stats["accuracy"]
            recall = stats["recall"]
            precision = stats["precision"]
            f1 = stats["f1"]
            writer.add_scalar("train_accuracy", acc, len(train_loader)*epoch + index)
            writer.add_scalar("train_recall", recall,  len(train_loader)*epoch + index)
            writer.add_scalar("train_precision", precision,  len(train_loader)*epoch + index)
            writer.add_scalar("train_f1", f1,  len(train_loader)*epoch + index)
            pbar.set_description(f"episode: {index + 1}, reread_or_skim_times: {reread_or_skim_times}, accuracy: {acc:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1:.2f}")
            
            """print(f'\n current episode: {index + 1}')
            # log the current position of the text which the agent has gone through
            print('curr_step: ', curr_step)
            # log the sum of the rereading and skimming times
            print(f'current reread_or_skim_times: {reread_or_skim_times}')"""


    print('Epoch time elapsed: %.2f s' % (time.time() - start))
    print('reread_or_skim_times in this epoch:', reread_or_skim_times)
    count_all, count_correct = evaluate_lm(trns_lstm, policy_s, policy_n, policy_c, valid_loader)
    print('Epoch: %s, Accuracy on the validation set: %.3f' % (epoch + 1, count_correct / count_all))
    writer.add_scalar("validation_acccuracy", count_correct / count_all,  len(train_loader)*epoch + index)
    # count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_loader)
    # print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
    
print('Compute the accuracy on the testing set...')
count_all, count_correct = evaluate_lm(trns_lstm, policy_s, policy_n, policy_c, test_loader)
print('Accuracy on the testing set: %.2f' % (count_correct / count_all))

Training starts...

Epoch 6


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.812, precision: 0.816, recall: 0.816, f1: 0.81: 100%|██████████| 20000/20000 [36:19<00:00,  9.18it/s]


Epoch time elapsed: 2179.48 s
reread_or_skim_times in this epoch: 100000


Evaluating...: 100%|██████████| 5000/5000 [05:54<00:00, 14.09it/s]


Evaluation time elapsed: 354.81 s
Average FLOPs per sample:  7072170
Epoch: 6, Accuracy on the validation set: 0.829

Epoch 7


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.781, precision: 0.841, recall: 0.794, f1: 0.78: 100%|██████████| 20000/20000 [36:27<00:00,  9.14it/s]


Epoch time elapsed: 2187.63 s
reread_or_skim_times in this epoch: 100000


Evaluating...: 100%|██████████| 5000/5000 [05:57<00:00, 13.97it/s]


Evaluation time elapsed: 357.93 s
Average FLOPs per sample:  7072170
Epoch: 7, Accuracy on the validation set: 0.826

Epoch 8


episode: 4800, reread_or_skim_times: 24000, accuracy: 0.812, precision: 0.802, recall: 0.817, f1: 0.81:  24%|██▍       | 4800/20000 [08:43<27:38,  9.16it/s]  


KeyboardInterrupt: 

In [16]:
trns_lstm.distilbert = trns_lstm.distilbert.requires_grad_(False)

In [13]:
# count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, valid_loader)
# print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
count_all, count_correct = evaluate_lm(trns_lstm, policy_s, policy_n, policy_c, train_loader)
print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))

Evaluating...: 100%|██████████| 20000/20000 [05:23<00:00, 61.92it/s]


Evaluation time elapsed: 323.01 s
Average FLOPs per sample:  2639562


NameError: name 'epoch' is not defined

In [17]:
from datetime import datetime

now = datetime.now()
now_time = str(now.day) + "_" + str(now.month) + "_" + str(now.year) + "_" + str(now.hour) + "_" + str(now.minute)
now_date = str(now.day) + "_" + str(now.month) + "_" + str(now.year)
now_date, now_time

('31_5_2023', '31_5_2023_10_1')

In [18]:
import os

def save_models(time_str, date_str, path: str = "."):
    os.makedirs(f"{path}\\{date_str}", exist_ok=True)
    torch.save(trns_lstm, f"{path}\\{date_str}\\clstm_{time_str}.pth")
    torch.save(policy_s, f"{path}\\{date_str}\\policy_s_{time_str}.pth")
    torch.save(policy_n, f"{path}\\{date_str}\\policy_n_{time_str}.pth")
    torch.save(policy_c, f"{path}\\{date_str}\\policy_c_{time_str}.pth")


In [19]:
save_models(now_time, now_date, "saved_models\\distilbert_lstm")

In [25]:
trns_lstm.eval()
policy_s.eval()
policy_n.eval()
policy_c.eval()
action_logs = []
seen_logs = []
writer = SummaryWriter()
for i, (x, xm, y) in enumerate(valid_loader):
    print(i)
    print(valid_data["text"].iloc[i])
    print(y)
    action_log_batch = []
    seen_batch = []
    label = y.to(device).long() # for cross entropy loss, the long type is required
    text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
    text_mask = xm.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE)
    curr_step = 0
    h_0 = torch.zeros([NUM_RNN_LAYERS,1,HIDDEN_DIM_LSTM]).to(device)
    c_0 = torch.zeros([NUM_RNN_LAYERS,1,HIDDEN_DIM_LSTM]).to(device)
    count = 0
    while curr_step < 20 and count < 5: # loop until a text can be classified or currstep is up to 20
        count += 1
        # pass the input through cnn-lstm and policy s
        text_input = text[curr_step] # text_input 1*20
        text_input_mask = text_mask[curr_step]
        text_str = train_data["text"].iloc[i]
        seen_batch.append(text_str.split()[curr_step * 20: (curr_step+1)*20])
        ht, ct = trns_lstm(text_input, text_input_mask, h_0, c_0)  # 1 * 128
        # if count == 1 and i == 0:
        #     writer.add_graph(clstm, [text_input, h_0, c_0], verbose=True)
        h_0 = ht #.unsqueeze(0) # NUM_RNN_LAYERS * 1 * LSTM_HIDDEN_DIM, next input of lstm
        c_0 = ct

        ht = ht[-1, :, :] # .view(1, ht.shape[0] * ht.shape[2])
        # ht_ = ht.view(1, ht.shape[0] * ht.shape[2])
        # draw a stop decision
        stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
        # if count == 1 and i == 1:
        #     writer.add_graph(policy_s, ht)
        stop_decision = stop_decision.item()
        if stop_decision == 1: # classify
            break
        else:
            # draw an action (reread or skip)
            step, log_prob_n = sample_policy_n(ht, policy_n)
            # if count == 1 and i == 2:
            #     writer.add_graph(policy_n, ht)
            curr_step += int(step)  # reread or skip
            action_log_batch.append({"skip/reread": step})
    # draw a predicted label
    output_c = policy_c(ht)
    if i == 3:
        writer.add_graph(policy_c, ht, verbose=True)
        
    # draw a predicted label 
    pred_label, log_prob_c = sample_policy_c(output_c)
    action_log_batch.append({"prediction": pred_label, "real": label})
    if pred_label.item() == label:
        count_correct += 1
    count_all += 1
    action_logs.append(action_log_batch)
    seen_logs.append(seen_batch)
    if i == 10:
        break

writer.close()

0
Dumb is as dumb does in this thoroughly uninteresting supposed black comedy Essentially what starts out as Chris Klein trying to maintain low profile eventually morphs into an uninspired version of The Three Amigos only without any laughs In order for black comedy to work it must be outrageous which Play Dead is not In order for black comedy to work it cannot be mean spirited which Play Dead is What Play Dead really is is town full of nut jobs Fred Dunst does however do pretty fair imitation of Billy Bob Thornton character from Simple Plan while Jake Busey does pretty fair imitation of well Jake Busey MERK
tensor([0])
1
I dug out from my garage some old musicals and this is another one of my favorites It was written by Jay Alan Lerner and directed by Vincent Minelli It won two Academy Awards for Best Picture of and Best Screenplay The story of an American painter in Paris who tries to make it big Nina Foch is sophisticated lady of means and is very interested in helping him but soon 

  if a.grad is not None:


graph(%self.1 : __torch__.networks.Policy_C,
      %ht : Float(1, 512, strides=[512, 1], requires_grad=1, device=cuda:0)):
  %fc2 : __torch__.torch.nn.modules.linear.___torch_mangle_0.Linear = prim::GetAttr[name="fc2"](%self.1)
  %relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self.1)
  %dropout : __torch__.torch.nn.modules.dropout.Dropout = prim::GetAttr[name="dropout"](%self.1)
  %fc1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc1"](%self.1)
  %bias.1 : Tensor = prim::GetAttr[name="bias"](%fc1)
  %weight.1 : Tensor = prim::GetAttr[name="weight"](%fc1)
  %input.1 : Float(1, 512, strides=[512, 1], requires_grad=1, device=cuda:0) = aten::linear(%ht, %weight.1, %bias.1), scope: __module.fc1 # c:\Users\mrbal\anaconda3\envs\blaze\lib\site-packages\torch\nn\modules\linear.py:114:0
  %49 : bool = prim::Constant[value=0](), scope: __module.dropout # c:\Users\mrbal\anaconda3\envs\blaze\lib\site-packages\torch\nn\functional.py:1252:0
  %50 

In [27]:
action_logs[0]

[{'skip/reread': 1},
 {'skip/reread': 1},
 {'skip/reread': 1},
 {'skip/reread': 1},
 {'skip/reread': 1},
 {'prediction': tensor([0], device='cuda:0'),
  'real': tensor([0], device='cuda:0')}]

In [91]:
ix = 3
action_logs[ix], " ".join(train_data["text"].iloc[ix].split()[0:400])

([{'skip/reread': 2},
  {'skip/reread': 2},
  {'skip/reread': 2},
  {'skip/reread': 2},
  {'skip/reread': 2},
  {'prediction': tensor([1], device='cuda:0'),
   'real': tensor([1], device='cuda:0')}],
 'Can Scarcely Imagine Better Movie Than This br br Hey before you all go Chick Flick on me am very Large Strong Masculine Macho Man who happens to think this was one of the better movies of the last years br br The acting was Superb and the Story was Marvelous This is wonderful medicine for the heart and soul The Acting could not have been better nor the movie better cast br br have known for Good while that Mercedes Ruehl along with Holly Hunter Joan Plowright Dame Edith Evans Sissy Spacek Judi Dench is among the greatest actresses ever to appear on film And of course Cloris Leachman also in this film in my view may in fact exceed them all in the shear magnum of her talent and varied roles she has appeared in over the years At any rate this was an Amazing cast This film was like book tha

In [100]:
" ".join(seen_logs[ix][5])

IndexError: list index out of range

In [46]:
valid_data.iloc[0]

text                        Dumb is as dumb does in this thoroughly uninte...
label                                                                       0
label_str                                                                 neg
text_bert_input_ids         [101, 12873, 2003, 2004, 12873, 2515, 1999, 20...
text_bert_attention_mask    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
Name: 6868, dtype: object

In [103]:
count_correct / count_all

0.9831278390655419