# Notebook 01 - Prototyping  
Author: Ngo Van Anh Kiet (nvakiet)

### 1. Problem description  
>Given a list of images or texts as inputs, build a machine learning solution to generate top N topics that summarize the information. A topic may not be any keyword in the given text.  
>
>Tech stack requirements:  
>- Framework: PyTorch or Tensorflow  
>- API: FastAPI  
>- Frontend: Streamlit  
>- Database: SQL or NoSQL  

From a glance, this seems like an NLP-focused problem. After spending 1.5 days researching solutions from various papers, I know that this is a problem domain called Topic Modelling. Additionally, since our inputs include images and texts, this is also a multimodal problem. And because a topic may not be in the input text, I think zero-shot learning models are needed.  

Here are some key points that I got from the above observation:  
- Topic modelling is an unsupervised learning problem, meaning it will be hard to evaluate the system performance.  
- We can leverage some pre-trained models to handle multimodal inputs and zero-shot learning, the most important thing to train here is the model for topic modelling.  

### 2. Solution ideas  
As of now, there are some well-known Topic Modelling algorithms for textual data, mainly: LDA (Latent Dirichlet Allocation) and BERTopic. LDA is a traditional machine learning algorithm while BERTopic is a deep learning solution based on Transformer models. After reading through the documentation of both these algorithms, I decided to go with BERTopic because its framework is more flexible for switch out components and with the ability to use feature embeddings, it seems to generalize better than LDA with more coherent topic outputs. Additionally, from the documentation, BERTopic doesn't need much data preprocessing and can be customize for multi-modal tasks.  

These are some ideas I have for the solution. I think that to resolve all the problem requirements, a multi-model approach is more intuitive than an end-to-end approach. (In case the drawings are not showing, they are kept in "reports/figures")

Idea 1
![Topic Extraction Idea 1](../reports/figures/topic_extraction_idea_1.png)  
  
Idea 2
![Topic Extraction Idea 2](../reports/figures/topic_extraction_idea_2.png)  
  
Idea 3
![Topic Extraction Idea 3](../reports/figures/topic_extraction_idea_3.png)  
  


The first solution idea was easy to come up with. However, I think it relies too much on the captioning results of the images, so it can potentially cause loss of contextual information from the whole image.  

In the second idea, I tried to encode both images and texts to a feature matrix in the same vector space. Then use BERTopic to cluster the feature vectors in order to extract the topics. But the problem is how can I get back the textual topic representation from image feature embeddings in each topic. If I tried to calculate cosine similarity of each image feature vector with the feature vectors of a vocabulary collection, I may be able to find the top best words to classify the image. But I think that would be no different than doing image recognition or captioning like the first idea.  

The third idea is where I tried to incorporate the first and second ideas. For the image set, I use OpenAI's CLIP model for image feature embedding. CLIP model can also be used for text feature embedding, but I notice that it only supports up to maximum sequence length of 77 tokens because it was mainly trained on an image captioning dataset. If I use CLIP for medium or long documents, the information loss from token truncation may be too much. So I use a different model to embed the documents. A PCA model is used to reduce the dimensionality of the image embeddings to the same size as the text embeddings. Then I concatenate both feature matrices together and put it through BERTopic. For the topic representation of the image topics, I would generate a caption for each image then add them to the document set for BERTopic to fit/transform on. Finally, I think we can use a Zero-shot Text Classification model to perform topic label classification without relying on existing keywords in the input data, then aggregate and filter for the top N topics. Although I'm not sure how efficient the current pretrained Zero-shot Text Classification models perform for inference.

### 3. Prototyping

In [None]:
# Initialize some barebone folder structure on Colab
!mkdir -p reports/figures
!mkdir -p data/raw

In [None]:
!pip install bertopic # for running on colab

In [None]:
!nvidia-smi

In [None]:
# Utilities
import os
import glob
import zipfile
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import pickle as pkl
import random

# Core frameworks
# Based on Torch and TF
import torch
from sentence_transformers import SentenceTransformer, util
from bertopic import BERTopic
from transformers import pipeline
# Traditional ML
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import PCA

# Datasets
from sklearn.datasets import fetch_20newsgroups
import nltk

In [None]:
# global variables
BATCH_SIZE = 32
DATA_DIR = "../data/"
RAW_DATA_DIR = DATA_DIR + "raw/"
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Flickr 8k images
img_folder = RAW_DATA_DIR + 'photos/'
caps_folder = RAW_DATA_DIR + 'captions/'
img_zip_path = img_folder + 'Flickr8k_Dataset.zip'
caps_zip_path = caps_folder + 'Flickr8k_text.zip'
img_dataset_path = img_folder + "Flicker8k_Dataset/"

