In [None]:
import shap
import cv2
import os
import torch
import numpy as np
from PIL import Image
import pytesseract
import tensorflow as tf
from transformers import TFRobertaForSequenceClassification, RobertaTokenizer, ViTForImageClassification, ViTImageProcessor
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Set Tesseract path
pytesseract.pytesseract.tesseract_cmd = '/usr/bin/tesseract'

# Load RoBERTa tokenizer (common for both models)
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Load ViT model
vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=4)

# Load RoBERTa model
text_model = TFRobertaForSequenceClassification.from_pretrained('/content/drive/My Drive/Colab Notebooks/public-data/models/roberta-cyberbullying-classifier')

# Load Image model
image_model_path = '/content/drive/My Drive/Colab Notebooks/public-data/models/my_vit_model.pth'
vit_model.load_state_dict(torch.load(image_model_path, map_location=torch.device('cpu')))

# Define function to preprocess the image
def preprocess_final(im):
    im = cv2.bilateralFilter(im, 5, 55, 60)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    _, im = cv2.threshold(im, 240, 255, 1)
    return im

# Define function to extract text from an image
def extract_text(image_path, custom_config=r"--oem 3 --psm 11 -c tessedit_char_whitelist= 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz '"):
    img = cv2.imread(image_path)
    img = preprocess_final(img)
    text = pytesseract.image_to_string(img, lang='eng', config=custom_config)
    return text.replace('\n', ' ')

# Define function to classify text using RoBERTa
def classify_text(text):
    inputs = roberta_tokenizer.encode_plus(
        text, add_special_tokens=True, max_length=512,
        padding='max_length', truncation=True, return_tensors="tf"
    )
    roberta_prediction = text_model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
    roberta_probs = tf.nn.softmax(roberta_prediction.logits, axis=1)
    text_class = np.argmax(roberta_probs, axis=1)[0]
    return text_class

# Define function to classify image using ViT
def classify_image(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = vit_processor(images=image, return_tensors="pt")['pixel_values']
    with torch.no_grad():
        outputs = vit_model(input_tensor)
        vit_probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        image_class = torch.argmax(vit_probs, dim=1).item()
    return image_class

# Define function for late fusion
def late_fusion(text_class, image_class):
    if text_class == image_class:
        if text_class == 0:
            return "Input does not contain any Cyber-bullying."
        else:
            return f"Input contains this class {text_class} of cyberbullying."
    else:
        return f"Input contains cyberbullying. Text label is: {text_class} and Image label is: {image_class}"

# Main function to handle input and perform classification
def process_input(image_path):
    extracted_text = extract_text(image_path)
    if extracted_text:
        text_class = classify_text(extracted_text)
    else:
        text_class = None
    image_class = classify_image(image_path)
    if text_class is not None:
        fusion_message = late_fusion(text_class, image_class)
    else:
        fusion_message = "No text found to classify."

    # SHAP explanation for text classification
    text_shap_values = explain_text_with_shap(extracted_text)

    # SHAP explanation for image classification
    image_shap_values = explain_image_with_shap(image_path)

    return {
        'extracted_text': extracted_text,
        'text_label': f"Text label: {text_class}" if text_class is not None else "No text prediction",
        'image_label': f"Image label: {image_class}",
        'fusion_message': fusion_message,
        'text_shap_values': text_shap_values,
        'image_shap_values': image_shap_values
    }

# Function to explain text predictions using SHAP
def explain_text_with_shap(text):
    # Prepare data for SHAP
    inputs = roberta_tokenizer.encode_plus(
        text, add_special_tokens=True, max_length=512,
        padding='max_length', truncation=True, return_tensors="tf"
    )

    # Create SHAP explainer for the text classification model
    explainer = shap.Explainer(text_model, inputs['input_ids'])
    shap_values = explainer(inputs['input_ids'])

    # Visualize the SHAP values
    shap.initjs()
    shap.visualize(shap_values[0])
    return shap_values

# Function to explain image predictions using SHAP
def explain_image_with_shap(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = vit_processor(images=image, return_tensors="pt")['pixel_values']

    # Create SHAP explainer for the image classification model
    explainer = shap.Explainer(vit_model, input_tensor)
    shap_values = explainer(input_tensor)

    # Visualize the SHAP values for image
    shap.initjs()
    shap.visualize(shap_values[0])
    return shap_values

# Example usage
image_path = '/content/drive/My Drive/Colab Notebooks/public-data/image/net/test.jpg'
results = process_input(image_path)
print(results)
