In [1]:
!pip install wordcloud

from datasets import load_from_disk
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud
from collections import Counter
import re
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import nltk

nltk.download('punkt')
nltk.download('stopwords')



[nltk_data] Downloading package punkt to
[nltk_data]     /Users/imenbenammar/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/imenbenammar/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [2]:
# Defining functions to use

def load_full_dataset(dataset_path):
    """
    Loads the dataset from the specified path, combines the train, validation, and test splits into a single DataFrame, 
    shuffles the data, and returns the full dataset.
    
    Input:
    - dataset_path (str): Path to the dataset.
    Output:
    - full_df (pd.DataFrame): The concatenated and shuffled DataFrame containing the full dataset.
    """
    dataset = load_from_disk(dataset_path)
    # convert the train, validation, and test splits to a pd dataframe
    train_df = dataset['train'].to_pandas()
    test_df = dataset['test'].to_pandas()
    validation_df = dataset['validation'].to_pandas()   
    full_df = pd.concat([train_df, validation_df, test_df], ignore_index=True)
    # shuffle
    full_df = full_df.sample(frac=1, random_state=42).reset_index(drop=True)
    return full_df
    
def display_dataset_info(df):
    """
    Displays dataset info and the first 5 rows.
    
    Input:
    - df (pd.DataFrame): The dataset.
    
    Output:
    - Prints dataset info and first 5 rows.
    """

    print("Dataset info: ")
    df.info()
    print("----------------------------------------------------------------")
    # display the first 5 rows of the dataset
    print("\n First 5 rows of the dataset: ")
    print(df.head())
    print("----------------------------------------------------------------")

def count_nulls(df):
    """
    Counts and returns null values per column.
    
    Input:
    - df (pd.DataFrame): The dataset.
    
    Output:
    - null_counts (pd.Series): Count of nulls per column.
    """
    null_counts = df.isnull().sum()
    return null_counts

def clean_and_tokenize(text, stop_words):
    """
    Cleans and tokenizes text, removing non-alphanumeric characters and stopwords.
    
    Input:
    - text (str): The text to clean and tokenize.
    - stop_words (list): List of stopwords.
    
    Output:
    - filtered_tokens (list): List of cleaned tokens.
    """
    text = re.sub(r'[^A-Za-z0-9\s]', '', text.lower())
    tokens = word_tokenize(text)
    filtered_tokens = [word for word in tokens if word not in stop_words]
    return filtered_tokens

In [3]:
# define functions for datasets 1 and 2 

def plot_toxicity_subtype_frequency(df, dataset=None):
    """
    Plots the frequency of non-zero values for each toxicity subtype column in the dataset.
    
    Input:
    - df (pd.DataFrame): The dataset containing toxicity columns.
    - dataset (str): Name of the dataset ('Jigsaw' or 'civil_comments') to select the correct text column.
    
    Output:
    - Displays a bar plot of the frequencies of non-zero values for each toxicity subtype.
    """
    if dataset == 'Jigsaw':
        toxicity_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    elif dataset == 'civil_comments':
        toxicity_columns = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
    else:
        raise ValueError("Dataset must be either 'Jigsaw' or 'civil_comments'")
    
    # list to store the frequency of non-zero values for each column
    frequencies = []
    # counting non zero values in each toxicity column
    for col in toxicity_columns:
        non_zero_count = (df[col] > 0).sum() 
        frequencies.append(non_zero_count)

    # plot 
    plt.figure(figsize=(7, 4))
    sns.barplot(x=toxicity_columns, y=frequencies, palette="viridis")
    plt.title('Frequency of Each Toxicity Subtype')
    plt.xlabel('Toxicity Subtype')
    plt.ylabel('Frequency')
    plt.xticks(rotation=45)
    plt.grid()
    plt.show()