if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
    os.makedirs(img_folder, exist_ok=True)

    if not os.path.exists('Flickr8k_Dataset.zip'):   #Download dataset if does not exist
        util.http_get('https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip', img_zip_path)
        util.http_get('https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip', caps_zip_path)

    for folder, file in [(img_folder, img_zip_path), (caps_folder, caps_zip_path)]:
        with zipfile.ZipFile(file, 'r') as zf:
            for member in tqdm(zf.infolist(), desc='Extracting'):
                zf.extract(member, folder)
                
images = list(glob.glob(img_dataset_path + '*.jpg'))


In [None]:
# Prepare dataframe
captions = pd.read_csv(caps_folder + "Flickr8k.lemma.token.txt", sep='\t', names=["img_id","img_caption"])
captions.img_id = captions.apply(lambda row: img_dataset_path + row.img_id.split(".jpg")[0] + ".jpg", 1)
captions = captions.groupby(["img_id"])["img_caption"].apply(','.join).reset_index()
captions = pd.merge(captions, pd.Series(images, name="img_id"), on="img_id")

# Extract images together with their documents/captions
images = captions.img_id.to_list()
images_captions = captions.img_caption.to_list()

In [None]:
img_model = SentenceTransformer('clip-ViT-B-32')

In [None]:
# Prepare images
nr_iterations = int(np.ceil(len(images) / BATCH_SIZE))

# Embed images per batch
img_embeddings = []
for i in tqdm(range(nr_iterations)):
    start_index = i * BATCH_SIZE
    end_index = (i * BATCH_SIZE) + BATCH_SIZE

    images_to_embed = [Image.open(filepath) for filepath in images[start_index:end_index]]
    
    img_emb = img_model.encode(images_to_embed, show_progress_bar=False)
    img_embeddings.extend(img_emb.tolist())

    # Close images
    for image in images_to_embed:
        image.close()
        
img_embeddings = np.array(img_embeddings)

In [None]:
print(len(images))
print(len(images_captions))
print(img_embeddings.shape)

In [None]:
# Perform PCA to reduce image embedding dimensionality to 384
pca = PCA(n_components=384)
img_embeddings = pca.fit_transform(img_embeddings)
print(img_embeddings.shape)

In [None]:
# Fetch news dataset
news_docs = fetch_20newsgroups(subset='train',  remove=('headers', 'footers', 'quotes'))['data']

In [None]:
# Create text embeddings
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
text_embeddings = sentence_model.encode(news_docs, show_progress_bar=True)
text_embeddings = np.array(text_embeddings)
print(text_embeddings.shape)

In [None]:
# Concatenate image and text embeddings
final_embeddings = np.concatenate([img_embeddings, text_embeddings])
print(final_embeddings.shape)

In [None]:
# Concatenate image captions and documents
docs = images_captions + news_docs
print(len(docs))

In [None]:
# Fit BERTopic model
vectorizer_model = CountVectorizer(stop_words="english")
topic_model = BERTopic(calculate_probabilities=False, 
                       low_memory=True, 
                       n_gram_range=(1,3), 
                       diversity=0.5, 
                       vectorizer_model=vectorizer_model)
topic_model.fit_transform(docs, final_embeddings)

In [None]:
# Check results
topic_model.get_topic_info()

It seems like there are some topics with weird gibberish words. Let's inspect the documents of those topics.

In [None]:
topic_names = topic_model.generate_topic_labels(1, False, None, "-")[1:]
weird_topic_index = topic_names.index("maxaxaxaxaxaxaxaxaxaxaxaxaxaxax")
weird_topic = topic_model.get_topic(weird_topic_index)
weird_topic

In [None]:
topic_model.get_representative_docs(weird_topic_index)

The gibberish topics are because the input documents themselves. While BERTopic can handle input data with little preprocessing, it seems like some data cleaning should still be done before passing to the model. However, since there's not many topics like that and the rank of this weird topic is pretty low (over 100), I will consider it an outlier for now.  

In [None]:
# Reduce outliers and update topic frequency & representations
new_topics = topic_model.reduce_outliers(docs, topics)
topic_model.update_topics(docs, topics=new_topics)
# After testing out a bit I think this makes the result slightly worse
# Maybe due to a lack of hyperparameter tuning
# If there's not many outlier topics, I think this isn't needed

In [None]:
documents = pd.DataFrame({"Document": docs, "Topic": new_topics})
topic_model._update_topic_size(documents)

In [None]:
# Check the topic info again
topic_model.get_topic_info()

Although I would like to perform some evaluation on these models, there isn't much time left, most models used here are pretrained and BERTopic itself is unsupervised learning so there's no well-defined way to measure the topic extraction results.

In [None]:
!mkdir models # for running on Colab

In [None]:
# Save fitted PCA and BERTopic models
with open("models/pca_384_v1.0.pkl","wb") as pklFile:
  pkl.dump(pca, pklFile)

In [None]:
topic_model.save("models/bertopic_flickr8k_20newsgroups_pre_embed_v1.0")

