Adapted Implementation of Learning Structured Representation for Text Classification via Reinforcement Learning (https://ojs.aaai.org/index.php/AAAI/article/view/12047)

# Bonaventure Dossou

In [None]:
import os
import sys
from sklearn.model_selection import train_test_split
import torch
from torch.nn import functional as F
import numpy as np
from torchtext import data
from torchtext import datasets
import pandas as pd
from torchtext.legacy import data
from torchtext import datasets
from torchtext.vocab import Vectors
import torch.nn as nn
from gensim.models import Word2Vec
from copy import deepcopy
import random
from tqdm import tqdm
import time
from torch.autograd import Variable
import torch.optim as optim

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, roc_curve, auc, roc_auc_score
from sklearn.metrics import classification_report

from google.colab import drive
drive.mount('/content/drive')

models = '/content/drive/MyDrive/models/'
plots = '/content/drive/MyDrive/plots/'

path = '/content/drive/MyDrive/data_label_1.csv'


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def load_dataset(filename, batch_size):
    
    # set_ = pd.read_csv(filename)
    # set_.dropna(axis=0, how='any', inplace=True)
    # set_.drop_duplicates(keep=False, inplace=True)
    # set_['Length'] = set_['text'].apply(lambda x: len(x.strip().split()))

    # set_ = set_[set_['Length'] <= 150]

    # c_filename = '/content/drive/MyDrive/data_label_1.csv'
    # set_.to_csv(c_filename, index=False)

    tokenize = lambda x: x.split()
    TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, include_lengths=True, batch_first=True)
    LABEL = data.LabelField()

    fields = [('text', TEXT), ('label', LABEL)]

    dataset = data.TabularDataset(path=filename, format='csv', fields=fields, skip_header = True)
    train_data, test_data = dataset.split(split_ratio=0.9, stratified=True, strata_field='label')
    # 90% - 10% (Train, Test)
    TEXT.build_vocab(train_data, vectors="glove.6B.300d")
    LABEL.build_vocab(train_data)

    word_embeddings = TEXT.vocab.vectors

    print ("Vocabulary size: {}".format(len(TEXT.vocab)))
    print ("Vector Shape: {}".format(TEXT.vocab.vectors.size()))
    print ("Number of labels: {}".format(len(LABEL.vocab)))

    train_data, valid_data = train_data.split(split_ratio=0.9, stratified=True, strata_field='label') # Further splitting of training_data to create new training_data & validation_data
    # 90% - 10% (Real Training, Validation)

    train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_data), batch_size=batch_size, sort_key=lambda x: len(x.text), repeat=False, shuffle=True)

    vocab_size = len(TEXT.vocab)
    return TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter


Constants

In [None]:
learning_rate = 3e-3
batch_size = 5
global_batch_size = 5
output_size = 2
hidden_size = 300
embedding_length = 300
samplecnt = 5
epsilon = 0.05
maxlength = 150
alpha = 0.1
tau = 0.1
delay_critic = True

Simple LSTM Classifier

In [None]:
TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter = load_dataset(path, batch_size=batch_size)

Vocabulary size: 35984
Vector Shape: torch.Size([35984, 300])
Number of labels: 2


Early Stopping to Control Under/Over fitting

In [None]:
import numpy as np
import torch

class EarlyStopping:
    def __init__(self, patience, path, verbose=True, delta=0, trace_func=print):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f})')
        # torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
class FakeNewsClassifier(nn.Module):
    def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights):
        super(FakeNewsClassifier, self).__init__()

        self.batch_size = batch_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.embedding_length = embedding_length

        self.word_embeddings = nn.Embedding(vocab_size, embedding_length)
        self.word_embeddings.weight = nn.Parameter(weights, requires_grad=False)
        self.lstm = nn.LSTM(embedding_length, hidden_size)
        self.Dropout = nn.Dropout(0.3)
        self.label = nn.Linear(hidden_size, output_size)

    def forward(self, input_sentence, batch_size=None):
        input = self.word_embeddings(input_sentence)
        input = input.permute(1, 0, 2)
        batch_size = input.size(1)
        if batch_size is None:
            h_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda())  # Initial hidden state of the LSTM
            c_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda())  # Initial cell state of the LSTM
        else:
            h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())
            c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())
        output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))
        final_output = self.label(self.Dropout(final_hidden_state[
                                                   -1]))  # final_hidden_state.size() = (1, batch_size, hidden_size) & final_output.size() = (batch_size, output_size)

        return final_output

    def wordvector_find(self, x):
        return self.word_embeddings(x)

    def getNextHiddenState(self, hc, x):
        hidden = hc[0, 0:self.hidden_size].view(1, 1, self.hidden_size)
        cell = hc[0, self.hidden_size:].view(1, 1, self.hidden_size)
        input = self.word_embeddings(x).view(1, 1, -1)
        out, hidden = self.lstm(input, (hidden, cell))
        hidden = torch.cat((hidden[0], hidden[1]), -1).view(1, -1)
        return self.Dropout(out), self.Dropout(hidden)