def plot_toxicity_percentage(df, dataset=None):
    """
    Plots the percentage of toxic and non-toxic comments based on the specified toxicity labels.
    
    Input:
    - df (pd.DataFrame): The dataset containing toxicity labels.
    - dataset (str): Name of the dataset ('Jigsaw' or 'civil_comments') to select the correct text column.
    
    Output:
    - Displays a pie chart showing the percentage of toxic and non-toxic comments.
    """
    if dataset == 'Jigsaw':
        toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    elif dataset == 'civil_comments':
        toxicity_labels = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
    else:
        raise ValueError("Dataset must be either 'Jigsaw' or 'civil_comments'")
        
    # check if any toxicity label is set for each row
    df['is_toxic'] = df[toxicity_labels].sum(axis=1) > 0  
    # count the number of toxic and non-toxic comments
    toxicity_counts = df['is_toxic'].value_counts()
    
    # prepare data for the pie chart
    labels = ['Non-Toxic', 'Toxic']
    sizes = [toxicity_counts.get(False, 0), toxicity_counts.get(True, 0)]
    explode = (0.1, 0) 
    # plot
    plt.figure(figsize=(7, 4))
    plt.pie(sizes, explode=explode, labels=labels, colors=['#3b528b', '#5ec962'], autopct='%1.1f%%', shadow=True, startangle=90)
    plt.title('Percentage of Toxic and Non-Toxic Comments')
    plt.axis('equal') 
    plt.show()


def plot_wordclouds(df, dataset=None):
    """
    Generates and displays word clouds for toxic and non-toxic comments.
    
    Input:
    - df (pd.DataFrame): The dataset containing the comments and toxicity labels.
    - dataset (str): Name of the dataset ('Jigsaw' or 'civil_comments').
    
    Output:
    - Displays two word clouds, one for toxic and one for non-toxic comments.
    """
    if dataset == 'Jigsaw':
        toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    elif dataset == 'civil_comments':
        toxicity_labels = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
    else:
        raise ValueError("Dataset must be either 'Jigsaw' or 'civil_comments'")
        
    # create a new column 'is_toxic' and set it to true if one toxicity subtype is present
    df['is_toxic'] = df[toxicity_labels].sum(axis=1) > 0
    # set column name for text
    if (dataset == 'Jigsaw'):
        text_col = 'comment_text'
    elif (dataset == 'civil_comments'):
        text_col = 'text'
    else:
        raise ValueError("Dataset must be either 'Jigsaw' or 'civil_comments'")
    # separate toxic and non-toxic comments
    toxic = ' '.join(df[df['is_toxic']][text_col])
    non_toxic = ' '.join(df[~df['is_toxic']][text_col])
    
    # generate wordclouds
    toxic_wordcloud = WordCloud(width=800, height=400, background_color='white', colormap='Reds', max_words=100).generate(toxic)
    non_toxic_wordcloud = WordCloud(width=800, height=400, background_color='white', colormap='Greens', max_words=100).generate(non_toxic)
    # plot
    plt.figure(figsize=(14, 7))
    # toxic comments word cloud
    plt.subplot(1, 2, 1)
    plt.title('Word Cloud of Toxic Comments', fontsize=16)
    plt.imshow(toxic_wordcloud, interpolation='bilinear')
    plt.axis('off')
    # non-toxic comments word cloud
    plt.subplot(1, 2, 2)
    plt.title('Word Cloud of Non-Toxic Comments', fontsize=16)
    plt.imshow(non_toxic_wordcloud, interpolation='bilinear')
    plt.axis('off')
    plt.tight_layout()
    plt.show()


