In [1]:
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Categorical
from torchtext import datasets
import os
import time
import numpy as np 
import random
import argparse

from buffer import Buffer
from networks import Policy_C, Policy_N, Policy_S, ValueNetwork, Transformer
from utils.utils import sample_policy_c, sample_policy_n, sample_policy_s, evaluate_transformer, compute_policy_value_losses
from utils.utils import cnn_cost, clstm_cost, c_cost, n_cost, s_cost, openDfFromPickle, calculate_stats_from_cm

In [2]:
print('Reading data...')
train_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_train_distilbert-base-uncased.pkl")
valid_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_val_distilbert-base-uncased.pkl")
test_data = openDfFromPickle("C:\\Users\\mrbal\\Documents\\NLP\\RL\\basic_reinforcement_learning\\NLP_datasets\\imdb\\imdb_test_distilbert-base-uncased.pkl")
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')


Reading data...
Number of training examples: 20000
Number of validation examples: 5000
Number of testing examples: 25000


In [3]:
# split the datasets into batches
BATCH_SIZE = 1  # the batch size for a dataset iterator
device = torch.device("cpu") # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

device: cpu


In [4]:
xtrain = torch.from_numpy(np.stack(train_data["text_bert_input_ids"].values))[:, 0:400]
ytrain = torch.from_numpy(train_data["label"].values)
xvalid = torch.from_numpy(np.stack(valid_data["text_bert_input_ids"].values))[:, 0:400]
yvalid = torch.from_numpy(valid_data["label"].values)
xtest = torch.from_numpy(np.stack(test_data["text_bert_input_ids"].values))[:, 0:400]
ytest = torch.from_numpy(test_data["label"].values)

print(xtrain.shape, ytrain.shape)
print(xvalid.shape, yvalid.shape)
print(xtest.shape, ytest.shape)

torch.Size([20000, 400]) torch.Size([20000])
torch.Size([5000, 400]) torch.Size([5000])
torch.Size([25000, 400]) torch.Size([25000])


In [5]:
from torch.utils.data import DataLoader, TensorDataset
train_loader = DataLoader(TensorDataset(xtrain, ytrain), batch_size=BATCH_SIZE)
valid_loader = DataLoader(TensorDataset(xvalid, yvalid), batch_size=BATCH_SIZE)
test_loader = DataLoader(TensorDataset(xtest, ytest), batch_size=BATCH_SIZE)

In [6]:
random.seed(2023)
np.random.seed(2023)
torch.manual_seed(2023)
torch.cuda.manual_seed(2023)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [7]:
# set up parameters
INPUT_DIM = 30522
CHUNCK_SIZE = 20
EMBEDDING_DIM = 100
NUM_RNN_LAYERS = 1
KER_SIZE = 5
HIDDEN_DIM_LSTM = 128 
HIDDEN_DIM_DENSE = CHUNCK_SIZE * 10 # HIDDEN_DIM_LSTM * NUM_RNN_LAYERS
OUTPUT_DIM = 1
MAX_K = 4  # the output dimension for step size 0, 1, 2, 3
LABEL_DIM = 2
N_FILTERS = 128
BATCH_SIZE = 1
gamma = 0.99
alpha = 0.2
learning_rate = 0.001

In [8]:
# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
# set up models
transformer_config = {
    "num_blocks": 3,
    "embed_dim": 100, 
    "trns_input_dim": 20*10, # embedding dim (per word) * chunk_size
    "num_heads": 1,
    "memory_length": 20,
    "positional_encoding": "", # options: "" "relative" "learned"
    "layer_norm": "pre", # options: "" "pre" "post"
    "gtrxl": True,
    "gtrxl_bias": 0.0
}
config = {
    "n_workers": 1,
    "n_mini_batch": 1,
    "worker_steps": 20,
    "transformer": transformer_config
}

trnsxl = Transformer(transformer_config, transformer_config["embed_dim"], 20).to(device)
print(trnsxl)
policy_s = Policy_S(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, OUTPUT_DIM).to(device)
policy_n = Policy_N(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, MAX_K).to(device)
policy_c = Policy_C(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, LABEL_DIM).to(device)
value_net = ValueNetwork(HIDDEN_DIM_DENSE, HIDDEN_DIM_DENSE, OUTPUT_DIM).to(device)


