In [1]:
import pandas as pd
import tensorflow as tf
from transformers import (
    DistilBertTokenizer, TFDistilBertForSequenceClassification,
    RobertaTokenizer, TFRobertaForSequenceClassification,
)
from openai import OpenAI




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the fine-tuned DistilBERT model and tokenizer
distilbert_tokenizer = DistilBertTokenizer.from_pretrained('./distilbert_finetuned')
distilbert_model = TFDistilBertForSequenceClassification.from_pretrained('./distilbert_finetuned')

# Load the fine-tuned RoBERTa model and tokenizer for the first variant
roberta_tokenizer_v1 = RobertaTokenizer.from_pretrained('./roberta_v1_finetuned')
roberta_model_v1 = TFRobertaForSequenceClassification.from_pretrained('./roberta_v1_finetuned')

# Load the fine-tuned RoBERTa model and tokenizer for the second variant
roberta_tokenizer_v2 = RobertaTokenizer.from_pretrained('./roberta_v2_finetuned')
roberta_model_v2 = TFRobertaForSequenceClassification.from_pretrained('./roberta_v2_finetuned')

def ensemble_classify_news_and_evaluate_accuracy(df, column):
    # Lists to store individual model predictions
    distilbert_predictions = []
    roberta_v1_predictions = []
    roberta_v2_predictions = []
    ensemble_predictions = []
    
    for _, row in df.iterrows():
        text_input = row[column]

        # Prepare inputs and get probabilities for DistilBERT
        distilbert_inputs = distilbert_tokenizer(text_input, return_tensors="tf", truncation=True, padding='max_length', max_length=512)
        distilbert_outputs = distilbert_model(distilbert_inputs)
        distilbert_probabilities = tf.nn.softmax(distilbert_outputs.logits, axis=-1)
        distilbert_predicted_class_index = tf.argmax(distilbert_probabilities, axis=-1).numpy()[0]
        distilbert_predictions.append(True if distilbert_predicted_class_index == 1 else False)

        # Prepare inputs and get probabilities for RoBERTa variant 1
        roberta_inputs_v1 = roberta_tokenizer_v1(text_input, return_tensors="tf", truncation=True, padding='max_length', max_length=512)
        roberta_outputs_v1 = roberta_model_v1(roberta_inputs_v1)
        roberta_probabilities_v1 = tf.nn.softmax(roberta_outputs_v1.logits, axis=-1)
        roberta_v1_predicted_class_index = tf.argmax(roberta_probabilities_v1, axis=-1).numpy()[0]
        roberta_v1_predictions.append(True if roberta_v1_predicted_class_index == 1 else False)

        # Prepare inputs and get probabilities for RoBERTa variant 2
        roberta_inputs_v2 = roberta_tokenizer_v2(text_input, return_tensors="tf", truncation=True, padding='max_length', max_length=512)
        roberta_outputs_v2 = roberta_model_v2(roberta_inputs_v2)
        roberta_probabilities_v2 = tf.nn.softmax(roberta_outputs_v2.logits, axis=-1)
        roberta_v2_predicted_class_index = tf.argmax(roberta_probabilities_v2, axis=-1).numpy()[0]
        roberta_v2_predictions.append(True if roberta_v2_predicted_class_index == 1 else False)

        # Ensemble: Average the probabilities from all models
        avg_probabilities = (distilbert_probabilities + roberta_probabilities_v1 + roberta_probabilities_v2) / 3
        predicted_class_index = tf.argmax(avg_probabilities, axis=-1).numpy()[0]
        ensemble_predictions.append(True if predicted_class_index == 1 else False)

    # Adding predictions to the DataFrame
    df[f'DistilBERTPrediction_{column}'] = distilbert_predictions
    df[f'RoBERTaV1Prediction_{column}'] = roberta_v1_predictions
    df[f'RoBERTaV2Prediction_{column}'] = roberta_v2_predictions
    df[f'EnsemblePrediction_{column}'] = ensemble_predictions
    
    # Calculate and print the accuracy for the ensemble predictions
    # correct_predictions = (df[f'EnsemblePrediction_{column}'] == df['Label']).sum()
    # total_predictions = len(df)
    # accuracy = correct_predictions / total_predictions
    # print(f"Accuracy: {accuracy:.4f}")
    
    return df




All model checkpoint layers were used when initializing TFDistilBertForSequenceClassification.

All the layers of TFDistilBertForSequenceClassification were initialized from the model checkpoint at ./distilbert_finetuned.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.
All model checkpoint layers were used when initializing TFRobertaForSequenceClassification.

All the layers of TFRobertaForSequenceClassification were initialized from the model checkpoint at ./roberta_v1_finetuned.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForSequenceClassification for predictions without further training.
All model checkpoint layers were used when initializing TFRobertaForSequenceClassification.