In [None]:
def clip_gradient(model, clip_value):
    params = list(filter(lambda p: p.grad is not None, model.parameters()))
    for p in params:
        p.grad.data.clamp_(-clip_value, clip_value)

def Sampling_RL(actor, critic, inputs, vector, length, epsilon, Random = True):
    current_lower_state = torch.zeros(1,2*hidden_size).cuda()
    actions = []
    states = []
    for pos in range(length):
        predicted = actor.get_target_output(current_lower_state, vector[0][pos], scope = "target")
        states.append([current_lower_state, vector[0][pos]])
        if Random:
            if random.random() > epsilon:
                action = (0 if random.random() < float(predicted[0].data.detach().cpu().numpy().item()) else 1)
            else:
                action = (1 if random.random() < float(predicted[0].data.detach().cpu().numpy().item()) else 0)
        else:
            action = np.argmax(predicted.data.detach().cpu().numpy()).item()
        actions.append(action)
        if action == 1:
            out_d, current_lower_state = critic.forward_lstm(current_lower_state, inputs[0][pos], scope = "target")
    Rinput = []
    for (i, a) in enumerate(actions):
        if a == 1:
            Rinput.append(int(inputs[0][i].item())) ####
    Rlength = len(Rinput)
    if Rlength == 0:
        actions[length-2] = 1
        Rinput.append(inputs[0][length-2])
        Rlength = 1
    Rinput += [1] * (maxlength - Rlength)
    Rinput = torch.tensor(Rinput).view(1,-1).cuda()    
    return actions, states, Rinput, Rlength


class policyNet(nn.Module):
    def __init__(self):
        super(policyNet, self).__init__()
        self.hidden = hidden_size
        self.W1 = nn.Parameter(torch.cuda.FloatTensor(2*self.hidden, 1).uniform_(-0.5, 0.5)) 
        self.W2 = nn.Parameter(torch.cuda.FloatTensor(embedding_length, 1).uniform_(-0.5, 0.5)) 
        self.b = nn.Parameter(torch.cuda.FloatTensor(1, 1).uniform_(-0.5, 0.5))

    def forward(self, h, x):
        h_ = torch.matmul(h.view(1,-1), self.W1) # 1x1
        x_ = torch.matmul(x.view(1,-1), self.W2) # 1x1
        scaled_out = torch.sigmoid(h_ +  x_ + self.b) # 1x1
        scaled_out = torch.clamp(scaled_out, min=1e-5, max=1 - 1e-5)
        scaled_out = torch.cat([1.0 - scaled_out, scaled_out],0)
        return scaled_out



class critic(nn.Module):
    def __init__(self):
        super(critic, self).__init__()
        self.target_pred = FakeNewsClassifier(batch_size, output_size, hidden_size, vocab_size, embedding_length, word_embeddings)
        self.active_pred = FakeNewsClassifier(batch_size, output_size, hidden_size, vocab_size, embedding_length, word_embeddings)


    def forward(self, x, scope):
        if scope == "target":
            out = self.target_pred(x)
        if scope == "active":
            out = self.active_pred(x)
        return out

    def assign_target_network(self):
        params = []
        for name, x in self.active_pred.named_parameters():
            params.append(x)
        i=0
        for name, x in self.target_pred.named_parameters():
            x.data = deepcopy(params[i].data)
            i+=1

    def update_target_network(self):
        params = []
        for name, x in self.active_pred.named_parameters():
            params.append(x)
        i=0
        for name, x in self.target_pred.named_parameters():
            x.data = deepcopy(params[i].data * (tau) + x.data * (1-tau))
            i+=1

    def assign_active_network(self):
        params = []
        for name, x in self.target_pred.named_parameters():
            params.append(x)
        i=0
        for name, x in self.active_pred.named_parameters():
            x.data = deepcopy(params[i].data)
            i+=1

    def assign_active_network_gradients(self):
        params = []
        for name, x in self.target_pred.named_parameters():
            params.append(x)
        i=0
        for name, x in self.active_pred.named_parameters():
            x.grad = deepcopy(params[i].grad)
            i+=1
        for name, x in self.target_pred.named_parameters():
            x.grad = None

    def forward_lstm(self, hc, x, scope):
        if scope == "target":
            out, state = self.target_pred.getNextHiddenState(hc, x)
        if scope == "active":
            out, state = self.active_pred.getNextHiddenState(hc, x)
        return out, state

    def wordvector_find(self, x):
        return self.target_pred.wordvector_find(x)