Transformer(
  (activation): ReLU()
  (linear_embedding): Embedding(30528, 100)
  (conv): Conv2d(100, 10, kernel_size=(4, 4), stride=(1, 1), padding=same)
  (transformer_blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (attention): MultiHeadAttention(
        (values): Linear(in_features=200, out_features=200, bias=False)
        (keys): Linear(in_features=200, out_features=200, bias=False)
        (queries): Linear(in_features=200, out_features=200, bias=False)
        (fc_out): Linear(in_features=200, out_features=200, bias=True)
      )
      (gate1): GRUGate(
        (Wr): Linear(in_features=200, out_features=200, bias=False)
        (Ur): Linear(in_features=200, out_features=200, bias=False)
        (Wz): Linear(in_features=200, out_features=200, bias=False)
        (Uz): Linear(in_features=200, out_features=200, bias=False)
        (Wg): Linear(in_features=200, out_features=200, bias=False)
        (Ug): Linear(in_features=200, out_features=200, bias=False)
        (si

In [9]:
# set up optimiser
params_pg = list(policy_s.parameters()) + list(policy_c.parameters()) + list(policy_n.parameters())
optim_loss = optim.Adam(trnsxl.parameters(), lr=learning_rate)
optim_policy = optim.Adam(params_pg, lr=learning_rate)
optim_value = optim.Adam(value_net.parameters(), lr=learning_rate)

In [10]:
def finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch):
    '''
    Called when a data sample has been processed.
    '''
    baseline_value_sum = torch.stack(baseline_value_batch).sum()
    policy_loss = torch.stack(policy_loss_sum).mean()
    encoder_loss = torch.stack(encoder_loss_sum).mean()
    objective_loss = encoder_loss - policy_loss + baseline_value_sum
    # set gradient to zero
    optim_loss.zero_grad()
    optim_policy.zero_grad()
    optim_value.zero_grad()
    # back propagation
    objective_loss.backward()
    # gradient update
    optim_loss.step()
    optim_policy.step()
    optim_value.step()

In [11]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [12]:
memory_length = transformer_config["memory_length"]
num_blocks = transformer_config["num_blocks"]
embed_dim = transformer_config["embed_dim"]
trns_input_dim = transformer_config["trns_input_dim"]
max_episode_length = 20
num_workers = 1

In [13]:
def batched_index_select(input, dim, index):
    """
    Selects values from the input tensor at the given indices along the given dimension.
    This function is similar to torch.index_select, but it supports batched indices.
    The input tensor is expected to be of shape (batch_size, ...), where ... means any number of additional dimensions.
    The indices tensor is expected to be of shape (batch_size, num_indices), where num_indices is the number of indices to select for each element in the batch.
    The output tensor is of shape (batch_size, num_indices, ...), where ... means any number of additional dimensions that were present in the input tensor.

    Arguments:
        input {torch.tensor} -- Input tensor
        dim {int} -- Dimension along which to select values
        index {torch.tensor} -- Tensor containing the indices to select

    Returns:
        {torch.tensor} -- Output tensor
    """
    for ii in range(1, len(input.shape)):
        if ii != dim:
            index = index.unsqueeze(ii)
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.expand(expanse)
    return torch.gather(input, dim, index)

In [14]:
print('Training starts...')
worker_current_episode_step = torch.zeros((num_workers, ), dtype=torch.long)
for epoch in range(5):
    print('\nEpoch', epoch+1)
    # log the start time of the epoch
    start = time.time()
    # set the models in training mode
    trnsxl.train()
    policy_s.train()
    policy_n.train()
    policy_c.train()
    # reset the count of reread_or_skim_times
    reread_or_skim_times = 0
    policy_loss_sum = []
    encoder_loss_sum = []
    baseline_value_batch = []
    pbar = tqdm(train_loader)
    cm = np.zeros((LABEL_DIM, LABEL_DIM))
   
    for index, (x, y) in enumerate(pbar):
        label = y.to(device)
        text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
        curr_step = 0  # the position of the current chunk
       
        # Setup placeholders for each worker's current episodic memory
        memory = torch.zeros((num_workers, memory_length, num_blocks, trns_input_dim), dtype=torch.float32).to(device)
        # Generate episodic memory mask used in attention
        memory_mask = torch.tril(torch.ones((num_workers, memory_length)), diagonal=-1).to(device)

        """ e.g. memory mask tensor looks like this if memory_length = 6
        0, 0, 0, 0, 0, 0
        1, 0, 0, 0, 0, 0
        1, 1, 0, 0, 0, 0
        1, 1, 1, 0, 0, 0
        1, 1, 1, 1, 0, 0
        1, 1, 1, 1, 1, 0
        """         
        # setup buffer for memory
        buffer = {"memories": [memory[w] for w in range(num_workers)], 
                  "memory_mask": torch.zeros((num_workers, memory_length, memory_length), dtype=torch.bool).to(device),
                  "memory_index": torch.zeros((num_workers, memory_length), dtype=torch.long).to(device), 
                  "memory_indices": torch.zeros((num_workers, memory_length, memory_length), dtype=torch.long).to(device)}
        # Setup memory window indices to support a sliding window over the episodic memory
        repetitions = torch.repeat_interleave(torch.arange(0, memory_length).unsqueeze(0), memory_length - 1, dim = 0).long()
        memory_indices = torch.stack([torch.arange(i, i + memory_length) for i in range(max_episode_length - memory_length + 1)]).long()
        memory_indices = torch.cat((repetitions, memory_indices)).to(device)
        count = 0  # maximum skim/reread time: 5
        baseline_value_ep = []
        saved_log_probs = []  # for the use of policy gradient update
        # collect the computational costs for every time step
        cost_ep = []
        # torch.autograd.set_detect_anomaly(True)
        while curr_step < CHUNCK_SIZE and count < 5: 
            # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
            # update count
            count += 1
            # pass the input through cnn-lstm and policy s
            text_input = text[curr_step] # text_input 1*20
            # print(f"input text_input: {text_input.shape}, memory: {memory.shape}")
            # with torch.no_grad():
            buffer["memory_mask"][:, curr_step] = memory_mask[torch.clip(worker_current_episode_step, 0, memory_length - 1)]
            print("======> ", buffer["memory_indices"][:, curr_step].shape, memory_indices[worker_current_episode_step].shape)
            buffer["memory_indices"][:, curr_step] = memory_indices[worker_current_episode_step]
            # Retrieve the memory window from the entire episodic memory
            sliced_memory = batched_index_select(memory, 1, buffer["memory_indices"][:,curr_step]).to(device)
            ht, memory_t = trnsxl(text_input, sliced_memory, buffer["memory_mask"][:, curr_step], buffer["memory_indices"][:,curr_step])  # text_input: CHUNK_SIZE * 10 (10: conv out filters), memory: num_workers, CHUNK_SIZE, num_blocks, CHUNK_SIZE*10
            # print(f"ht memory: {ht.shape, memory_t.shape}") # memory: num_workers, num_blocks, CHUNK_SIZe*10
            # separate the value which is the input of value net
            ht_ = ht.clone().detach().requires_grad_(True)
            memory[:, worker_current_episode_step] = memory_t # .clone().detach().requires_grad_(False)
            # print(f"After memory update: {memory.shape}")
            # ht_ = ht_.view(1, ht_.shape[0] * ht_.shape[2]) # ht_: 1, NUM_RNN_LAYERS * HIDDEN_DIM_LSTM
            # compute a baseline value for the value network
            bi = value_net(ht_)
            # NUM_RNN_LAYERS * 1 * 128, next input of lstm
            # draw a stop decision
            stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
            stop_decision = stop_decision.item()
            if stop_decision == 1: # classify
                worker_current_episode_step[0] = 0
                break
            else: 
                reread_or_skim_times += 1
                worker_current_episode_step[0] += 1
                # draw an action (reread or skip)
                step, log_prob_n = sample_policy_n(ht, policy_n)
                curr_step += int(step)  # reread or skip
                if curr_step < CHUNCK_SIZE and count < 5:
                    # If the code can still execute the next loop, it is not the last time step.
                    cost_ep.append(clstm_cost + s_cost + n_cost)
                    # add the baseline value
                    baseline_value_ep.append(bi)
                    # add the log prob for the current actions
                    saved_log_probs.append(log_prob_s + log_prob_n)
        
        # draw a predicted label
        output_c = policy_c(ht)
        # cross entrpy loss input shape: input(N, C), target(N)
        loss = criterion(output_c, label)  # positive value
        # draw a predicted label 
        pred_label, log_prob_c = sample_policy_c(output_c)
        # update the confusion matrix
        cm[pred_label][y] += 1
        if stop_decision == 1:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost)
            saved_log_probs.append(log_prob_s + log_prob_c)
        else:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost + n_cost)
            # At the moment, the probability of drawing a stop decision is 1,
            # so its log probability is zero which can be ignored in th sum.
            saved_log_probs.append(log_prob_c.unsqueeze(0))
        # add the baseline value
        baseline_value_ep.append(bi)
        # add the cross entropy loss
        encoder_loss_sum.append(loss)
        # compute the policy losses and value losses for the current episode
        policy_loss_ep, value_losses = compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma)
        policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
        baseline_value_batch.append(torch.cat(value_losses).sum())
        # update gradients
        if (index + 1) % 32 == 0:  # take the average of samples, backprop
            finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
            del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]

        # log and print out info     
        if (index + 1) % 32 == 0:
            stats = calculate_stats_from_cm(cm)
            cm = np.zeros((LABEL_DIM, LABEL_DIM))
            acc = stats["accuracy"]
            recall = stats["recall"]
            precision = stats["precision"]
            f1 = stats["f1"]
            writer.add_scalar("train_accuracy", acc, len(train_loader)*epoch + index)
            writer.add_scalar("train_recall", recall,  len(train_loader)*epoch + index)
            writer.add_scalar("train_precision", precision,  len(train_loader)*epoch + index)
            writer.add_scalar("train_f1", f1,  len(train_loader)*epoch + index)
            pbar.set_description(f"episode: {index + 1}, reread_or_skim_times: {reread_or_skim_times}, accuracy: {acc:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1:.2f}")
            
            """print(f'\n current episode: {index + 1}')
            # log the current position of the text which the agent has gone through
            print('curr_step: ', curr_step)
            # log the sum of the rereading and skimming times
            print(f'current reread_or_skim_times: {reread_or_skim_times}')"""


    print('Epoch time elapsed: %.2f s' % (time.time() - start))
    print('reread_or_skim_times in this epoch:', reread_or_skim_times)
    count_all, count_correct = evaluate_transformer(trnsxl, policy_s, policy_n, policy_c, valid_loader)
    print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
    writer.add_scalar("validation_acccuracy", count_correct / count_all,  len(train_loader)*epoch + index)
    # count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_loader)
    # print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
    