In [None]:
# Clean up used resources from previous phases
img_embeddings = None
text_embeddings = None
final_embeddings = None
img_model = None
sentence_model = None
docs = None
news_docs = None
topics = None
probs = None
new_topics = None
documents = None
captions = None

In [None]:
import gc
gc.collect()

Now we try to classify the topic labels using a pretrained zero-shot classification model. (EXPERIMENTAL: THIS PART MAY CRASH THE KERNEL)

In [None]:
nltk.download("wordnet")
nltk.download('omw-1.4')
from nltk.corpus import wordnet as wn

In [None]:
# Try to classify the topic labels using words from an English dictionary
all_nouns = [word for synset in wn.all_synsets('n') for word in synset.lemma_names() if "_" not in word]
selected_nouns = random.sample(all_nouns, 100)
all_nouns = None

In [None]:
classifier = pipeline("zero-shot-classification", model='valhalla/distilbart-mnli-12-1', device=0)

In [None]:
topic_info = topic_model.get_topic_info()[1:11]
topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, separator=" ")[1:11]

In [None]:
topic_labels

In [None]:
results = classifier(topic_labels, selected_nouns)
topic_labels = [result["labels"][0] for result in results]
topic_labels

After experimenting around, I think Zero-Shot Classification as this point is too slow. Most available pretrained models are so large that even running them on the cloud can cause memory overload. It took over 5 minutes just to infer the topic labels for 10 topics, with a vocabulary size of 100 only. Currently, I don't think using Zero-Shot Classification to create topic labels outside of the input data is practical.

Now we try to do inference using BERTopic on new data

In [None]:
# Load the models
img_model = SentenceTransformer('clip-ViT-B-32')
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
bertopic = BERTopic.load("models/bertopic_flickr8k_20newsgroups_pre_embed_v1.0")
image_captioning = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning", device=0)
with open("models/pca_384_v1.0.pkl","rb") as pklFile:
  pca_model = pkl.load(pklFile)

In [None]:
# Fetch test news document and test images
# Download Coco128 dataset from Kaggle manually and put in data/raw
test_img_zip_path = RAW_DATA_DIR + "coco128.zip"
test_img_zip_dir = RAW_DATA_DIR + "coco128/"
with zipfile.ZipFile(test_img_zip_path, 'r') as zf:
  for member in tqdm(zf.infolist(), desc='Extracting'):
      zf.extract(member, test_img_zip_dir)


In [None]:
test_images = list(glob.glob(test_img_zip_dir + "coco128/images/train2017/" + '*.jpg'))
test_news = fetch_20newsgroups(subset='test',  remove=('headers', 'footers', 'quotes'))['data']

In [None]:
# Create captions for each image
captions = image_captioning(test_images)
captions = [result[0]["generated_text"] for result in captions]

In [None]:
# Redo the feature embedding for images and texts
# Prepare images
nr_iterations = int(np.ceil(len(test_images) / BATCH_SIZE))

# Embed images per batch
img_embeddings = []
for i in tqdm(range(nr_iterations)):
    start_index = i * BATCH_SIZE
    end_index = (i * BATCH_SIZE) + BATCH_SIZE

    images_to_embed = [Image.open(filepath) for filepath in test_images[start_index:end_index]]
    
    img_emb = img_model.encode(images_to_embed, show_progress_bar=False)
    img_embeddings.extend(img_emb.tolist())

    # Close images
    for image in images_to_embed:
        image.close()
        
img_embeddings = np.array(img_embeddings)

In [None]:
# PCA on image embeddings
img_embeddings = pca_model.transform(img_embeddings)

In [None]:
# Create text embeddings
text_embeddings = sentence_model.encode(test_news, show_progress_bar=True)
text_embeddings = np.array(text_embeddings)

In [None]:
# Concatenate everything together for inference on BERTopic
final_embeddings = np.concatenate([img_embeddings, text_embeddings])
docs = captions + test_news

In [None]:
topics, _ = bertopic.transform(docs, final_embeddings)
topic_labels = bertopic.generate_topic_labels(nr_words=3, topic_prefix=False, separator=" ")

In [None]:
df = pd.DataFrame({"Topic": [topic_labels[t + 1] for t in topics]})
df = df.groupby(["Topic"], sort=False)["Topic"].agg(Count="count").sort_values("Count", ascending=False).reset_index()

In [None]:
df[1:11] # Get top 10 topics from test data

In [None]:
bertopic.get_topic_info()[1:11] # Check with the total set of topics in the model

### References  
1. BERTopic Documentation - https://maartengr.github.io/BERTopic/index.html#example  
2. HuggingFace CLIP Model - https://huggingface.co/sentence-transformers/clip-ViT-B-32  
3. HuggingFace Image Captioning Model - https://huggingface.co/nlpconnect/vit-gpt2-image-captioning  
4. HuggingFace Zero-Shot Classification Pipeline - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.ZeroShotClassificationPipeline