def analyze_comment_len(df, dataset=None):
    """
    Analyzes comment lengths and compares them between toxic and non-toxic comments.
    
    Input:
    - df (pd.DataFrame): Dataset containing comments and toxicity labels.
    - dataset (str): Dataset name ('Jigsaw' or 'civil_comments').
    
    Output:
    - Prints basic statistics and displays plots for comment length distribution, comparison between toxic and non-toxic comments, 
    and average comment length by toxicity.
    """
    if dataset == 'Jigsaw':
        toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        # set column name for text
        text_col = 'comment_text'
    elif dataset == 'civil_comments':
        toxicity_labels = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
        text_col = 'text'
    else:
        raise ValueError("Dataset must be either 'Jigsaw' or 'civil_comments'")
        
    # create new column and calculate comment lengths
    df['comment_length'] = df[text_col].apply(len)
    
    # display basic statistics of comment lengths
    print("Basic Statistics for Comment Lengths:")
    print(df['comment_length'].describe())
    print("\n")
    
    # plot for distribution of comment lengths
    plt.figure(figsize=(7,4))
    sns.histplot(df['comment_length'], bins=20, kde=True, color='#21918c')
    plt.title('Distribution of Comment Lengths')
    plt.xlabel('Comment Length')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()
    
    # plot to compare comment lengths between Toxic and Non-Toxic comments
    df['is_toxic'] = df[toxicity_labels].sum(axis=1) > 0
    plt.figure(figsize=(7, 4))
    sns.boxplot(data=df, x='is_toxic', y='comment_length', palette='viridis')
    plt.xticks([0, 1], ['Non-Toxic', 'Toxic'])
    plt.title('Comment Length Comparison: Toxic vs. Non-Toxic')
    plt.xlabel('Comment Type')
    plt.ylabel('Comment Length')
    plt.grid(True)
    plt.show()
  
    # plot to visualize the average comment length by toxicity subtype
    plt.figure(figsize=(7, 4))
    toxicity_avg_length = df.groupby('is_toxic')['comment_length'].mean().reset_index()
    sns.barplot(data=toxicity_avg_length, x='is_toxic', y='comment_length', palette='viridis')
    plt.xticks([0, 1], ['Non-Toxic', 'Toxic'])
    plt.title('Average Comment Length: Toxic vs. Non-Toxic')
    plt.xlabel('Comment Type')
    plt.ylabel('Average Comment Length')
    plt.grid(True)
    plt.show()


def plot_word_frequency(df, dataset=None, threshold=0.8, top_n=10):
    """
    Plots the most common words (top N) in toxic comments for each toxicity label above a specified threshold.
    
    Input:
    - df (pd.DataFrame): Dataset containing comments and toxicity labels.
    - dataset (str): Name of the dataset ('Jigsaw' or 'civil_comments') to select the correct text column.
    - threshold (float, optional): Minimum toxicity level to consider a comment as toxic (default is 0.8).
    - top_n (int, optional): Number of most common words to display (default is 10).
    
    Output:
    - Displays bar plots for the most common words in toxic comments for each label.
    """

    if dataset == 'Jigsaw':
        toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        # set column name for text
        text_col = 'comment_text'
    elif dataset == 'civil_comments':
        toxicity_labels = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
        text_col = 'text'
    else:
        raise ValueError("Dataset must be either 'Jigsaw' or 'civil_comments'")
    
    # define stop words
    stop_words = set(stopwords.words('english'))
    
    # loop through each toxicity label
    for label in toxicity_labels:
        # filter rows where the toxicity value for the label is higher than the threshold
        is_toxic = df[label] > threshold
        toxic_comments = df[is_toxic][text_col]
        
        # tokenize and clean all comments
        all_tokens = []
        for comment in toxic_comments:
            all_tokens.extend(clean_and_tokenize(comment, stop_words))
        
        # frequency of each word
        word_counts = Counter(all_tokens)
        # select top_n most common words
        common_words = word_counts.most_common(top_n)
        # plot 
        if common_words:
            words, counts = zip(*common_words)
            plt.figure(figsize=(8, 6))
            sns.barplot(x=list(counts), y=list(words), palette='viridis')
            plt.title(f'Most Common Words for {label.capitalize()} (Threshold > {threshold})')
            plt.xlabel('Frequency')
            plt.ylabel('Words')
            plt.grid(True)
            plt.show()
        else:
            print(f"No words found for {label.capitalize()} with threshold > {threshold}.")


