In [1]:
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Categorical
from torchtext import datasets
import os
import time
import numpy as np 
import random
import argparse

from networks import CNN_LSTM, Policy_C, Policy_N, Policy_S, ValueNetwork
from utils.utils import sample_policy_c, sample_policy_n, sample_policy_s, evaluate, compute_policy_value_losses
from utils.utils import cnn_cost, clstm_cost, c_cost, n_cost, s_cost, openDfFromPickle, calculate_stats_from_cm

In [2]:
print('Reading data...')
train_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_train_distilbert-base-uncased.pkl")
valid_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_val_distilbert-base-uncased.pkl")
test_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_test_distilbert-base-uncased.pkl")
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')


Reading data...
Number of training examples: 20000
Number of validation examples: 5000
Number of testing examples: 25000


In [3]:
train_data

Unnamed: 0,text,label,label_str,text_bert_input_ids,text_bert_attention_mask
23311,I borrowed this movie despite its extremely lo...,1,pos,"[101, 1045, 11780, 2023, 3185, 2750, 2049, 518...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
23623,After the unexpected accident that killed an i...,1,pos,"[101, 2044, 1996, 9223, 4926, 2008, 2730, 2019...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1020,On the summer blockbuster hit BASEketball This...,0,neg,"[101, 2006, 1996, 2621, 27858, 2718, 2918, 348...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
12645,Can Scarcely Imagine Better Movie Than This br...,1,pos,"[101, 2064, 20071, 5674, 2488, 3185, 2084, 202...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1533,A still famous but decadent actor Morgan Freem...,0,neg,"[101, 1037, 2145, 3297, 2021, 5476, 3372, 3364...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...
21575,My discovery of the cinema of Jan Svankmajer o...,1,pos,"[101, 2026, 5456, 1997, 1996, 5988, 1997, 5553...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
5390,The story is similar to ET an extraterrestrial...,0,neg,"[101, 1996, 2466, 2003, 2714, 2000, 3802, 2019...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
860,I have read the novel Reaper of Ben Mezrich fe...,0,neg,"[101, 1045, 2031, 3191, 1996, 3117, 19559, 199...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
15795,Went to see this finnish film and ve got to sa...,1,pos,"[101, 2253, 2000, 2156, 2023, 6983, 2143, 1998...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [4]:
train_data["text"].iloc[1]

'After the unexpected accident that killed an inexperienced climber Michelle Joyner Eight months has passed The Rocky Mountain Rescue receive distress call set by brilliant terrorist mastermind Eric Quaien John Lithgow Quaien has lost three large cases that has millions of dollars inside Two experienced climbers Walker Sylvester Stallone and Tucker Micheal Rooker and helicopter pilot Janine Turner are to the rescue but they are set by trap by Quaien and his men Now the two climbers and pilot are forced to play deadly game of hide and seek While Quaien is trying to find the millions of dollars and he kidnapped Tucker to find the money Once Tucker finds the money Tucker will be dead Against explosive firepower bitter cold and dizzying heights Walker must outwit Quaien for survival br br Directed by Renny Harlin Driven Mindhunters Nightmare on Elm Street The Dream Master made an entertaining non stop action picture This film is spectacular exciting visually exciting action picture with pl

In [5]:
# split the datasets into batches
BATCH_SIZE = 1  # the batch size for a dataset iterator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

device: cuda


In [6]:
xtrain = torch.from_numpy(np.stack(train_data["text_bert_input_ids"].values))[:, 0:400]
ytrain = torch.from_numpy(train_data["label"].values)
xvalid = torch.from_numpy(np.stack(valid_data["text_bert_input_ids"].values))[:, 0:400]
yvalid = torch.from_numpy(valid_data["label"].values)
xtest = torch.from_numpy(np.stack(test_data["text_bert_input_ids"].values))[:, 0:400]
ytest = torch.from_numpy(test_data["label"].values)

print(xtrain.shape, ytrain.shape)
print(xvalid.shape, yvalid.shape)
print(xtest.shape, ytest.shape)

torch.Size([20000, 400]) torch.Size([20000])
torch.Size([5000, 400]) torch.Size([5000])
torch.Size([25000, 400]) torch.Size([25000])


In [7]:
from torch.utils.data import DataLoader, TensorDataset
train_loader = DataLoader(TensorDataset(xtrain, ytrain), batch_size=BATCH_SIZE)
valid_loader = DataLoader(TensorDataset(xvalid, yvalid), batch_size=BATCH_SIZE)
test_loader = DataLoader(TensorDataset(xtest, ytest), batch_size=BATCH_SIZE)

In [8]:
random.seed(2023)
np.random.seed(2023)
torch.manual_seed(2023)
torch.cuda.manual_seed(2023)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [9]:
# set up parameters
INPUT_DIM = 30522
EMBEDDING_DIM = 100
NUM_RNN_LAYERS = 1
KER_SIZE = 5
HIDDEN_DIM_LSTM = 128 
HIDDEN_DIM_DENSE = HIDDEN_DIM_LSTM * NUM_RNN_LAYERS
OUTPUT_DIM = 1
CHUNCK_SIZE = 20
MAX_K = 4  # the output dimension for step size 0, 1, 2, 3
LABEL_DIM = 2
N_FILTERS = 128
BATCH_SIZE = 1
gamma = 0.99
alpha = 0.2
learning_rate = 0.001

In [10]:
# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
# set up models
clstm = CNN_LSTM(INPUT_DIM, EMBEDDING_DIM, KER_SIZE, N_FILTERS, NUM_RNN_LAYERS, HIDDEN_DIM_LSTM).to(device)
print(clstm)
policy_s = Policy_S(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, OUTPUT_DIM).to(device)
policy_n = Policy_N(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, MAX_K).to(device)
policy_c = Policy_C(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, LABEL_DIM).to(device)
value_net = ValueNetwork(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, OUTPUT_DIM).to(device)


CNN_LSTM(
  (embedding): Embedding(30522, 100)
  (conv): Conv2d(1, 128, kernel_size=(5, 100), stride=(1, 1))
  (lstm): LSTM(128, 128)
  (dropout): Dropout(p=0.2, inplace=False)
  (relu): ReLU()
)


In [11]:
# set up optimiser
params_pg = list(policy_s.parameters()) + list(policy_c.parameters()) + list(policy_n.parameters())
optim_loss = optim.Adam(clstm.parameters(), lr=learning_rate)
optim_policy = optim.Adam(params_pg, lr=learning_rate)
optim_value = optim.Adam(value_net.parameters(), lr=learning_rate)

In [12]:
def finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch):
    '''
    Called when a data sample has been processed.
    '''
    baseline_value_sum = torch.stack(baseline_value_batch).sum()
    policy_loss = torch.stack(policy_loss_sum).mean()
    encoder_loss = torch.stack(encoder_loss_sum).mean()
    objective_loss = encoder_loss - policy_loss + baseline_value_sum
    # set gradient to zero
    optim_loss.zero_grad()
    optim_policy.zero_grad()
    optim_value.zero_grad()
    # back propagation
    objective_loss.backward()
    # gradient update
    optim_loss.step()
    optim_policy.step()
    optim_value.step()

In [15]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [16]:
print('Training starts...')

for epoch in range(5):
    print('\nEpoch', epoch+1)
    # log the start time of the epoch
    start = time.time()
    # set the models in training mode
    clstm.train()
    policy_s.train()
    policy_n.train()
    policy_c.train()
    # reset the count of reread_or_skim_times
    reread_or_skim_times = 0
    policy_loss_sum = []
    encoder_loss_sum = []
    baseline_value_batch = []
    pbar = tqdm(train_loader)
    cm = np.zeros((LABEL_DIM, LABEL_DIM))
    for index, (x, y) in enumerate(pbar):
        label = y.to(device)
        text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
        curr_step = 0  # the position of the current chunk
        h_0 = torch.zeros([1,1,128]).to(device)  # run on GPU
        c_0 = torch.zeros([1,1,128]).to(device)
        count = 0  # maximum skim/reread time: 5
        baseline_value_ep = []
        saved_log_probs = []  # for the use of policy gradient update
        # collect the computational costs for every time step
        cost_ep = []  
        while curr_step < CHUNCK_SIZE and count < 5: 
            # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
            # update count
            count += 1
            # pass the input through cnn-lstm and policy s
            text_input = text[curr_step] # text_input 1*20
            # print(f"input h: {h_0.shape}")
            ht, ct = clstm(text_input, h_0, c_0)  #ht: NUM_RNN_LAYERS * 1 * HIDDEN_DIM_LSTM
            # separate the value which is the input of value net
            ht_ = ht.clone().detach().requires_grad_(True)
            # ht_ = ht_.view(1, ht_.shape[0] * ht_.shape[2]) # ht_: 1, NUM_RNN_LAYERS * HIDDEN_DIM_LSTM
            # compute a baseline value for the value network
            bi = value_net(ht_)
            # NUM_RNN_LAYERS * 1 * 128, next input of lstm
            h_0 = ht.unsqueeze(0)
            c_0 = ct
            # draw a stop decision
            stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
            stop_decision = stop_decision.item()
            if stop_decision == 1: # classify
                break
            else: 
                reread_or_skim_times += 1
                # draw an action (reread or skip)
                step, log_prob_n = sample_policy_n(ht, policy_n)
                curr_step += int(step)  # reread or skip
                if curr_step < CHUNCK_SIZE and count < 5:
                    # If the code can still execute the next loop, it is not the last time step.
                    cost_ep.append(clstm_cost + s_cost + n_cost)
                    # add the baseline value
                    baseline_value_ep.append(bi)
                    # add the log prob for the current actions
                    saved_log_probs.append(log_prob_s + log_prob_n)
        # draw a predicted label
        output_c = policy_c(ht)
        # cross entrpy loss input shape: input(N, C), target(N)
        loss = criterion(output_c, label)  # positive value
        # draw a predicted label 
        pred_label, log_prob_c = sample_policy_c(output_c)
        # update the confusion matrix
        cm[pred_label][y] += 1
        if stop_decision == 1:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost)
            saved_log_probs.append(log_prob_s + log_prob_c)
        else:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost + n_cost)
            # At the moment, the probability of drawing a stop decision is 1,
            # so its log probability is zero which can be ignored in th sum.
            saved_log_probs.append(log_prob_c.unsqueeze(0))
        # add the baseline value
        baseline_value_ep.append(bi)
        # add the cross entropy loss
        encoder_loss_sum.append(loss)
        # compute the policy losses and value losses for the current episode
        policy_loss_ep, value_losses = compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma)
        policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
        baseline_value_batch.append(torch.cat(value_losses).sum())
        # update gradients
        if (index + 1) % 32 == 0:  # take the average of samples, backprop
            finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
            del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]
            
        if (index + 1) % 32 == 0:
            stats = calculate_stats_from_cm(cm)
            cm = np.zeros((LABEL_DIM, LABEL_DIM))
            acc = stats["accuracy"]
            recall = stats["recall"]
            precision = stats["precision"]
            f1 = stats["f1"]
            writer.add_scalar("train_accuracy", acc, len(train_loader)*epoch + index)
            writer.add_scalar("train_recall", recall,  len(train_loader)*epoch + index)
            writer.add_scalar("train_precision", precision,  len(train_loader)*epoch + index)
            writer.add_scalar("train_f1", f1,  len(train_loader)*epoch + index)
            pbar.set_description(f"episode: {index + 1}, reread_or_skim_times: {reread_or_skim_times}, accuracy: {acc:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1:.2f}")
            
            """print(f'\n current episode: {index + 1}')
            # log the current position of the text which the agent has gone through
            print('curr_step: ', curr_step)
            # log the sum of the rereading and skimming times
            print(f'current reread_or_skim_times: {reread_or_skim_times}')"""


    print('Epoch time elapsed: %.2f s' % (time.time() - start))
    print('reread_or_skim_times in this epoch:', reread_or_skim_times)
    count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, valid_loader)
    print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
    writer.add_scalar("validation_acccuracy", count_correct / count_all,  len(train_loader)*epoch + index)
    # count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_loader)
    # print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
    
