In [None]:
import os
import sys
import torch
import torch.backends.cudnn as cudnn

module_path = os.path.abspath(os.path.join('../../python'))
if module_path not in sys.path:
    sys.path.append(module_path)

from cvpr2018.feature_extractor import get_features_loader
from cvpr2018.utils.utils import register_logger
from encoder.clip_encoder import ClipEncoder

In [None]:
from langchain_community.llms import Ollama
llm = Ollama(model="llama2")

In [None]:
log_every = 50  # log the writing of clips every n steps
log_file = None  # set logging file
num_workers = 4  # define the number of workers used for loading the videos

cudnn.benchmark = True
register_logger(log_file=log_file)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset_path = '/home/ubuntu/repos/llm-rag/data/Anomaly-Videos-Part-1/test'  # path to the video dataset
clip_length = 16  # define the length of each input sample
frame_interval = 1 # define the sampling interval between framesq
batch_size = 4


In [None]:
llm.invoke("how are you doing today?")

In [None]:
import clip
import numpy as np
from lavis.models import load_model_and_preprocess
from torchvision.transforms import ToPILImage
import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
import base64
from io import BytesIO

class ClipEncoder:
    def __init__(self, dataset_path, clip_length, caption_model_type, frame_interval, batch_size, num_workers):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dataset_path = dataset_path
        self.clip_length = clip_length
        self.frame_interval = frame_interval
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.chroma_client = chromadb.HttpClient(host='localhost', port=8000)
        self.data_loader, self.data_iter = get_features_loader(dataset_path,
                                                                    clip_length,
                                                                    frame_interval,
                                                                    batch_size,
                                                                    num_workers,
                                                                    "clip"
                                                                    )
        self.model, self.preprocess = clip.load("ViT-B/32", device=device)
        self.caption_model, self.vis_processors, _ = load_model_and_preprocess(name="blip_caption",
                                                                               model_type=caption_model_type,
                                                                               is_eval=True,
                                                                               device=device)
    def encode_image(self, idx):
        frame_tensor = self.data_loader[idx][0].permute(1, 0, 2, 3)
        with torch.no_grad():
            frame_embeddings = self.model.encode_image(frame_tensor.cuda())
        return frame_embeddings

    def get_all_image_embeddings(self):
        embeddings = []
        for idx in range(len(self.data_loader)):
            emb = self.encode_image(idx)
            embeddings.append(emb)
        return embeddings
    
    def export_tensor_to_np(self):
        arr = []
        for idx in range(len(self.data_loader)):
            frame_tensor = self.data_loader[idx][0].permute(1, 0, 2, 3)
            pil_image = ToPILImage()(frame_tensor[0]) 
            arr.append(np.array(pil_image))
        return arr
    
    def export_tensor_to_base64(self):
        arr = []
        for idx in range(len(self.data_loader)):
            frame_tensor = self.data_loader[idx][0].permute(1, 0, 2, 3)
            pil_image = ToPILImage()(frame_tensor[0])
            buffered = BytesIO()
            pil_image.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode()
            arr.append(img_str)
        return arr
    
    def get_captions(self):
        captions_list = []  
        for idx in range(len(self.data_loader)):  
            frame_tensor = self.data_loader[idx][0].permute(1, 0, 2, 3)
            pil_image = ToPILImage()(frame_tensor[0]) 
            image = self.vis_processors["eval"](pil_image).unsqueeze(0).to(self.device)
            generated_captions = self.caption_model.generate({"image": image})  
            captions_list.append(generated_captions)
        return captions_list

    def get_all_caption_embeddings(self, captions_list):
        # Future improvements: Maybe multiple captions per image; Think about a way how to add anomalous captions / features
        caption_embeddings = []
        if captions_list:
            for caption_set in captions_list:
                if caption_set:
                    for caption in caption_set:
                        if caption and len(caption) > 0:
                            with torch.no_grad():
                                try:
                                    caption_features = clip.tokenize(caption).to(self.device)
                                    caption_embedding = self.model.encode_text(caption_features)
                                    caption_embeddings.append(caption_embedding)
                                except Exception as e:
                                    print(f"Error encoding text for caption: {caption}")
                                    print(f"Error details: {e}")
        return caption_embeddings
    
    def generate_document_ids(self):
        document_ids = []
        for i in range(len(self.data_loader)):
            item = self.data_loader.getitem_from_raw_video(idx=i)  
            for j in range(self.clip_length):
                document_ids.append(str(item[3] + '_' + str(item[1]) + '-' + str(j)))

        batched_ids = [document_ids[i:i+clip_length] for i in range(0, len(document_ids), clip_length)]
        
        return document_ids, batched_ids
    
    def get_or_create_chroma_collection(self, collection_name, embedding_function=None, data_loader=None):
        if embedding_function:
            try:
                collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=embedding_function, data_loader=data_loader)
                return collection
            except Exception as e:
                print(f"Error creating collection: {collection_name}")
                print(f"Error details: {e}")
        else:    
            try:
                collection = self.chroma_client.get_or_create_collection(collection_name)
                return collection
            except Exception as e:
                print(f"Error creating collection: {collection_name}")
                print(f"Error details: {e}")


    def upload_embeddings_to_chroma(self, collection_name, img_data, ids, multi_modal= False, captions=None, documents=None, metadata=None):
        if multi_modal:
            if not len(img_data) == len(ids):
                raise ValueError("data and ids must have the same length")
            
            embedding_function = OpenCLIPEmbeddingFunction("ViT-H-14","laion2b_s32b_b79k" )
            data_loader = ImageLoader()
            
            collection = self.get_or_create_chroma_collection(collection_name, embedding_function, data_loader)
            print("Multi Modal Collection created")

            for frame, id_ in zip(img_data, ids):
                try:
                    collection.add(images=frame[0], metadatas=metadata, ids=id_[0])
                    
                except Exception as e:
                    print(f"Failed to add document with ID {id_}: {str(e)}")         
        
        else:
            if not len(img_data) == len(ids):
                raise ValueError("data and ids must have the same length")

            collection = self.get_or_create_chroma_collection(collection_name)

            for emb, id_ in zip(img_data, ids):
                try:
                    collection.add(documents=documents, embeddings=emb, metadatas=metadata, ids=id_)
                except Exception as e:
                    print(f"Failed to add document with ID {id_}: {str(e)}")                
            