def plot_toxicity_vs_comment_length(df,dataset=None):
    """
    Plots a scatter plot comparing comment length with a specified toxicity label.
    
    Input:
    - df (pd.DataFrame): Dataset containing comments and the toxicity labels.
    - dataset (str): Name of the dataset ('Jigsaw' or 'civil_comments') to select the correct text column.
    
    Output:
    - Displays a scatter plot of comment length vs the specified toxicity label.
    """
    if dataset == 'Jigsaw':
        toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        # set column name for text
        text_col = 'comment_text'
    elif dataset == 'civil_comments':
        toxicity_labels = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
        text_col = 'text'
    else:
        raise ValueError("Dataset must be either 'Jigsaw' or 'civil_comments'")
    
    
    # calculate comment length if not already done
    if 'comment_length' not in df.columns:
        df['comment_length'] = df[text_col].apply(len)
    for label in toxicity_labels:
        # plot
        plt.figure(figsize=(10, 6))
        sns.scatterplot(x=df['comment_length'], y=df[label], alpha=0.3, color='#21918c')
        plt.title(f'Scatter Plot: Comment Length vs. {label.capitalize()}')
        plt.xlabel('Comment Length')
        plt.ylabel(f'{label.capitalize()} Level')
        plt.grid(True)
        plt.show()




In [4]:
# define functions for dataset 3

def plot_sentiment_frequency_sst2(df):
    """
    Plots the distribution of sentiment labels in the SST-2 dataset.
    
    Input:
    - df (pd.DataFrame): Dataset containing a 'label' column with sentiment values (0 for negative, 1 for positive).
    
    Output:
    - Displays a bar plot showing the frequency of positive and negative sentiments.
    """
    sentiment_counts = df['label'].value_counts()
    #plot
    plt.figure(figsize=(7, 4))
    sns.barplot(x=sentiment_counts.index, y=sentiment_counts.values, palette='viridis')
    plt.xticks([0, 1], ['Negative', 'Positive'])
    plt.title('Sentiment Distribution')
    plt.xlabel('Sentiment')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

def plot_sentiment_percentage_sst2(df):
    """
    Plots a pie chart showing the percentage of positive and negative sentiments in the SST-2 dataset.
    
    Input:
    - df (pd.DataFrame): Dataset containing a 'label' column with sentiment values (0 for negative, 1 for positive).
    
    Output:
    - Displays a pie chart of the sentiment distribution.
    """
    label_counts = df['label'].value_counts()
    # map the labels positive and negative
    labels = ['Negative', 'Positive']
    sizes = label_counts.values

    # plot
    plt.figure(figsize=(7, 4))
    explode = (0.1, 0) 
    plt.pie(sizes, explode=explode, labels=labels, autopct='%1.1f%%', startangle=90, colors=['#3b528b', '#5ec962'])
    plt.title('Percentage of Positive and Negative Sentences')
    plt.axis('equal') 
    plt.show()

    