print('Compute the accuracy on the testing set...')
count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, test_loader)
print('Accuracy on the testing set: %.2f' % (count_correct / count_all))

Training starts...

Epoch 1


episode: 20000, reread_or_skim_times: 99996, accuracy: 0.594, precision: 0.623, recall: 0.606, f1: 0.58: 100%|██████████| 20000/20000 [14:50<00:00, 22.46it/s]


Epoch time elapsed: 890.50 s
reread_or_skim_times in this epoch: 99996
Evaluation time elapsed: 85.76 s
Average FLOPs per sample:  7072170
Epoch: 1, Accuracy on the validation set: 0.50

Epoch 2


episode: 20000, reread_or_skim_times: 99990, accuracy: 0.500, precision: 0.511, recall: 0.510, f1: 0.49: 100%|██████████| 20000/20000 [14:54<00:00, 22.36it/s]


Epoch time elapsed: 894.66 s
reread_or_skim_times in this epoch: 99990
Evaluation time elapsed: 88.50 s
Average FLOPs per sample:  7072170
Epoch: 2, Accuracy on the validation set: 0.52

Epoch 3


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.812, precision: 0.830, recall: 0.820, f1: 0.81: 100%|██████████| 20000/20000 [16:37<00:00, 20.06it/s]


Epoch time elapsed: 997.09 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 89.26 s
Average FLOPs per sample:  7072170
Epoch: 3, Accuracy on the validation set: 0.66