print('Compute the accuracy on the testing set...')
count_all, count_correct = evaluate_transformer(trnsxl, policy_s, policy_n, policy_c, test_loader)
print('Accuracy on the testing set: %.2f' % (count_correct / count_all))

Training starts...

Epoch 1


  0%|          | 0/20000 [00:00<?, ?it/s]



  return F.conv2d(input, weight, bias, self.stride,
  0%|          | 0/20000 [00:00<?, ?it/s]


IndexError: index 1 is out of bounds for dimension 0 with size 1

In [None]:
%tb

In [None]:
print('Training starts...')

for epoch in range(5, 15):
    print('\nEpoch', epoch+1)
    # log the start time of the epoch
    start = time.time()
    # set the models in training mode
    trnsxl.train()
    policy_s.train()
    policy_n.train()
    policy_c.train()
    # reset the count of reread_or_skim_times
    reread_or_skim_times = 0
    policy_loss_sum = []
    encoder_loss_sum = []
    baseline_value_batch = []
    pbar = tqdm(train_loader)
    cm = np.zeros((LABEL_DIM, LABEL_DIM))
    for index, (x, y) in enumerate(pbar):
        label = y.to(device)
        text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
        curr_step = 0  # the position of the current chunk
        # Setup placeholders for each worker's current episodic memory
        memory = torch.zeros((num_workers, memory_length, num_blocks, trns_input_dim), dtype=torch.float32).to(device)
        # Generate episodic memory mask used in attention
        memory_mask = torch.tril(torch.ones((num_workers, memory_length)), diagonal=-1).to(device)
        """ e.g. memory mask tensor looks like this if memory_length = 6
        0, 0, 0, 0, 0, 0
        1, 0, 0, 0, 0, 0
        1, 1, 0, 0, 0, 0
        1, 1, 1, 0, 0, 0
        1, 1, 1, 1, 0, 0
        1, 1, 1, 1, 1, 0
        """         
        # Setup memory window indices to support a sliding window over the episodic memory
        repetitions = torch.repeat_interleave(torch.arange(0, memory_length).unsqueeze(0), memory_length - 1, dim = 0).long()
        memory_indices = torch.stack([torch.arange(i, i + memory_length) for i in range(max_episode_length - memory_length + 1)]).long()
        memory_indices = torch.cat((repetitions, memory_indices)).to(device)
        count = 0  # maximum skim/reread time: 5
        baseline_value_ep = []
        saved_log_probs = []  # for the use of policy gradient update
        # collect the computational costs for every time step
        cost_ep = []
        with torch.no_grad():
            while curr_step < CHUNCK_SIZE and count < 5: 
                # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
                # update count
                count += 1
                # pass the input through cnn-lstm and policy s
                text_input = text[curr_step] # text_input 1*20
                # print(f"input text_input: {text_input.shape}, memory: {memory.shape}")
                ht, memory_t = trnsxl(text_input, memory, memory_mask, memory_indices)  # text_input: CHUNK_SIZE * 10 (10: conv out filters), memory: num_workers, CHUNK_SIZE, num_blocks, CHUNK_SIZE*10
                # print(f"ht memory: {ht.shape, memory_t.shape}") # memory: num_workers, num_blocks, CHUNK_SIZe*10
                # separate the value which is the input of value net
                ht_ = ht.clone().detach().requires_grad_(True)
                memory[:, curr_step, :, :] = memory_t.clone()
                # ht_ = ht_.view(1, ht_.shape[0] * ht_.shape[2]) # ht_: 1, NUM_RNN_LAYERS * HIDDEN_DIM_LSTM
                # compute a baseline value for the value network
                bi = value_net(ht_)
                # NUM_RNN_LAYERS * 1 * 128, next input of lstm
                # draw a stop decision
                stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
                stop_decision = stop_decision.item()
                if stop_decision == 1: # classify
                    break
                else: 
                    reread_or_skim_times += 1
                    # draw an action (reread or skip)
                    step, log_prob_n = sample_policy_n(ht, policy_n)
                    curr_step += int(step)  # reread or skip
                    if curr_step < CHUNCK_SIZE and count < 5:
                        # If the code can still execute the next loop, it is not the last time step.
                        cost_ep.append(clstm_cost + s_cost + n_cost)
                        # add the baseline value
                        baseline_value_ep.append(bi)
                        # add the log prob for the current actions
                        saved_log_probs.append(log_prob_s + log_prob_n)
            
        # draw a predicted label
        output_c = policy_c(ht)
        # cross entrpy loss input shape: input(N, C), target(N)
        loss = criterion(output_c, label)  # positive value
        # draw a predicted label 
        pred_label, log_prob_c = sample_policy_c(output_c)
        # update the confusion matrix
        cm[pred_label][y] += 1
        if stop_decision == 1:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost)
            saved_log_probs.append(log_prob_s + log_prob_c)
        else:
            # add the cost of the last time step
            cost_ep.append(clstm_cost + s_cost + c_cost + n_cost)
            # At the moment, the probability of drawing a stop decision is 1,
            # so its log probability is zero which can be ignored in th sum.
            saved_log_probs.append(log_prob_c.unsqueeze(0))
        # add the baseline value
        baseline_value_ep.append(bi)
        # add the cross entropy loss
        encoder_loss_sum.append(loss)
        # compute the policy losses and value losses for the current episode
        policy_loss_ep, value_losses = compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma)
        policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
        baseline_value_batch.append(torch.cat(value_losses).sum())
        # update gradients
        if (index + 1) % 32 == 0:  # take the average of samples, backprop
            finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
            del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]

        # log and print out info     
        if (index + 1) % 32 == 0:
            stats = calculate_stats_from_cm(cm)
            cm = np.zeros((LABEL_DIM, LABEL_DIM))
            acc = stats["accuracy"]
            recall = stats["recall"]
            precision = stats["precision"]
            f1 = stats["f1"]
            writer.add_scalar("train_accuracy", acc, len(train_loader)*epoch + index)
            writer.add_scalar("train_recall", recall,  len(train_loader)*epoch + index)
            writer.add_scalar("train_precision", precision,  len(train_loader)*epoch + index)
            writer.add_scalar("train_f1", f1,  len(train_loader)*epoch + index)
            pbar.set_description(f"episode: {index + 1}, reread_or_skim_times: {reread_or_skim_times}, accuracy: {acc:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1:.2f}")
            
            """print(f'\n current episode: {index + 1}')
            # log the current position of the text which the agent has gone through
            print('curr_step: ', curr_step)
            # log the sum of the rereading and skimming times
            print(f'current reread_or_skim_times: {reread_or_skim_times}')"""


    print('Epoch time elapsed: %.2f s' % (time.time() - start))
    print('reread_or_skim_times in this epoch:', reread_or_skim_times)
    count_all, count_correct = evaluate_transformer(trnsxl, policy_s, policy_n, policy_c, valid_loader)
    print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
    writer.add_scalar("validation_acccuracy", count_correct / count_all,  len(train_loader)*epoch + index)
    # count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_loader)
    # print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
    