class actor(nn.Module):
    def __init__(self):
        super(actor, self).__init__()
        self.target_policy = policyNet()
        self.active_policy = policyNet()
        
    def get_target_logOutput(self, h, x):
        out = self.target_policy(h, x)
        logOut = torch.log(out)
        return logOut

    def get_target_output(self, h, x, scope):
        if scope == "target":
            out = self.target_policy(h, x)
        if scope == "active":
            out = self.active_policy(h, x)
        return out

    def get_gradient(self, h, x, reward, scope):
        if scope == "target":
            out = self.target_policy(h, x)
            logout = torch.log(out).view(-1)
            index = reward.index(0)
            index = (index + 1) % 2
            grad = torch.autograd.grad(logout[index].view(-1), self.target_policy.parameters())
            grad[0].data = grad[0].data * reward[index]
            grad[1].data = grad[1].data * reward[index]
            grad[2].data = grad[2].data * reward[index]
            return grad
        if scope == "active":
            out = self.active_policy(h, x)
        return out

    def assign_active_network_gradients(self, grad1, grad2, grad3):
        params = [grad1, grad2, grad3]    
        i=0
        for name, x in self.active_policy.named_parameters():
            x.grad = deepcopy(params[i])
            i+=1

    def update_target_network(self):
        params = []
        for name, x in self.active_policy.named_parameters():
            params.append(x)
        i=0
        for name, x in self.target_policy.named_parameters():
            x.data = deepcopy(params[i].data * (tau) + x.data * (1-tau))
            i+=1

    def assign_active_network(self):
        params = []
        for name, x in self.target_policy.named_parameters():
            params.append(x)
        i=0
        for name, x in self.active_policy.named_parameters():
            x.data = deepcopy(params[i].data)
            i+=1