Epoch 4


episode: 20000, reread_or_skim_times: 99989, accuracy: 0.906, precision: 0.906, recall: 0.908, f1: 0.91: 100%|██████████| 20000/20000 [14:42<00:00, 22.66it/s]


Epoch time elapsed: 882.76 s
reread_or_skim_times in this epoch: 99989
Evaluation time elapsed: 82.24 s
Average FLOPs per sample:  7072170
Epoch: 4, Accuracy on the validation set: 0.70

Epoch 5


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.781, precision: 0.790, recall: 0.786, f1: 0.78: 100%|██████████| 20000/20000 [14:42<00:00, 22.68it/s]


Epoch time elapsed: 882.01 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 84.74 s
Average FLOPs per sample:  7072170
Epoch: 5, Accuracy on the validation set: 0.73
Compute the accuracy on the testing set...
Evaluation time elapsed: 410.76 s
Average FLOPs per sample:  7072170
Accuracy on the testing set: 0.72


In [17]:
print('Training starts...')

for epoch in range(15, 25):
    print('\nEpoch', epoch+1)
    # log the start time of the epoch
    start = time.time()
    # set the models in training mode
    clstm.train()
    policy_s.train()
    policy_n.train()
    policy_c.train()
    # reset the count of reread_or_skim_times
    reread_or_skim_times = 0
    policy_loss_sum = []
    encoder_loss_sum = []
    baseline_value_batch = []
    pbar = tqdm(train_loader)
    cm = np.zeros((LABEL_DIM, LABEL_DIM))
    for index, (x, y) in enumerate(pbar):
        label = y.to(device)
        text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
        curr_step = 0  # the position of the current chunk
        h_0 = torch.zeros([1,1,128]).to(device)  # run on GPU
        c_0 = torch.zeros([1,1,128]).to(device)
        count = 0  # maximum skim/reread time: 5
        baseline_value_ep = []
        saved_log_probs = []  # for the use of policy gradient update
        # collect the computational costs for every time step
        cost_ep = []  
        while curr_step < CHUNCK_SIZE and count < 5: 
            # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
            # update count
            count += 1
            # pass the input through cnn-lstm and policy s
            text_input = text[curr_step] # text_input 1*20
            # print(f"input h: {h_0.shape}")
            ht, ct = clstm(text_input, h_0, c_0)  #ht: NUM_RNN_LAYERS * 1 * HIDDEN_DIM_LSTM
            # separate the value which is the input of value net
            ht_ = ht.clone().detach().requires_grad_(True)
            # ht_ = ht_.view(1, ht_.shape[0] * ht_.shape[2]) # ht_: 1, NUM_RNN_LAYERS * HIDDEN_DIM_LSTM
            # compute a baseline value for the value network
            bi = value_net(ht_)
            # NUM_RNN_LAYERS * 1 * 128, next input of lstm
            h_0 = ht.unsqueeze(0)
            c_0 = ct
            # draw a stop decision
            stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
            stop_decision = stop_decision.item()
            if stop_decision == 1: # classify
                break
            else: 
                reread_or_skim_times += 1
                # draw an action (reread or skip)
                step, log_prob_n = sample_policy_n(ht, policy_n)
                curr_step += int(step)  # reread or skip
                if curr_step < CHUNCK_SIZE and count < 5:
                    # If the code can still execute the next loop, it is not the last time step.
                    cost_ep.append(clstm_cost + s_cost + n_cost)
                    # add the baseline value
                    baseline_value_ep.append(bi)
                    # add the log prob for the current actions
                    saved_log_probs.append(log_prob_s + log_prob_n)
        # draw a predicted label
        output_c = policy_c(ht)
        # cross entrpy loss input shape: input(N, C), target(N)
        loss = criterion(output_c, label)  # positive value
        # draw a predicted label 
        pred_label, log_prob_c = sample_policy_c(output_c)
        # update the confusion matrix
        cm[pred_label][y] += 1
        if stop_decision == 1:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost)
            saved_log_probs.append(log_prob_s + log_prob_c)
        else:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost + n_cost)
            # At the moment, the probability of drawing a stop decision is 1,
            # so its log probability is zero which can be ignored in th sum.
            saved_log_probs.append(log_prob_c.unsqueeze(0))
        # add the baseline value
        baseline_value_ep.append(bi)
        # add the cross entropy loss
        encoder_loss_sum.append(loss)
        # compute the policy losses and value losses for the current episode
        policy_loss_ep, value_losses = compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma)
        policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
        baseline_value_batch.append(torch.cat(value_losses).sum())
        # update gradients
        if (index + 1) % 32 == 0:  # take the average of samples, backprop
            finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
            del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]
            
        if (index + 1) % 32 == 0:
            stats = calculate_stats_from_cm(cm)
            cm = np.zeros((LABEL_DIM, LABEL_DIM))
            acc = stats["accuracy"]
            recall = stats["recall"]
            precision = stats["precision"]
            f1 = stats["f1"]
            writer.add_scalar("train_accuracy", acc, len(train_loader)*epoch + index)
            writer.add_scalar("train_recall", recall,  len(train_loader)*epoch + index)
            writer.add_scalar("train_precision", precision,  len(train_loader)*epoch + index)
            writer.add_scalar("train_f1", f1,  len(train_loader)*epoch + index)
            pbar.set_description(f"episode: {index + 1}, reread_or_skim_times: {reread_or_skim_times}, accuracy: {acc:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1:.2f}")
            
            """print(f'\n current episode: {index + 1}')
            # log the current position of the text which the agent has gone through
            print('curr_step: ', curr_step)
            # log the sum of the rereading and skimming times
            print(f'current reread_or_skim_times: {reread_or_skim_times}')"""


    print('Epoch time elapsed: %.2f s' % (time.time() - start))
    print('reread_or_skim_times in this epoch:', reread_or_skim_times)
    count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, valid_loader)
    print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
    writer.add_scalar("validation_acccuracy", count_correct / count_all,  len(train_loader)*epoch + index)
    # count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_loader)
    # print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
    