def plot_wordclouds_sst2(df):
    """
    Generates and displays word clouds for positive and negative sentences in the SST-2 dataset.
    
    Input:
    - df (pd.DataFrame): Dataset containing a 'label' column (0 for negative, 1 for positive) 
      and a 'sentence' column with text data.
    
    Output:
    - Displays two word clouds, one for positive sentences and one for negative sentences.
    """
    positive_sentence = ' '.join(df[df['label'] == 1]['sentence'])
    negative_sentence= ' '.join(df[df['label'] == 0]['sentence'])
    negative_wordcloud = WordCloud(width=800, height=400, background_color='white', colormap='Reds').generate(negative_sentence)
    positive_wordcloud = WordCloud(width=800, height=400, background_color='white', colormap='Greens').generate(positive_sentence)    

    # plot 
    plt.figure(figsize=(14, 7))
    plt.subplot(1, 2, 1)
    plt.imshow(positive_wordcloud, interpolation='bilinear')
    plt.title('Word Cloud: Positive Sentences', fontsize=16)
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.imshow(negative_wordcloud, interpolation='bilinear')
    plt.title('Word Cloud: Negative Sentences', fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.show()


        
def analyze_sentence_len_sst2(df):
    """
    Analyzes and visualizes sentence lengths in the SST-2 dataset.
    
    Input:
    - df (pd.DataFrame): Dataset containing a 'sentence' column with text data 
      and a 'label' column (0 for negative, 1 for positive).
    
    Output:
    - Displays basic statistics for sentence lengths, a histogram for their distribution, 
      and a box plot comparing lengths between positive and negative sentiments.
    """
    # calculate sentence lengths
    df['sentence_length'] = df['sentence'].apply(len)
    # basic statistics
    print("Basic Statistics for Sentence Lengths:")
    print(df['sentence_length'].describe())
    print("\n")
    
    # plot the distribution of sentence lengths
    plt.figure(figsize=(10, 4))
    sns.histplot(df['sentence_length'], bins=20, kde=True, color='#21918c')
    plt.title('Distribution of Sentence Lengths')
    plt.xlabel('Sentence Length')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

    # compare lengths for positive and negative sentiments
    plt.figure(figsize=(7, 4))
    sns.boxplot(data=df, x='label', y='sentence_length', palette='viridis')
    plt.xticks([0, 1], ['Negative', 'Positive'])
    plt.title('Sentence Length Comparison: Negative vs. Positive')
    plt.xlabel('Sentiment')
    plt.ylabel('Sentence Length')
    plt.grid(True)
    plt.show()

def plot_avg_sentence_length_by_sentiment_sst2(df):
    """
    Plots the average sentence length for negative and positive sentiments in the SST-2 dataset.
    
    Input:
    - df (pd.DataFrame): Dataset containing a 'sentence' column with text data 
      and a 'label' column (0 for negative, 1 for positive).
    
    Output:
    - Displays a bar plot comparing the average sentence length for each sentiment.
    """
    # calculate sentence length 
    if 'sentence_length' not in df.columns:
        df['sentence_length'] = df['sentence'].apply(len)
    # average sentence length by sentiment
    avg_length = df.groupby('label')['sentence_length'].mean().reset_index()
    
    # plot
    plt.figure(figsize=(7, 4))
    sns.barplot(data=avg_length, x='label', y='sentence_length', palette='viridis')
    plt.xticks([0, 1], ['Negative', 'Positive'])
    plt.title('Average Sentence Length: Negative vs. Positive')
    plt.xlabel('Sentiment')
    plt.ylabel('Average Sentence Length')
    plt.grid(True)
    plt.show()

def plot_word_frequency_sst2(df, top_n=10):
    """
    Plots the most common words in positive or negative sentences from the SST-2 dataset.
    
    Input:
    - df (pd.DataFrame): Dataset containing a 'sentence' column with text data 
      and a 'label' column (0 for negative, 1 for positive).
    - top_n (int, optional): Number of most frequent words to display (default is 10).
    
    Output:
    - Displays a bar plot showing the frequency of the top words for the selected sentiment.
    """
    labels = [0,1]
    #define stop words
    stop_words = set(stopwords.words('english'))
    for label in labels:
        # filter sentences by sentiment
        text = ' '.join(df[df['label'] == label]['sentence'])
        
        tokens = clean_and_tokenize(text, stop_words)
        # select the most common words
        word_counts = Counter(tokens)
        common_words = word_counts.most_common(top_n)
    
        # plot
        words, counts = zip(*common_words)
        plt.figure(figsize=(7, 4))
        sns.barplot(x=list(counts), y=list(words), palette='viridis')
        plt.title(f'Most Common Words in {"Positive" if label == 1 else "Negative"} Sentences')
        plt.xlabel('Frequency')
        plt.ylabel('Words')
        plt.grid(True)
        plt.show()



In [None]:
def show(plot_type, outfile=None, dataset=None):
    """
    Dispatches visualization functions based on `plot_type` and dataset.

    Parameters:
    - plot_type (str): The type of plot to generate.
    - outfile (str, optional): Path to save the plot (default is None).
    - dataset (str, optional): Specifies the dataset (e.g., "Jigsaw", "civil_comments", or "SST2").

    Output:
    - Displays the plot and optionally saves it to a file.
    """
    # load the corresponding dataset
    if dataset == "Jigsaw":
        try:
            toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
            train_df = pd.read_csv('../data/jigsaw_toxicity_pred/train.csv')
            initial_test_df = pd.read_csv('../data/jigsaw_toxicity_pred/test.csv')
            test_labels_df = pd.read_csv('../data/jigsaw_toxicity_pred/test_labels.csv')
            
            # filter out rows with toxicity labels = -1
            valid_labels_df = test_labels_df[~test_labels_df[toxicity_labels].eq(-1).any(axis=1)]
            # merge test data with the filtered valid labels only
            test_df = pd.merge(initial_test_df, valid_labels_df, on="id", how="inner")
            # combine both train and test sets
            df = pd.concat([train_df, test_df], ignore_index=True)
        except Exception as e:
            raise ValueError(f"Error loading dataset: {e}")
            
        display_dataset_info(df)

    elif dataset == "civil_comments":
        try:
            toxicity_labels = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
            df = load_full_dataset("../data/finetuning_datasets/civil_comments")
        except Exception as e:
            raise ValueError(f"Error loading dataset: {e}")
        
        display_dataset_info(df)
        
    elif dataset == "SST2":
        try:
            df = load_full_dataset("../data/finetuning_datasets/sst2")
            
            # determine the number of hidden labels (-1)
            hidden_label_count = (df['label'] == -1).sum()
            print(f"Number of hidden labels in the dataset: {hidden_label_count}")
            # filter out hidden labels
            df = df[df['label'] != -1]
        except Exception as e:
            raise ValueError(f"Error loading dataset: {e}")
            
        display_dataset_info(df)
        
    else:
        raise ValueError(f"Unsupported dataset '{dataset}'. Available datasets: Jigsaw, civil_comments, or SST2.")
        

    # check for null values
    count_nulls(df)
    
    # assign function mappings for the datasets
    jigsaw_cc_plots = {
        "toxicity_frequency": plot_toxicity_subtype_frequency,
        "toxicity_percentage": plot_toxicity_percentage,
        "wordclouds": plot_wordclouds,
        "comment_length_analysis": analyze_comment_len,
        "toxicity_vs_length": plot_toxicity_vs_comment_length,
        "word_frequency": plot_word_frequency,
        
    }

    sst2_plots = {
        "sentiment_frequency": plot_sentiment_frequency_sst2,
        "sentiment_percentage": plot_sentiment_percentage_sst2,
        "wordclouds": plot_wordclouds_sst2,
        "sentence_length_analysis": analyze_sentence_len_sst2,
        "avg_sentence_length": plot_avg_sentence_length_by_sentiment_sst2,
        "word_frequency": plot_word_frequency_sst2,
    }

    # select the corresponding function dictionary
    plot_functions = {}
    if dataset == "Jigsaw" or dataset == 'civil_comments':
        plot_functions = jigsaw_cc_plots
    elif dataset == "SST2":
        plot_functions = sst2_plots
    else:
        raise ValueError(f"Unsupported dataset '{dataset}'. Available datasets: Jigsaw, Civil_comments, or SST2.")

    # validate the plot_type
    if plot_type not in plot_functions:
        available_types = list(plot_functions.keys())
        raise ValueError(f"Unsupported plot_type '{plot_type}'. Available types: {available_types}")

    # call the plot function
    plot_func = plot_functions[plot_type]
    if dataset == "SST2":
        plot_func(df)
    else:
        plot_func(df, dataset)

    # save the plot if outfile is provided
    if outfile:
        plt.savefig(outfile, dpi=kwargs.get("dpi", 300))
        print(f"Visualization saved to {outfile}")



In [None]:
# Examples
# Note: Runninf the following lines will print the dataset info many times, as the show() function is intented to be run only once for 
# a plot at a time

show('toxicity_frequency', outfile=None, dataset="Jigsaw")
show('toxicity_frequency', outfile=None, dataset="civil_comments")
#show('toxicity_percentage', outfile=None, dataset="Jigsaw")
#show('wordclouds', outfile=None, dataset="Jigsaw")
#show('comment_length_analysis', outfile=None, dataset="Jigsaw")
#show('toxicity_vs_length', outfile=None, dataset="Jigsaw")
#show('word_frequency', outfile=None, dataset="Jigsaw")

show('sentiment_frequency', outfile=None, dataset="SST2")
#show('sentiment_percentage', outfile=None, dataset="SST2")
#show('word_frequency', outfile=None, dataset="SST2")