All the layers of TFRobertaForSequenceClassification were initialized from the model checkpoint at ./roberta_v2_

In [3]:
# # Ensure the get_gpt4_response function is adjusted if necessary
def get_gpt4_response(client, prompt):
    chat_completion = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    return chat_completion.choices[0].message.content

key_path = './apiKey.txt'
with open(key_path, 'r') as file:
    key = file.readline().strip()

client = OpenAI(
    api_key = key
)

def get_counterfactual(df, client):
    # Iterate over the DataFrame using .iterrows() for reading; however, to modify the original df, use .loc for assignment
    for index, row in df.iterrows():
        input_text, label = row['text'], row['label']
        
        # Your existing prompt creation and GPT-4 querying logic
        # Simplified task description without source information
        task_description = "classifying tweets on COVID19 as misinformation or reliable information"
        
        prompt1 = f"""
        You are an oracle explanation module in a machine learning pipeline. In the task of {task_description},
        a trained black-box classifier correctly predicted the label
        {label} for the following headline. Think about why the model
        predicted the {label} label and identify the latent features
        that caused the label. List ONLY the latent features
        as a comma separated list, without any explanation.
        Examples of latent features are ‘credibility’, ‘tone’, ‘ambiguity in text’, etc.
        —
        Headline: {input_text}
        —
        Begin!
        """
        latent_features = get_gpt4_response(client, prompt1)
        
        prompt2 = f"""
        Original headline: {input_text}
        Label: {label}

        Identify the words in the headline that are associated
        with the latent features: {latent_features}, and output the
        identified words as a comma separated list.
        """
        identified_words = get_gpt4_response(client, prompt2)
        
        prompt3 = f"""
        Original headline: {input_text}
        Label: {label}

        Identified words associated with latent features: {identified_words}.
        Generate a minimally edited version of the original headline
        by ONLY changing a minimal set of the words you identified, in order to change the label. It is okay if the semantic meaning of the original headline is altered. Make sure the
        generated text makes sense and is plausible. Enclose the
        generated text within <new>tags.
        """
        counterfactual = get_gpt4_response(client, prompt3)
        
        # Update the original DataFrame directly
        df.loc[index, 'Latent Features'] = latent_features
        df.loc[index, 'Identified Words'] = identified_words
        df.loc[index, 'Counterfactual Text'] = counterfactual

    # No need to return a new DataFrame; the original df has been updated
    return df

In [4]:
val_data = pd.read_csv("./Data/Processed/Constraint_Val_Labeled.csv").drop(columns=["Unnamed: 0"])
display(val_data)

