<a href="https://colab.research.google.com/github/michaelachmann/social-media-lab/blob/main/notebooks/2024_01_19_Classification_With_CLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classification using CLIP [![DOI](https://zenodo.org/badge/660157642.svg)](https://zenodo.org/badge/latestdoi/660157642)
![Notes on (Computational) Social Media Research Banner](https://raw.githubusercontent.com/michaelachmann/social-media-lab/main/images/banner.png)

## Overview

This Jupyter notebook is a part of the social-media-lab.net project, which is a work-in-progress textbook on computational social media analysis. The notebook is intended for use in my classes.

The **Classification using CLIP** Notebook uses the [CLIP](https://openai.com/research/clip) neural network for zero-shot image classification tasks. I recommend to **run this notebook using a GPU**.

### Project Information

- Project Website: [social-media-lab.net](https://social-media-lab.net/)
- GitHub Repository: [https://github.com/michaelachmann/social-media-lab](https://github.com/michaelachmann/social-media-lab)

## License Information

This notebook, along with all other notebooks in the project, is licensed under the following terms:

- License: [GNU General Public License version 3.0 (GPL-3.0)](https://www.gnu.org/licenses/gpl-3.0.de.html)
- License File: [LICENSE.md](https://github.com/michaelachmann/social-media-lab/blob/main/LICENSE.md)


## Citation

If you use or reference this notebook in your work, please cite it appropriately. Here is an example of the citation:

```
Michael Achmann. (2024). michaelachmann/social-media-lab: 2024-1-22 (v0.0.10). Zenodo. https://doi.org/10.5281/zenodo.8199901
```

Import the visual data.

In [None]:
!unzip /content/drive/MyDrive/2024-01-19-AfD-Stories-Exported.zip

Load the model and dependencies. The CLIP classification implementation was inspired by [this medium story](https://medium.com/@JettChenT/image-classification-with-openai-clip-3ab5f1c23e35)

In [None]:
!pip -q install ftfy regex tqdm
!pip -q install git+https://github.com/openai/CLIP.git
!pip -q install pyarrow

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.5/53.4 kB[0m [31m629.8 kB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━[0m [32m41.0/53.4 kB[0m [31m522.3 kB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m543.9 kB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clip (setup.py) ... [?25l[?25hdone


In [None]:
import os
import torch
import clip
import numpy as np
import pandas as pd
import pyarrow.feather as feather
from PIL import Image
from urllib import request
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
device

100%|████████████████████████████████████████| 338M/338M [00:03<00:00, 114MiB/s]


'cuda'

In [None]:
df = pd.read_csv('/content/drive/MyDrive/2024-01-19-AfD-Stories-Exported.csv')

In [None]:
df = df[df['Username'] == "afd.bund"]

In [None]:
df['Image'] = df.apply(lambda row: f"/content/media/images/{row['Username']}/{row['ID']}.jpg", axis=1)

In [None]:
df.head()

Unnamed: 0.1,Unnamed: 0,ID,Time of Posting,Type of Content,video_url,image_url,Username,Video Length (s),Expiration,Caption,Is Verified,Stickers,Accessibility Caption,Attribution URL,Image
1,1,2125373886060513565_1484534097,2019-09-04 08:05:27,Image,,,afd.bund,,2019-09-05 08:05:27,,True,[],Photo by Alternative für Deutschland on Septem...,,/content/media/images/afd.bund/212537388606051...
2,2,2125374701022077222_1484534097,2019-09-04 08:07:04,Image,,,afd.bund,,2019-09-05 08:07:04,,True,[],Photo by Alternative für Deutschland on Septem...,,/content/media/images/afd.bund/212537470102207...
3,3,2490851226217175299_1484534097,2021-01-20 14:23:30,Image,,,afd.bund,,2021-01-21 14:23:30,,True,[],Photo by Alternative für Deutschland on Januar...,,/content/media/images/afd.bund/249085122621717...
4,4,2600840011884997131_1484534097,2021-06-21 08:31:45,Image,,,afd.bund,,2021-06-22 08:31:45,,True,[],Photo by Alternative für Deutschland on June 2...,,/content/media/images/afd.bund/260084001188499...
5,5,2600852794831609459_1484534097,2021-06-21 08:57:09,Image,,,afd.bund,,2021-06-22 08:57:09,,True,[],Photo by Alternative für Deutschland in Berlin...,,/content/media/images/afd.bund/260085279483160...


In [None]:
classification_dict = {
    "Collages": [
        "A screenshot with multiple visual elements such as text, graphics, and images combined.",
    ],
    "Campaign Material": [
        "An image primarily showcasing election-related flyers, brochures, or handouts.",
        "A distinct promotional poster for a political event or campaign.",
        "Visible printed material urging people to vote or join a political cause."
    ],
    "Political Events": [
        "An image distinctly capturing the essence of a political campaign event.",
        "A location set for a political event, possibly without a crowd.",
        "A large assembly of supporters or participants at an open-air political rally.",
        "Clear visuals of a venue set for a significant political gathering or convention.",
        "Focused visuals of attendees or participants of a political rally or event.",
        "Inside ambiance of a political convention or major political conference.",
        "Prominent figures or speakers on stage addressing a political audience.",
        "A serene image primarily focused on landscapes, travel.",
        "Food, beverages, or generic shots."
    ],
    "Individual Contact": [
        "A politician genuinely engaging or interacting with individuals or small groups.",
        "A close-up or selfie, primarily showcasing an individual, possibly with political affiliations.",
        "An informal or candid shot with emphasis on individual engagement, perhaps in a political setting."
    ],
    "Interviews & Media": [
        "An indoor setting, well-lit, designed for professional media interviews or broadcasts.",
        "Clear visuals of an interviewee in a controlled studio environment.",
        "Microphone or recording equipment predominantly in front of a speaker indoors.",
        "Behind-the-scenes ambiance of a media setup or broadcast preparation.",
        "Visuals from a TV or media broadcast, including distinct channel or media branding.",
        "Significant media branding or logos evident, possibly during an interview or panel discussion.",
        "Structured indoor setting of a press conference or media event with multiple participants."
    ],
    "Social Media Moderation": [
        "Face-centric visual with the individual addressing or connecting with the camera.",
        "Emphasis on facial features, minimal background distractions, typical of online profiles.",
        "Portrait-style close-up of a face, without discernible logos, graphics, or overlays."
    ],
}

In [None]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from PIL import Image

# Assuming preprocess, clip model, and device are already initialized

def classify_images_with_clip(image_files, classification_dict, column_name, BATCH_SIZE=500):
    labels_map, flat_labels = flatten_classification_dict(classification_dict)
    text = clip.tokenize(flat_labels).to(device)

    results = []
    for batch_start in tqdm(range(0, len(image_files), BATCH_SIZE)):
        batch_end = batch_start + BATCH_SIZE
        batch_files = image_files[batch_start:batch_end]
        images = preprocess_images(batch_files)
        if not images:
            continue
        image_input = torch.tensor(np.stack(images)).to(device)

        logits_per_image = model_inference(image_input, text)
        update_results(logits_per_image, batch_files, flat_labels, labels_map, results, column_name)

    return pd.DataFrame(results)

def flatten_classification_dict(classification_dict):
    labels_map = {}
    flat_labels = []
    for category, phrases in classification_dict.items():
        for phrase in phrases:
            flat_labels.append(phrase)
            labels_map[phrase] = category
    return labels_map, flat_labels

def preprocess_images(image_files):
    images = []
    for img_file in image_files:
        try:
            image = preprocess(Image.open(img_file))
            images.append(image)
        except IOError:
            print(f"Error loading image: {img_file}")
    return images

def model_inference(image_input, text):
    with torch.no_grad():
        logits_per_image, _ = model(image_input, text)
        return logits_per_image.softmax(dim=-1).cpu().numpy() * 100

def update_results(logits_per_image, batch_files, flat_labels, labels_map, results, column_name):
    max_indices = np.argsort(logits_per_image, axis=1)[:, -2:]
    for idx, (file, top_indices) in enumerate(zip(batch_files, max_indices)):
        result = {"Image": file}
        for rank, label_idx in enumerate(top_indices[::-1], 1):
            label = flat_labels[label_idx]
            category = labels_map[label]
            prob = logits_per_image[idx, label_idx].round(2)
            result.update({
                f"{column_name}_{rank}": category,
                f"{column_name}_label_{rank}": label,
                f"{column_name}_prob_{rank}": prob
            })
        results.append(result)

def update_results(logits_per_image, batch_files, flat_labels, labels_map, results, column_name):
    max_indices = np.argmax(logits_per_image, axis=1)
    for idx, (file, top_index) in enumerate(zip(batch_files, max_indices)):
        label = flat_labels[top_index]
        category = labels_map[label]
        prob = logits_per_image[idx, top_index].round(2)  # Fixed probability extraction

        result = {
            "Image": file,
            f"{column_name}": category,
            f"{column_name} Label": label,
            f"{column_name} Probability": prob
        }
        results.append(result)




In [None]:
import os


image_files = df['Image'].unique()

# Perform the classification and get the results as a DataFrame
classified_df = classify_images_with_clip(image_files, classification_dict, 'Classification')

  0%|          | 0/1 [00:00<?, ?it/s]

Error loading image: /content/media/images/afd.bund/2632909594311219564_1484534097.jpg
Error loading image: /content/media/images/afd.bund/2637169242765597715_1484534097.jpg
Error loading image: /content/media/images/afd.bund/2637310044636651340_1484534097.jpg
Error loading image: /content/media/images/afd.bund/2640856259194124126_1484534097.jpg
Error loading image: /content/media/images/afd.bund/2643802824089930195_1484534097.jpg
Error loading image: /content/media/images/afd.bund/2653863205891438589_1484534097.jpg
Error loading image: /content/media/images/afd.bund/2664113842957989541_1484534097.jpg
Error loading image: /content/media/images/afd.bund/2671444844831156334_1484534097.jpg


100%|██████████| 1/1 [00:04<00:00,  4.17s/it]


In [None]:
classified_df.head()

Unnamed: 0,Image,Classification,Classification Label,Classification Probability
0,/content/media/images/afd.bund/212537388606051...,Political Events,Focused visuals of attendees or participants o...,26.78125
1,/content/media/images/afd.bund/212537470102207...,Interviews & Media,"Visuals from a TV or media broadcast, includin...",71.3125
2,/content/media/images/afd.bund/249085122621717...,Social Media Moderation,"Emphasis on facial features, minimal backgroun...",29.21875
3,/content/media/images/afd.bund/260084001188499...,Interviews & Media,Clear visuals of an interviewee in a controlle...,79.625
4,/content/media/images/afd.bund/260085279483160...,Interviews & Media,Clear visuals of an interviewee in a controlle...,48.0


In [None]:
df = pd.merge(df, classified_df, on="Image", how="left")

In [None]:
df.head()

In [None]:
#@title Qualitative Evaluation #1: Display Images
#@markdown Running this cell creates a visual overview of the classification results. *n* images are sampled and displayed per group. The qualitative evaluation **does not replace proper external validation!**
#@markdown The overview is saved to `{current_date}-CLIP-Classification.html`. Download the file and open it in your browser for a better layout.

import pandas as pd
from IPython.display import HTML, display
import datetime
import base64

#@markdown Select the sample size per Class
sample_size = 1 # @param {type: "slider", min: 1, max: 25}

def get_base64_encoded_image(image_path):
    try:
        with open(image_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode('utf-8')
    except IOError:
        print(f"Error loading image: {image_path}")
        return None

def create_html_card(row, label_column, image_type_column, probability_column):
    base64_image = get_base64_encoded_image(row['Image'])
    if base64_image is None:
        return ""

    return f"""
    <div class='col-lg-2 col-md-4 col-sm-6 mb-4'>
        <div class='card h-100'>
            <img src="data:image/jpeg;base64,{base64_image}" class="card-img-top">
            <div class='card-body'>
                <p class='card-text'><strong>🤖 {row[image_type_column]}</strong></p>
                <p class='card-text'>💬 {row[label_column]}</p>

            </div>
            <div class="card-footer">
                <small class="text-muted">🎯: {row[probability_column]}</small>
            </div>
        </div>
    </div>
    """

def generate_html(df, sample_size, label_column, image_type_column, probability_column):
    bootstrap_link = ('<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/css/bootstrap.min.css" integrity="sha384-rbsA2VBKQhggwzxH7pPCaAqO46MgnOM80zW1RWuH61DGLwZJEdK2Kadq2F9CUG65" crossorigin="anonymous">')

    sample = df.groupby(image_type_column).sample(n=sample_size, replace=True)
    html_content = [bootstrap_link, "<div class='container-fluid mt-3'>"]

    for image_type, group_df in sample.groupby(image_type_column):
        html_content.append(f"<h2>{image_type}</h2>")
        html_content.append("<div class='row row-cols-1 row-cols-md-6 g-10'>")
        html_content.extend([create_html_card(row, label_column, image_type_column, probability_column) for _, row in group_df.iterrows()])
        html_content.append("</div>")

    html_content.append("</div>")
    return "\n".join(html_content)

# Define column names
image_type_column = 'Classification'
label_column =  f"{image_type_column} Label"
probability_column = f"{image_type_column} Probability"

# Generate and display HTML
final_html = generate_html(df, sample_size, label_column, image_type_column, probability_column)
display(HTML(final_html))

# Save to an HTML file
current_date = datetime.datetime.now().strftime('%Y-%m-%d')
html_file_name = f"{current_date}-CLIP-Classification-Result.html"

with open(html_file_name, "w") as file:
    file.write(final_html)


In [None]:
from IPython.display import display, clear_output, HTML
from PIL import Image as PILImage
from io import BytesIO
import pandas as pd
import ipywidgets as widgets
import base64

# Assuming df is your dataframe with the images, labels, and probabilities
df = df.sample(frac=1).reset_index(drop=True)
image_type_column = 'Classification'  # Make sure this is the correct column name
label = f"{image_type_column} Label"
prob = f"{image_type_column} Probability"

idx = 0
out = widgets.Output()
df['correct'] = None

def show_counts():
    value_counts = df["correct"].value_counts()
    print("Number of correct labels:", value_counts.get(True, 0))
    print("Number of incorrect labels:", value_counts.get(False, 0))

def on_button_click(is_correct):
    global idx
    df.at[idx, "correct"] = is_correct
    idx += 1
    if idx < len(df):
        show_next_image()
    else:
        clear_output(wait=True)
        print("No more images.")
        show_counts()

def display_encoded_image(image_path):
    try:
        pil_image = PILImage.open(image_path)
        bio = BytesIO()
        pil_image.save(bio, 'PNG')
        encoded_image = base64.b64encode(bio.getvalue()).decode('utf-8')
        display(HTML(f'<img src="data:image/png;base64,{encoded_image}" style="max-height: 90vh;">'))
    except Exception as e:
        print(f"Error displaying image: {e}")

def show_next_image():
    clear_output(wait=True)
    with out:
        row = df.iloc[idx]
        display_encoded_image(row["Image"])
        display(HTML(f'<h3 style="color: blue;">Label: {row[label]}</h3>'))
        display(HTML(f'<h3 style="color: blue;">Type: {row[image_type_column]}</h3>'))
        display(HTML(f'<h4 style="color: green;">Probability: {row[prob]:.2f}</h4>'))
        display(widgets.HBox([correct_button, wrong_button]))
    display(out)
    show_counts()

# Button event handlers
def correct_click(b):
    on_button_click(True)

def wrong_click(b):
    on_button_click(False)

# Button setup
button_layout = widgets.Layout(width='100px', height='40px', margin='5px 10px 5px 10px')
correct_button = widgets.Button(description="Correct", layout=button_layout, button_style='success')
correct_button.on_click(correct_click)

wrong_button = widgets.Button(description="Wrong", layout=button_layout, button_style='danger')
wrong_button.on_click(wrong_click)

show_next_image()