print('Compute the accuracy on the testing set...')
count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, test_loader)
print('Accuracy on the testing set: %.2f' % (count_correct / count_all))

Training starts...

Epoch 6


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.844, precision: 0.853, recall: 0.849, f1: 0.84: 100%|██████████| 20000/20000 [14:35<00:00, 22.84it/s]


Epoch time elapsed: 875.56 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 83.85 s
Average FLOPs per sample:  7072170
Epoch: 6, Accuracy on the validation set: 0.75

Epoch 7


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.781, precision: 0.790, recall: 0.786, f1: 0.78: 100%|██████████| 20000/20000 [14:37<00:00, 22.78it/s]


Epoch time elapsed: 877.90 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 83.29 s
Average FLOPs per sample:  7072170
Epoch: 7, Accuracy on the validation set: 0.75

Epoch 8


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.969, precision: 0.969, recall: 0.971, f1: 0.97: 100%|██████████| 20000/20000 [14:43<00:00, 22.64it/s]


Epoch time elapsed: 883.37 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 81.30 s
Average FLOPs per sample:  7072170
Epoch: 8, Accuracy on the validation set: 0.76

Epoch 9


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.969, precision: 0.972, recall: 0.967, f1: 0.97: 100%|██████████| 20000/20000 [14:37<00:00, 22.79it/s]