def train_model(criticModel, actorModel, train_iter, epoch, RL_train = True, LSTM_train = True):
    total_epoch_loss = 0
    total_epoch_acc = 0
    criticModel.cuda()
    actorModel.cuda()
    critic_target_optimizer = torch.optim.SGD(criticModel.target_pred.parameters(), lr=learning_rate, momentum=0.9)
    critic_active_optimizer = torch.optim.SGD(criticModel.active_pred.parameters(), lr=learning_rate, momentum=0.9)

    actor_target_optimizer = torch.optim.SGD(actorModel.target_policy.parameters(), lr=learning_rate, momentum=0.9)
    actor_active_optimizer = torch.optim.SGD(actorModel.active_policy.parameters(), lr=learning_rate, momentum=0.9)
    steps = 0
    for idx, batch in enumerate(train_iter):
        totloss = 0.
        text = batch.text[0]
        target = batch.label
        lengths = batch.text[1]
        target = torch.autograd.Variable(target).long()
        pred = torch.zeros(batch_size, 2).cuda()
        if torch.cuda.is_available():
            text = text.cuda()
            target = target.cuda()
        if (text.size()[0] is not batch_size):
            continue

        criticModel.assign_active_network()
        actorModel.assign_active_network()

        avgloss = 0
        aveloss = 0.

        for i in range(batch_size):
            x = text[i].view(1,-1)
            y = target[i].view(1)
            length = int(lengths[i])

            if RL_train:
                criticModel.train(False)
                actorModel.train()
                actionlist, statelist, losslist = [], [], []
                aveLoss = 0.

                for i in range(samplecnt):
                    actions, states, Rinput, Rlength = Sampling_RL(actorModel, criticModel, x, criticModel.wordvector_find(x), length, epsilon, Random=True)
                    actionlist.append(actions)
                    statelist.append(states)
                    out = criticModel(Rinput, scope = "target")
                    loss_ = loss_fn(out, y)
                    loss_ += (float(Rlength) / length) **2 *0.15
                    aveloss += loss_
                    losslist.append(loss_)

                aveloss /= samplecnt
                totloss += aveloss
                grad1 = None
                grad2 = None
                grad3 = None
                flag = 0 

                if LSTM_train:
                    criticModel.train()
                    actorModel.train()  
                    critic_active_optimizer.zero_grad()
                    critic_target_optimizer.zero_grad()
                    prediction = criticModel(Rinput, scope = "target")
                    pred[i] = prediction
                    loss = loss_fn(prediction, y)
                    loss.backward()

                    criticModel.assign_active_network_gradients()
                    critic_active_optimizer.step()

                for i in range(samplecnt):
                    for pos in range(len(actionlist[i])):
                        rr = [0, 0]
                        rr[actionlist[i][pos]] = ((losslist[i] - aveloss) * alpha).cpu().item()
                        g = actorModel.get_gradient(statelist[i][pos][0], statelist[i][pos][1], rr, scope = "target")
                        if flag == 0:
                            grad1 = g[0]
                            grad2 = g[1]
                            grad3 = g[2]
                            flag = 1
                        else:
                            grad1 += g[0]
                            grad2 += g[1]
                            grad3 += g[2]
                actor_target_optimizer.zero_grad()
                actor_active_optimizer.zero_grad()
                actorModel.assign_active_network_gradients(grad1, grad2, grad3)
                actor_active_optimizer.step()
            else: 
                criticModel.train()
                actorModel.train(False)  
                critic_active_optimizer.zero_grad()
                critic_target_optimizer.zero_grad()
                prediction = criticModel(x, scope = "target")
                pred[i] = prediction
                loss = loss_fn(prediction, y)
                avgloss += loss.item()
                loss.backward()
                criticModel.assign_active_network_gradients()
                critic_active_optimizer.step()
        
        if RL_train:
            criticModel.train(False)
            actorModel.train()
            actorModel.update_target_network()
            if LSTM_train:
                criticModel.train()
                actorModel.train() 
                criticModel.update_target_network()                
        else:
            criticModel.train()
            actorModel.train(False)  
            criticModel.assign_target_network()
        avgloss /= batch_size
        num_corrects = (torch.max(pred, 1)[1].view(target.size()).data == target.data).float().sum()
        acc = 100.0 * num_corrects/len(batch)
        steps += 1
                
        total_epoch_loss += avgloss
        total_epoch_acc += acc.item()
        
    return total_epoch_loss/len(train_iter), total_epoch_acc/len(train_iter)

def eval_model(model, val_iter):
    total_epoch_loss = 0
    total_epoch_acc = 0
    model.eval()
    y_pred, y_true = [], []
    with torch.no_grad():
        for idx, batch in enumerate(val_iter):
            text = batch.text[0]
            if (text.size()[0] is not batch_size):
                continue
            target = batch.label
            target = torch.autograd.Variable(target).long()
            if torch.cuda.is_available():
                text = text.cuda()
                target = target.cuda()
            prediction = model(text, scope = "target")

            pred_idx = torch.max(prediction, 1)[1]
            y_true += list(target.data.detach().cpu().numpy())
            y_pred += list(pred_idx.detach().cpu().numpy())

            loss = loss_fn(prediction, target)
            num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).sum()
            acc = 100.0 * num_corrects/len(batch)
            total_epoch_loss += loss.item()
            total_epoch_acc += acc.item()

    return total_epoch_loss/len(val_iter), total_epoch_acc/len(val_iter), classification_report(y_true, y_pred)

def eval_model_RL(criticModel, actorModel, val_iter):
    total_epoch_loss = 0
    total_epoch_acc = 0
    criticModel.eval()
    actorModel.eval()
    y_pred, y_true = [], []
    with torch.no_grad():
        for idx, batch in enumerate(val_iter):
            text = batch.text[0]
            if (text.size()[0] is not batch_size):
                continue
            target = batch.label
            lengths = batch.text[1]
            target = torch.autograd.Variable(target).long()
            if torch.cuda.is_available():
                text = text.cuda()
                target = target.cuda()

            batch_loss = 0
            pred = torch.zeros(batch_size, 2).cuda()

            for i in range(batch_size):
                x = text[i].view(1,-1)
                y = target[i].view(1)
                length = int(lengths[i])

                actions, states, Rinput, Rlenth = Sampling_RL(actorModel, criticModel, x, criticModel.wordvector_find(x), length, epsilon, Random=False)
                
                prediction = criticModel(Rinput, scope = "target")
                loss = loss_fn(prediction, y)
                batch_loss += loss
                pred[i] = prediction

            pred_idx = torch.max(pred, 1)[1]
            y_true += list(target.data.detach().cpu().numpy())
            y_pred += list(pred_idx.detach().cpu().numpy())

            num_corrects = (torch.max(pred, 1)[1].view(target.size()).data == target.data).sum()
            acc = 100.0 * num_corrects/len(batch)
            total_epoch_loss += batch_loss.item()
            total_epoch_acc += acc.item()

    return total_epoch_loss/len(val_iter), total_epoch_acc/len(val_iter), classification_report(y_true, y_pred)