print('Compute the accuracy on the testing set...')
count_all, count_correct = evaluate_transformer(trnsxl, policy_s, policy_n, policy_c, test_loader)
print('Accuracy on the testing set: %.2f' % (count_correct / count_all))
writer.close()

In [None]:
count_all, count_correct = evaluate_transformer(trnsxl, policy_s, policy_n, policy_c, valid_loader)

In [None]:
count_correct / count_all

In [None]:
from datetime import datetime

now = datetime.now()
now_time = str(now.day) + "_" + str(now.month) + "_" + str(now.year) + "_" + str(now.hour) + "_" + str(now.minute)
now_date = str(now.day) + "_" + str(now.month) + "_" + str(now.year)
now_date, now_time

In [None]:
import os

def save_models(time_str, date_str, path: str = "."):
    os.makedirs(f"saved_models\\{date_str}", exist_ok=True)
    torch.save(clstm, f"{date_str}\\clstm_{time_str}.pth")
    torch.save(policy_s, f"{date_str}\\policy_s_{time_str}.pth")
    torch.save(policy_n, f"{date_str}\\policy_n_{time_str}.pth")
    torch.save(policy_c, f"{date_str}\\policy_c_{time_str}.pth")


In [None]:
save_models(now_time, now_date, "saved_models")