Unnamed: 0,text,label,DistilBERTPrediction_Text,RoBERTaV1Prediction_Text,RoBERTaV2Prediction_Text,EnsemblePrediction_Text
0,Chinese converting to Islam after realising th...,False,False,False,False,False
1,11 out of 13 people (from the Diamond Princess...,False,False,False,False,False
2,"COVID-19 Is Caused By A Bacterium, Not Virus A...",False,False,False,False,False
3,Mike Pence in RNC speech praises Donald Trump’...,False,False,False,False,False
4,6/10 Sky's @EdConwaySky explains the latest #C...,True,True,True,True,True
...,...,...,...,...,...,...
2135,Donald Trump wrongly claimed that New Zealand ...,False,False,False,False,False
2136,Current understanding is #COVID19 spreads most...,True,True,True,True,True
2137,Nothing screams “I am sat around doing fuck al...,False,False,False,False,False
2138,Birx says COVID-19 outbreak not under control ...,False,False,True,False,False


In [5]:
test = val_data.sample(10) # change to sample as many rows as needed
test = get_counterfactual(test, client)
display(test)

Unnamed: 0,text,label,DistilBERTPrediction_Text,RoBERTaV1Prediction_Text,RoBERTaV2Prediction_Text,EnsemblePrediction_Text,Latent Features,Identified Words,Counterfactual Text
1811,@ePearce4Q @das_realestate_ In Djibouti where ...,False,False,False,False,False,"credibility, statistics, treatment effectiveness","Chloroquine, treated, statistics, population, ...",<new>@ePearce4Q @das_realestate_ In Djibouti w...
757,How has alcohol consumption changed during loc...,False,True,False,False,False,"ambiguity in text, informal tone, credibility","lockdown, drinking habits, pandemic, short-ter...",<new>How has alcohol consumption changed durin...
1019,354 new cases of #COVID19Nigeria; FCT-78 Lagos...,True,True,True,True,True,"high number of new cases, specific locations m...","new cases, FCT, Lagos, Kaduna, Ebonyi, Oyo, Na...",<new> 354 new cases of #COVID19Nigeria; FCT-78...
1211,A doctor who went to Uttar Pradesh (a state in...,False,False,False,False,False,"controversial tone, religious bias, lack of cr...","attacked, stones, Muslims, succumbed to her in...",<new>A doctor who went to Uttar Pradesh (a sta...
725,Multiple Facebook posts shared hundreds of tim...,False,False,False,False,False,"sensationalism, lack of credible sources, inac...","sensationalism, lack of credible sources, inac...",<new>Multiple Facebook posts shared hundreds o...
1545,Italy's Ministry of Health has discovered that...,False,False,False,False,False,"credibility, misinformation, scientific inaccu...","False, misinformation, scientific inaccuracy",<new>Italy's Ministry of Health has discovered...
1264,Twenty new cases of #COVID19 have been reporte...,True,True,True,True,True,"location, numbers, statistics, time, discharge...","location, numbers, statistics, time, discharge...",<new> Ten new cases of #COVID19 have been repo...
1947,#Vermont has a low percentage (0.6%) of positi...,True,True,True,True,True,"credibility, testing coverage, prevention stra...","aggressive testing, widespread testing, identi...",<new>Vermont has a high percentage (0.6%) of p...
1126,CDC has new info to help camps youth sports K1...,True,True,True,True,True,"credibility, expertise, collaboration, prevent...","CDC, new info, camps, youth sports, K12 school...",<new>CDC has outdated info to help camps youth...
263,Coronavirus: Head of Test and Trace says rise ...,True,True,True,True,True,"credibility, expertise, unexpected event","Test and Trace, rise in demand, COVID tests, u...",<new>Coronavirus: Boss of Test and Trace says ...


In [6]:
test = ensemble_classify_news_and_evaluate_accuracy(test, 'Counterfactual Text')
display(test)

Unnamed: 0,text,label,DistilBERTPrediction_Text,RoBERTaV1Prediction_Text,RoBERTaV2Prediction_Text,EnsemblePrediction_Text,Latent Features,Identified Words,Counterfactual Text,DistilBERTPrediction_Counterfactual Text,RoBERTaV1Prediction_Counterfactual Text,RoBERTaV2Prediction_Counterfactual Text,EnsemblePrediction_Counterfactual Text
1811,@ePearce4Q @das_realestate_ In Djibouti where ...,False,False,False,False,False,"credibility, statistics, treatment effectiveness","Chloroquine, treated, statistics, population, ...",<new>@ePearce4Q @das_realestate_ In Djibouti w...,False,False,False,False
757,How has alcohol consumption changed during loc...,False,True,False,False,False,"ambiguity in text, informal tone, credibility","lockdown, drinking habits, pandemic, short-ter...",<new>How has alcohol consumption changed durin...,True,False,True,True
1019,354 new cases of #COVID19Nigeria; FCT-78 Lagos...,True,True,True,True,True,"high number of new cases, specific locations m...","new cases, FCT, Lagos, Kaduna, Ebonyi, Oyo, Na...",<new> 354 new cases of #COVID19Nigeria; FCT-78...,True,True,True,True
1211,A doctor who went to Uttar Pradesh (a state in...,False,False,False,False,False,"controversial tone, religious bias, lack of cr...","attacked, stones, Muslims, succumbed to her in...",<new>A doctor who went to Uttar Pradesh (a sta...,False,False,False,False
725,Multiple Facebook posts shared hundreds of tim...,False,False,False,False,False,"sensationalism, lack of credible sources, inac...","sensationalism, lack of credible sources, inac...",<new>Multiple Facebook posts shared hundreds o...,False,False,False,False
1545,Italy's Ministry of Health has discovered that...,False,False,False,False,False,"credibility, misinformation, scientific inaccu...","False, misinformation, scientific inaccuracy",<new>Italy's Ministry of Health has discovered...,False,False,False,False
1264,Twenty new cases of #COVID19 have been reporte...,True,True,True,True,True,"location, numbers, statistics, time, discharge...","location, numbers, statistics, time, discharge...",<new> Ten new cases of #COVID19 have been repo...,True,True,True,True
1947,#Vermont has a low percentage (0.6%) of positi...,True,True,True,True,True,"credibility, testing coverage, prevention stra...","aggressive testing, widespread testing, identi...",<new>Vermont has a high percentage (0.6%) of p...,True,True,True,True
1126,CDC has new info to help camps youth sports K1...,True,True,True,True,True,"credibility, expertise, collaboration, prevent...","CDC, new info, camps, youth sports, K12 school...",<new>CDC has outdated info to help camps youth...,True,True,True,True
263,Coronavirus: Head of Test and Trace says rise ...,True,True,True,True,True,"credibility, expertise, unexpected event","Test and Trace, rise in demand, COVID tests, u...",<new>Coronavirus: Boss of Test and Trace says ...,True,True,False,True
