<a href="https://colab.research.google.com/github/navrgithub/NLP_Authorship_Attribution/blob/main/RST_BERT_BILSTM_CNN_task1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertModel
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pyRST import RSTParser
import spacy

In [None]:
# Define the RST+BERT+BILSTM-CNN model
class RST_BERT_BILSTM_CNN(nn.Module):
    def __init__(self, bert_hidden_size, num_classes):
        super(RST_BERT_BILSTM_CNN, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.lstm = nn.LSTM(bert_hidden_size, bert_hidden_size, batch_first=True)
        self.conv1d = nn.Conv1d(bert_hidden_size, bert_hidden_size, kernel_size=3)
        self.fc = nn.Linear(bert_hidden_size, num_classes)

    def forward(self, rst_embeddings, input_ids, attention_mask):
        _, pooled_output = self.bert(input_ids, attention_mask)
        lstm_output, _ = self.lstm(rst_embeddings)
        conv_output = self.conv1d(lstm_output.permute(0, 2, 1))
        conv_output = F.relu(conv_output)
        conv_output = torch.max(conv_output, dim=2)[0]
        concat_output = torch.cat((pooled_output, conv_output), dim=1)
        output = self.fc(concat_output)
        return output


In [None]:
# Define a custom dataset
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data['RST_Embeddings'][index], self.data['Generation'][index], self.data['label'][index]

In [None]:
# Load the dataframe
df = pd.read_csv('final_task1_data.csv')

# Initialize the RSTParser and spaCy English model
rst_parser = RSTParser()
nlp = spacy.load('en_core_web_sm')

# Preprocess the text data and extract RST features
df['Generation'] = df['Generation'].apply(lambda x: x.lower())  # Convert to lowercase
rst_embeddings = []
for text in df['Generation']:
    doc = nlp(text)
    rst_tree = rst_parser.parse(doc)
    rst_features = rst_parser.extract_features(rst_tree)
    rst_embeddings.append(rst_features)
df['RST_Embeddings'] = rst_embeddings

# Save the updated dataframe
df.to_csv('rst_embd_updated_data_task1.csv', index=False)

In [None]:
# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the text data
encoded_inputs = tokenizer(df['Generation'].tolist(), padding=True, truncation=True, return_tensors='pt')

# Create the custom dataset
dataset = CustomDataset(df)

In [None]:
# Define the hyperparameters
batch_size = 16
num_epochs = 10
bert_hidden_size = 768
num_classes = 2

In [None]:
# Create the data loader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Instantiate the RST+BERT+BILSTM-CNN model
model = RST_BERT_BILSTM_CNN(bert_hidden_size, num_classes)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())

In [None]:
# Training loop
for epoch in range(num_epochs):
    for rst_embeddings, generations, labels in dataloader:
        input_ids = encoded_inputs['input_ids'][generations]
                attention_mask = encoded_inputs['attention_mask'][generations]

        # Clear gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(rst_embeddings, input_ids, attention_mask)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()

# Evaluate the model
model.eval()

In [None]:
with torch.no_grad():
    for rst_embeddings, generations, labels in dataloader:
        input_ids = encoded_inputs['input_ids'][generations]
        attention_mask = encoded_inputs['attention_mask'][generations]
        
        # Forward pass
        outputs = model(rst_embeddings, input_ids, attention_mask)
        
        # Calculate predictions
        _, predicted = torch.max(outputs, 1)
        
        # Calculate accuracy, precision, and F1 score
        correct = (predicted == labels).sum().item()
        total = labels.size(0)
        accuracy = correct / total
        precision = precision_score(labels, predicted)
        f1 = f1_score(labels, predicted)

        # Print performance metrics
        print(f"Accuracy: {accuracy}")
        print(f"Precision: {precision}")
        print(f"F1 Score: {f1}")