In [None]:
normal_videos = "/home/ubuntu/repos/llm-rag/data/subset_normal"
anomalous_videos = "/home/ubuntu/repos/llm-rag/data/subset_anomalous"
normal_encoder = ClipEncoder(normal_videos, clip_length, 'base_coco', frame_interval, batch_size, num_workers)
anomalous_encoder = ClipEncoder(anomalous_videos, clip_length, 'base_coco', frame_interval, batch_size, num_workers)


In [None]:
# normal_caption_list = normal_encoder.get_captions()
# anomalous_caption_list = anomalous_encoder.get_captions()

normal_doc_ids = normal_encoder.generate_document_ids()
anomalous_doc_ids = anomalous_encoder.generate_document_ids()

In [None]:
import pandas as pd

# # Convert lists to dataframes
# df_normal = pd.DataFrame(normal_caption_list, columns=['Captions'])
# df_anomalous = pd.DataFrame(anomalous_caption_list, columns=['Captions'])

# # Write dataframes to CSV
# df_normal.to_csv('normal_captions.csv', index=False)
# df_anomalous.to_csv('anomalous_captions.csv', index=False)

In [None]:
# read csv to list of captions

df_normal = pd.read_csv('normal_captions.csv')
df_anomalous = pd.read_csv('weakly_labeled_anomalous_captions.csv', header=None)

normal_caption_list = df_normal['Captions'].tolist()
weakly_labeled_anomalous_caption_list = df_anomalous[0].tolist()

#weakly_labeled_anomalous_caption_list = [str(anomalous_encoder.data_loader.getitem_from_raw_video(idx=i)[2]) + ' ' + caption for i, caption in enumerate(anomalous_caption_list)]

#print(anomalous_encoder.data_loader.getitem_from_raw_video(idx=0)[2] )

In [None]:
df_weakly_labeled_anomalous = pd.DataFrame(weakly_labeled_anomalous_caption_list, columns=['Captions'])
df_weakly_labeled_anomalous.to_csv('weakly_labeled_anomalous_captions.csv', index=False, header=False)

In [None]:
print(len(normal_doc_ids[1]))
print(len(weakly_labeled_anomalous_caption_list))

print(normal_caption_list[0])

In [None]:
from langchain_community.llms import Ollama
llm = Ollama(model="llama2")

In [None]:
# First chroma retriever
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.documents import Document
from langchain.storage import LocalFileStore
from langchain.retrievers.multi_vector import MultiVectorRetriever
from pathlib import Path


embeddings = OllamaEmbeddings()
vectore_store = Chroma(collection_name='text_summary_db', persist_directory='/home/ubuntu/chroma_db/', embedding_function=embeddings)
root_path = Path.cwd() / "data" / "doc_store_text_summary"

# The storage layer for the parent documents
store = LocalFileStore(root_path)
id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectore_store,
    docstore=store,
    id_key=id_key,
)

normal_summary_texts = [
    Document(page_content=caption, metadata={id_key: doc_id})
    for i, caption in enumerate(normal_caption_list)
    for doc_id in normal_doc_ids[1][i]
]

anomalous_summary_texts = [
    Document(page_content=caption, metadata={id_key: doc_id})
    for i, caption in enumerate(weakly_labeled_anomalous_caption_list)
    for doc_id in anomalous_doc_ids[1][i]
]

In [None]:
retriever.vectorstore.add_documents(normal_summary_texts)
retriever.vectorstore.add_documents(anomalous_summary_texts)


In [None]:
# retriever.docstore.mset(list(zip(normal_doc_ids[0], normal_summary_texts))) # Eventually run llava for long description and add to docstore
# retriever.docstore.mset(list(zip(anomalous_doc_ids[0], anomalous_summary_texts))) 

import pickle

retriever.docstore.mset([(doc_id, pickle.dumps(doc)) for doc_id, doc in zip(normal_doc_ids[0], normal_summary_texts)])
retriever.docstore.mset([(doc_id, pickle.dumps(doc)) for doc_id, doc in zip(anomalous_doc_ids[0], anomalous_summary_texts)])

In [None]:
retriever.get_relevant_documents(
    "traffic"
)[:10]

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

# Prompt template
template = """Answer the question based only on the following context, which are short descriptions of videos:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

# Option 1: LLM
# Option 2: Multi-modal LLM
# model = GPT4-V or LLaVA

# RAG pipeline
chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
chain.invoke(
    "Which document shows people walking next to lockers"
)

In [None]:
chain.invoke(
    "How many videos  show indoor scenes?"
)

In [None]:
#normal_encoder.chroma_client.delete_collection('text_summary_db')