In [None]:
clstm.eval()
policy_s.eval()
policy_n.eval()
policy_c.eval()
action_logs = []
seen_logs = []
writer = SummaryWriter()
for i, (x, y) in enumerate(valid_loader):
    print(i)
    print(valid_data["text"].iloc[i])
    print(y)
    action_log_batch = []
    seen_batch = []
    label = y.to(device).long() # for cross entropy loss, the long type is required
    text = x.to(device).view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
    curr_step = 0
    n_rnn_layers = clstm.n_rnn_layers
    lstm_hidden_dim = clstm.lstm_hidden_dim
    h_0 = torch.zeros([n_rnn_layers,1,lstm_hidden_dim]).to(device)
    c_0 = torch.zeros([n_rnn_layers,1,lstm_hidden_dim]).to(device)
    count = 0
    while curr_step < 20 and count < 5: # loop until a text can be classified or currstep is up to 20
        count += 1
        # pass the input through cnn-lstm and policy s
        text_input = text[curr_step] # text_input 1*20
        text_str = train_data["text"].iloc[i]
        seen_batch.append(text_str.split()[curr_step * 20: (curr_step+1)*20])
        ht, ct = clstm(text_input, h_0, c_0)  # 1 * 128
        # if count == 1 and i == 0:
        #     writer.add_graph(clstm, [text_input, h_0, c_0], verbose=True)
        h_0 = ht.unsqueeze(0) # NUM_RNN_LAYERS * 1 * LSTM_HIDDEN_DIM, next input of lstm
        c_0 = ct
        # ht_ = ht.view(1, ht.shape[0] * ht.shape[2])
        # draw a stop decision
        stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
        # if count == 1 and i == 1:
        #     writer.add_graph(policy_s, ht)
        stop_decision = stop_decision.item()
        if stop_decision == 1: # classify
            break
        else:
            # draw an action (reread or skip)
            step, log_prob_n = sample_policy_n(ht, policy_n)
            # if count == 1 and i == 2:
            #     writer.add_graph(policy_n, ht)
            curr_step += int(step)  # reread or skip
            action_log_batch.append({"skip/reread": step})
    # draw a predicted label
    output_c = policy_c(ht)
    if i == 3:
        writer.add_graph(policy_c, ht, verbose=True)
        
    # draw a predicted label 
    pred_label, log_prob_c = sample_policy_c(output_c)
    action_log_batch.append({"prediction": pred_label, "real": label})
    if pred_label.item() == label:
        count_correct += 1
    count_all += 1
    action_logs.append(action_log_batch)
    seen_logs.append(seen_batch)
    if i == 10:
        break

writer.close()

In [None]:
action_logs[9]

In [None]:
ix = 3
action_logs[ix], " ".join(train_data["text"].iloc[ix].split()[0:400])

In [None]:
" ".join(seen_logs[ix][5])

In [None]:
valid_data.iloc[0]

In [None]:
count_correct / count_all