In [23]:
import pandas as pd
import numpy as np

In [2]:
# nltk.download('stopwords')
# nltk.download('wordnet')

In [3]:
train_data = pd.read_csv('./Constraint_Train.csv')
test_data = pd.read_csv('./Constraint_Val.csv')

In [5]:
train_data.head()

Unnamed: 0,id,tweet,label
0,1,The CDC currently reports 99031 deaths. In gen...,real
1,2,States reported 1121 deaths a small rise from ...,real
2,3,Politically Correct Woman (Almost) Uses Pandem...,fake
3,4,#IndiaFightsCorona: We have 1524 #COVID testin...,real
4,5,Populous states can generate large case counts...,real


In [6]:
test_data.head()

Unnamed: 0,id,tweet,label
0,1,Chinese converting to Islam after realising th...,fake
1,2,11 out of 13 people (from the Diamond Princess...,fake
2,3,"COVID-19 Is Caused By A Bacterium, Not Virus A...",fake
3,4,Mike Pence in RNC speech praises Donald Trump’...,fake
4,5,6/10 Sky's @EdConwaySky explains the latest #C...,real


In [33]:
train_data.shape, test_data.shape

((6420, 3), (2140, 3))

In [8]:
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.base import BaseEstimator, TransformerMixin

class Preprocessor(BaseEstimator, TransformerMixin):
    def __init__(self):
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()
        
    def clean_text(self, text):
        # Remove URLs and special characters
        text = re.sub(r"http\S+|www\S+|https\S+", '', text, flags=re.MULTILINE)
        text = re.sub(r'\@w+|\#', '', text)
        text = re.sub(r"[^a-zA-Z]", " ", text)
        
        # Lowercase, remove stopwords and lemmatize
        words = text.lower().split()
        words = [self.lemmatizer.lemmatize(word) for word in words if word not in self.stop_words]
        
        return " ".join(words)
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        return X.apply(self.clean_text)

In [12]:
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import BernoulliNB
from sklearn.pipeline import Pipeline
from sklearn.metrics import balanced_accuracy_score
from sklearn.base import BaseEstimator, TransformerMixin

In [13]:


X_train, X_test, y_train, y_test = train_test_split(train_data['tweet'], train_data['label'], test_size=0.2, random_state=42)

In [25]:
pipeline = Pipeline([
    ('preprocess', Preprocessor()),
    ('vectorizer', TfidfVectorizer(binary=True)),
    ('classifier', BernoulliNB())
])
pipeline

In [26]:
pipeline.fit(X_train, y_train)

vectorizer = pipeline.named_steps['vectorizer']
classifier = pipeline.named_steps['classifier']

y_pred = pipeline.predict(X_test)

In [27]:
balanced_accuracy = balanced_accuracy_score(y_test, y_pred)
print("Accuracy:", balanced_accuracy)

Accuracy: 0.900923794287498


In [29]:
# Get feature names (words) and their corresponding log probabilities
feature_names = vectorizer.get_feature_names_out()
log_probs = classifier.feature_log_prob_


# Get the words that are more important for each class
real_class_important_words = sorted(zip(log_probs[0], feature_names), reverse=True)[:10]
fake_class_important_words = sorted(zip(log_probs[1], feature_names), reverse=True)[:10]


In [31]:
print("Real class important words:")
real_class_important_words

Real class important words:


[(-0.8936376633706811, 'covid'),
 (-0.942378274753997, 'coronavirus'),
 (-2.372273374801094, 'people'),
 (-2.5792440668697028, 'pandemic'),
 (-2.589996858645965, 'trump'),
 (-2.6063459966474944, 'say'),
 (-2.6118556524584635, 'virus'),
 (-2.6342029511504608, 'claim'),
 (-2.645566709800775, 'new'),
 (-2.766927566805043, 'video')]

In [32]:
print("Fake class important words:")
fake_class_important_words

Fake class important words:


[(-0.5708037953891489, 'covid'),
 (-1.1508113980556658, 'case'),
 (-1.5723626439154543, 'new'),
 (-1.6071965965910868, 'state'),
 (-1.8182862235614836, 'test'),
 (-1.8849775980601562, 'number'),
 (-1.9303254180386151, 'death'),
 (-2.0305445341960233, 'total'),
 (-2.083188267681445, 'india'),
 (-2.1842204929130133, 'day')]