Epoch time elapsed: 877.39 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 84.78 s
Average FLOPs per sample:  7072170
Epoch: 9, Accuracy on the validation set: 0.76

Epoch 10


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.906, precision: 0.917, recall: 0.912, f1: 0.91: 100%|██████████| 20000/20000 [14:40<00:00, 22.72it/s]


Epoch time elapsed: 880.20 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 83.43 s
Average FLOPs per sample:  7072170
Epoch: 10, Accuracy on the validation set: 0.77

Epoch 11


episode: 20000, reread_or_skim_times: 100000, accuracy: 1.000, precision: 1.000, recall: 1.000, f1: 1.00: 100%|██████████| 20000/20000 [14:37<00:00, 22.79it/s]


Epoch time elapsed: 877.47 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 81.46 s
Average FLOPs per sample:  7072170
Epoch: 11, Accuracy on the validation set: 0.77

Epoch 12


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.969, precision: 0.972, recall: 0.967, f1: 0.97: 100%|██████████| 20000/20000 [14:36<00:00, 22.83it/s]


Epoch time elapsed: 876.09 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 83.87 s
Average FLOPs per sample:  7072170
Epoch: 12, Accuracy on the validation set: 0.78

Epoch 13


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.969, precision: 0.969, recall: 0.971, f1: 0.97: 100%|██████████| 20000/20000 [14:37<00:00, 22.79it/s]