criticModel = critic()
actorModel = actor()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using {} to train".format(device))
criticModel.to(device)
actorModel.to(device)

actorModel.cuda()
criticModel.cuda()

loss_fn = F.cross_entropy
best_val_acc = 0.

epoch = 30

Using cuda to train


# Training the LSTM Model

In [None]:
lstm_train_loss, lstm_train_acc, lstm_val_loss, lstm_val_acc = [], [], [], [] 
if delay_critic:
    print("Training Classifier")
    early_stopping = EarlyStopping(patience=10, path='')
    for epoch in range(epoch):
        train_loss, train_acc = train_model(criticModel, actorModel, train_iter, epoch, RL_train = False)
        val_loss, val_acc, _= eval_model(criticModel, valid_iter)
        lstm_train_loss.append(train_loss)
        lstm_train_acc.append(train_acc)
        lstm_val_loss.append(val_loss)
        lstm_val_acc.append(val_acc)
        if val_acc > best_val_acc:
            torch.save(criticModel.state_dict(), models+'classifier.pt')
            best_val_acc = val_acc
            print("Saved Classifier Model with acc: ", val_acc)
        
        early_stopping(val_loss, criticModel)
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, Val. Acc: {val_acc:.2f}%')


Classifier Performance on Testing Set

In [None]:
test_loss, test_acc, class_report = eval_model(criticModel, test_iter)
print("Classifier Performance on Testing Set: {}".format(test_acc))
print("\n\n")
print("Classifier Classification Report on Testing Set\n: {}".format(class_report))

  self.dropout, self.training, self.bidirectional, self.batch_first)


Classifier Performance on Testing Set: 99.68911917098445



Classifier Classification Report on Testing Set
:               precision    recall  f1-score   support

           0       1.00      1.00      1.00       724
           1       1.00      0.99      0.99       241

    accuracy                           1.00       965
   macro avg       1.00      1.00      1.00       965
weighted avg       1.00      1.00      1.00       965



Plot Classifier Training Plot

Use trained Classifier for the RL agent training

In [None]:
best_val_acc1 = 0
rl_train_loss, rl_train_acc, rl_val_loss, rl_val_acc = [], [], [], []
criticModel.load_state_dict(torch.load(models+'classifier.pt'))
early_stopping = EarlyStopping(patience=10, path='')
for i in range(30):
    print("Epoch: {}".format(i+1))
    train_loss, train_acc = train_model(criticModel, actorModel, train_iter, i, LSTM_train = False)
    val_loss, val_acc, _= eval_model_RL(criticModel, actorModel,  valid_iter)
    rl_train_loss.append(train_loss)
    rl_train_acc.append(train_acc)
    rl_val_loss.append(val_loss)
    rl_val_acc.append(val_acc)
    if val_acc > best_val_acc1:
        torch.save(actorModel.state_dict(), models+'rl_agent.pt')
        best_val_acc1 = val_acc

    early_stopping(val_loss, criticModel)
    if early_stopping.early_stop:
        print("Early stopping")
        break

    print(f'Epoch: {i+1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, Val. Acc: {val_acc:.2f}%')


RL Agent Performance on the Testing Set

In [None]:
criticModel.load_state_dict(torch.load(models+'classifier.pt'))
actorModel.load_state_dict(torch.load(models+'rl_agent.pt'))

test_loss, test_acc, class_report = eval_model_RL(criticModel, actorModel, test_iter)
print("RL Agent Performance on Testing Set: {}".format(test_acc))
print("\n\n")
print("RL Agent Classification Report on Testing Set\n: {}".format(class_report))

RL Agent Performance on Testing Set: 99.89637305699482



RL Agent Classification Report on Testing Set
:               precision    recall  f1-score   support

           0       1.00      1.00      1.00       724
           1       1.00      1.00      1.00       241

    accuracy                           1.00       965
   macro avg       1.00      1.00      1.00       965
weighted avg       1.00      1.00      1.00       965