Epoch time elapsed: 877.66 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 82.93 s
Average FLOPs per sample:  7072170
Epoch: 13, Accuracy on the validation set: 0.77

Epoch 14


episode: 20000, reread_or_skim_times: 100000, accuracy: 1.000, precision: 1.000, recall: 1.000, f1: 1.00: 100%|██████████| 20000/20000 [14:37<00:00, 22.79it/s]


Epoch time elapsed: 877.75 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 80.54 s
Average FLOPs per sample:  7072170
Epoch: 14, Accuracy on the validation set: 0.78

Epoch 15


episode: 20000, reread_or_skim_times: 100000, accuracy: 0.938, precision: 0.937, recall: 0.937, f1: 0.94: 100%|██████████| 20000/20000 [14:37<00:00, 22.79it/s]


Epoch time elapsed: 877.72 s
reread_or_skim_times in this epoch: 100000
Evaluation time elapsed: 84.21 s
Average FLOPs per sample:  7072170
Epoch: 15, Accuracy on the validation set: 0.78
Compute the accuracy on the testing set...
Evaluation time elapsed: 412.45 s
Average FLOPs per sample:  7072170
Accuracy on the testing set: 0.76


In [56]:
# count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, valid_loader)
# print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_loader)
print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))

Evaluation time elapsed: 318.69 s
Average FLOPs per sample:  7072170
Epoch: 15, Accuracy on the training set: 0.98


In [115]:
from datetime import datetime

now = datetime.now()
now_time = str(now.day) + "_" + str(now.month) + "_" + str(now.year) + "_" + str(now.hour) + "_" + str(now.minute)
now_date = str(now.day) + "_" + str(now.month) + "_" + str(now.year)
now_date, now_time

('15_4_2023', '15_4_2023_14_51')

In [114]:
import os

def save_models(time_str, date_str, path: str = "."):
    os.makedirs(f"saved_models\\{date_str}", exist_ok=True)
    torch.save(clstm, f"{date_str}\\clstm_{time_str}.pth")
    torch.save(policy_s, f"{date_str}\\policy_s_{time_str}.pth")
    torch.save(policy_n, f"{date_str}\\policy_n_{time_str}.pth")
    torch.save(policy_c, f"{date_str}\\policy_c_{time_str}.pth")


In [33]:
save_models(now_time, now_date, "saved_models")

In [101]:
clstm.eval()
policy_s.eval()
policy_n.eval()
policy_c.eval()
action_logs = []
seen_logs = []
writer = SummaryWriter()
for i, (x, y) in enumerate(valid_loader):
    print(i)
    print(valid_data["text"].iloc[i])
    print(y)
    action_log_batch = []
    seen_batch = []
    label = y.to(device).long() # for cross entropy loss, the long type is required
    text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
    curr_step = 0
    n_rnn_layers = clstm.n_rnn_layers
    lstm_hidden_dim = clstm.lstm_hidden_dim
    h_0 = torch.zeros([n_rnn_layers,1,lstm_hidden_dim]).to(device)
    c_0 = torch.zeros([n_rnn_layers,1,lstm_hidden_dim]).to(device)
    count = 0
    while curr_step < 20 and count < 5: # loop until a text can be classified or currstep is up to 20
        count += 1
        # pass the input through cnn-lstm and policy s
        text_input = text[curr_step] # text_input 1*20
        text_str = train_data["text"].iloc[i]
        seen_batch.append(text_str.split()[curr_step * 20: (curr_step+1)*20])
        ht, ct = clstm(text_input, h_0, c_0)  # 1 * 128
        # if count == 1 and i == 0:
        #     writer.add_graph(clstm, [text_input, h_0, c_0], verbose=True)
        h_0 = ht.unsqueeze(0) # NUM_RNN_LAYERS * 1 * LSTM_HIDDEN_DIM, next input of lstm
        c_0 = ct
        # ht_ = ht.view(1, ht.shape[0] * ht.shape[2])
        # draw a stop decision
        stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
        # if count == 1 and i == 1:
        #     writer.add_graph(policy_s, ht)
        stop_decision = stop_decision.item()
        if stop_decision == 1: # classify
            break
        else:
            # draw an action (reread or skip)
            step, log_prob_n = sample_policy_n(ht, policy_n)
            # if count == 1 and i == 2:
            #     writer.add_graph(policy_n, ht)
            curr_step += int(step)  # reread or skip
            action_log_batch.append({"skip/reread": step})
    # draw a predicted label
    output_c = policy_c(ht)
    if i == 3:
        writer.add_graph(policy_c, ht, verbose=True)
        
    # draw a predicted label 
    pred_label, log_prob_c = sample_policy_c(output_c)
    action_log_batch.append({"prediction": pred_label, "real": label})
    if pred_label.item() == label:
        count_correct += 1
    count_all += 1
    action_logs.append(action_log_batch)
    seen_logs.append(seen_batch)
    if i == 10:
        break

writer.close()

0
Dumb is as dumb does in this thoroughly uninteresting supposed black comedy Essentially what starts out as Chris Klein trying to maintain low profile eventually morphs into an uninspired version of The Three Amigos only without any laughs In order for black comedy to work it must be outrageous which Play Dead is not In order for black comedy to work it cannot be mean spirited which Play Dead is What Play Dead really is is town full of nut jobs Fred Dunst does however do pretty fair imitation of Billy Bob Thornton character from Simple Plan while Jake Busey does pretty fair imitation of well Jake Busey MERK
tensor([0])
1
I dug out from my garage some old musicals and this is another one of my favorites It was written by Jay Alan Lerner and directed by Vincent Minelli It won two Academy Awards for Best Picture of and Best Screenplay The story of an American painter in Paris who tries to make it big Nina Foch is sophisticated lady of means and is very interested in helping him but soon 

  if a.grad is not None:


In [113]:
action_logs[9]

[{'skip/reread': 2},
 {'skip/reread': 2},
 {'skip/reread': 2},
 {'skip/reread': 2},
 {'skip/reread': 2},
 {'prediction': tensor([1], device='cuda:0'),
  'real': tensor([0], device='cuda:0')}]

In [91]:
ix = 3
action_logs[ix], " ".join(train_data["text"].iloc[ix].split()[0:400])

([{'skip/reread': 2},
  {'skip/reread': 2},
  {'skip/reread': 2},
  {'skip/reread': 2},
  {'skip/reread': 2},
  {'prediction': tensor([1], device='cuda:0'),
   'real': tensor([1], device='cuda:0')}],
 'Can Scarcely Imagine Better Movie Than This br br Hey before you all go Chick Flick on me am very Large Strong Masculine Macho Man who happens to think this was one of the better movies of the last years br br The acting was Superb and the Story was Marvelous This is wonderful medicine for the heart and soul The Acting could not have been better nor the movie better cast br br have known for Good while that Mercedes Ruehl along with Holly Hunter Joan Plowright Dame Edith Evans Sissy Spacek Judi Dench is among the greatest actresses ever to appear on film And of course Cloris Leachman also in this film in my view may in fact exceed them all in the shear magnum of her talent and varied roles she has appeared in over the years At any rate this was an Amazing cast This film was like book tha

In [100]:
" ".join(seen_logs[ix][5])

IndexError: list index out of range

In [46]:
valid_data.iloc[0]

text                        Dumb is as dumb does in this thoroughly uninte...
label                                                                       0
label_str                                                                 neg
text_bert_input_ids         [101, 12873, 2003, 2004, 12873, 2515, 1999, 20...
text_bert_attention_mask    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
Name: 6868, dtype: object

In [103]:
count_correct / count_all

0.9831278390655419