In [3]:
import bs4
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import HTMLHeaderTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings 
from langchain_chroma import Chroma 
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain 
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# Step 1: Load and split blog content
loader = WebBaseLoader(web_path="https://blog.mdturp.ch/posts/2024-04-05-visual_guide_to_vision_transformer.html", 
                       bs_kwargs=dict(parse_only=bs4.SoupStrainer(
                           class_=("main")
                       )))
documents = loader.load()

USER_AGENT environment variable not set, consider setting it to identify your requests.


In [4]:
documents

[Document(metadata={'source': 'https://blog.mdturp.ch/posts/2024-04-05-visual_guide_to_vision_transformer.html'}, page_content="A Visual Guide to Vision Transformers \u200bThis is a visual guide to Vision Transformers (ViTs), a class of deep learning models that have achieved state-of-the-art performance on image classification tasks. Vision Transformers apply the transformer architecture, originally designed for natural language processing (NLP), to image data. This guide will walk you through the key components of Vision Transformers in a scroll story format, using visualizations and simple explanations to help you understand how these models work and how the flow of the data through the model looks like.Translations \u200bLanguageLinkTranslated by🇰🇷 KoreanLinkJunghwan ParkPlease enjoy and start scrolling!0) Lets start with the data \u200bLike normal convolutional neural networks, vision transformers are trained in a supervised manner. This means that the model is trained on a datase

In [5]:
url = "https://blog.mdturp.ch/posts/2024-04-05-visual_guide_to_vision_transformer.html"

headers_to_split_on = [
    ("h1", "Header 1"),
    ("h2", "Header 2"),
    ("h3", "Header 3"),
    ("h4", "Header 4")
]

html_splitter= HTMLHeaderTextSplitter(headers_to_split_on)
split_docs= html_splitter.split_text_from_url(url)

split_docs

[Document(metadata={}, page_content='[ ] [ ] [ ]  \nSkip to content  \n[ ] [ MDTURP ] [ ]  \n[ ] [ ]  \n[ [ ] [ ] ]  \nMain Navigation  \n[ ]  \nHome  \n[ ]  \nAbout  \n[ ]  \n[ ]  \nGitHub  \nX  \nLinkedIn  \n[ [ ] ]  \nAppearance  \n[ ]  \n[ ]  \nGitHub  \nX  \nLinkedIn  \nReturn to top  \n[ ] [ ]  \n[ ] [ ] [ ] [ ] [ ] [ ]  \nOn this page  \nTable of Contents for current page  \n[ ]  \n[ ] [ ]'),
 Document(metadata={'Header 1': 'A Visual Guide to Vision Transformers'}, page_content='A Visual Guide to Vision Transformers'),
 Document(metadata={'Header 1': 'A Visual Guide to Vision Transformers'}, page_content='â\x80\x8b  \nThis is a visual guide to Vision Transformers (ViTs), a class of deep learning models that have achieved state-of-the-art performance on image classification tasks. Vision Transformers apply the transformer architecture, originally designed for natural language processing (NLP), to image data. This guide will walk you through the key components of Vision Transforme

In [6]:
from dotenv import load_dotenv
import os
from langchain_community.embeddings import OllamaEmbeddings
load_dotenv()

True

In [7]:
ollama_embeddings = OllamaEmbeddings(model="llama3.2")
db = FAISS.from_documents(split_docs, ollama_embeddings)

  ollama_embeddings = OllamaEmbeddings(model="llama3.2")


In [8]:
query = "what is the effective way to train our model"
similarity = db.similarity_search(query)

similarity

[Document(id='197a89ff-a18a-44de-9874-4cc69ab8f33d', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': '4) Flatting of the images patches'}, page_content='4) Flatting of the images patches'),
 Document(id='9b42612c-7c99-47f2-b1ce-bcd190a8ffc9', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': '5) Creating patch embeddings'}, page_content='5) Creating patch embeddings'),
 Document(id='cdc24a1f-7854-416e-b45d-b1578ea50a74', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': '6) Embedding all patches'}, page_content='6) Embedding all patches'),
 Document(id='9a028791-c82a-4dfc-9dc0-4b39b0fd277c', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': '12) Identify Classification token output'}, page_content='12) Identify Classification token output')]

In [9]:
llm = OllamaEmbeddings(model="llama3.2")
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)


  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)


In [10]:
from langchain_community.chat_models import ChatOllama
from langchain_community.embeddings import OllamaEmbeddings

ollama_embeddings = OllamaEmbeddings()  
llm = ChatOllama(model="llama3.2")        


  llm = ChatOllama(model="llama3.2")


In [11]:
retriever = db.as_retriever()

In [12]:
from langchain.chains import create_history_aware_retriever


In [13]:
from langchain_core.prompts import MessagesPlaceholder

In [14]:
## Prompt Template 
system_prompt = (
    "You are a highly capable AI assistant specializing in answering questions based on provided context."
    "Use only the given context to answer the question, and do not rely on prior knowledge or assumptions."
    "If the answer is not present in the context, clearly state, 'I don\'t know.'"
    "Keep your response concise, limited to a maximum of three sentences, and ensure clarity and relevance to the question."
    "\n\n"
    "{context}"
)

In [15]:
prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    ("human", "{input}")
])

question_answer_chain = create_stuff_documents_chain(llm, prompt)

rag_chain = create_retrieval_chain(retriever, question_answer_chain)

response = rag_chain.invoke({"input":"What is the purpose of Vision Transformers?"})

#for index, res in enumerate(response, 1):
   # print(f"retrieval{index}", res["answer"])
   
print(response)
print(response.get("answer"))   

{'input': 'What is the purpose of Vision Transformers?', 'context': [Document(id='92b59d29-8996-4653-a1c9-67652b0b04ad', metadata={'Header 1': 'A Visual Guide to Vision Transformers'}, page_content='A Visual Guide to Vision Transformers'), Document(id='c2494ec1-f982-4ea6-9834-57cbfca53f8f', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': '14) Training of the Vision Transformer'}, page_content='14) Training of the Vision Transformer'), Document(id='ec73269b-417b-4eb3-86e6-69fd271a810b', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': '10.12)Transformer: Final Result'}, page_content='After the transformer step there is another residual connections which we will skip here for brevity. And so the last step concluded the transformer layer. In the end the transformer produced outputs of the same size as input.'), Document(id='e0c32b75-7935-4a9b-81ca-4270cb7af2c9', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': 

In [16]:

contextualize_q_system_prompt = (
"You are a question reformulation assistant. Given a chat history and the latest user " "question, your task is to reformulate the latest user question into a standalone question ""that can be understood without any reference to the previous conversation. "
"If the question already makes complete sense on its own, return it as is. Do not answer the "
"question. Your task is only to reformulate the question if necessary."
)

print(type(contextualize_q_system_prompt))

<class 'str'>


In [17]:
prompt = ChatPromptTemplate.from_messages([
    ("system", contextualize_q_system_prompt),
    MessagesPlaceholder("chat_history"), 
    ("human", "{input}")
])

#retrieval_chain = create_retrieval_chain(retriever, document_chain)

history_aware_retriever = create_history_aware_retriever(llm, retriever, prompt)
history_aware_retriever

RunnableBinding(bound=RunnableBranch(branches=[(RunnableLambda(lambda x: not x.get('chat_history', False)), RunnableLambda(lambda x: x['input'])
| VectorStoreRetriever(tags=['FAISS', 'OllamaEmbeddings'], vectorstore=<langchain_community.vectorstores.faiss.FAISS object at 0x0000029418AEE5D0>, search_kwargs={}))], default=ChatPromptTemplate(input_variables=['chat_history', 'input'], input_types={'chat_history': list[typing.Annotated[typing.Union[typing.Annotated[langchain_core.messages.ai.AIMessage, Tag(tag='ai')], typing.Annotated[langchain_core.messages.human.HumanMessage, Tag(tag='human')], typing.Annotated[langchain_core.messages.chat.ChatMessage, Tag(tag='chat')], typing.Annotated[langchain_core.messages.system.SystemMessage, Tag(tag='system')], typing.Annotated[langchain_core.messages.function.FunctionMessage, Tag(tag='function')], typing.Annotated[langchain_core.messages.tool.ToolMessage, Tag(tag='tool')], typing.Annotated[langchain_core.messages.ai.AIMessageChunk, Tag(tag='AIMess

In [18]:
qa_prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    MessagesPlaceholder("chat_history"), 
    ("human", "{input}")
])

question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)


In [19]:
from langchain_core.messages import AIMessage, HumanMessage 
chat_history = []
question = "What is the purpose of vision transformers?"
response1 = rag_chain.invoke({"input": question, "chat_history": chat_history})
chat_history.extend(
    [
        HumanMessage(content=question),
        AIMessage(content=response1["answer"])
    ]
)
question2 = "What kind of embeddings are needed for vision transformers"
response2 = rag_chain.invoke({"input": question2, "chat_history": chat_history})
print(response2)

{'input': 'What kind of embeddings are needed for vision transformers', 'chat_history': [HumanMessage(content='What is the purpose of vision transformers?', additional_kwargs={}, response_metadata={}), AIMessage(content='I don\'t know. The provided context only mentions "Training of the Vision Transformer" and "Transformer: Residual connections", but it doesn\'t explicitly state the purpose of vision transformers. However, based on general knowledge, vision transformers are a type of neural network architecture designed for image processing tasks.', additional_kwargs={}, response_metadata={})], 'context': [Document(id='f0c78451-4781-4b7a-a7f1-dc92f41ed0e5', metadata={'Header 1': 'A Visual Guide to Vision Transformers', 'Header 3': '1) Focus on one data point'}, page_content='To get a better understanding of what happens inside a vision transformer lets focus on a single data point (batch size of 1). And lets ask the question: How is this data point prepared in order to be consumed by a

In [20]:
print(chat_history)

[HumanMessage(content='What is the purpose of vision transformers?', additional_kwargs={}, response_metadata={}), AIMessage(content='I don\'t know. The provided context only mentions "Training of the Vision Transformer" and "Transformer: Residual connections", but it doesn\'t explicitly state the purpose of vision transformers. However, based on general knowledge, vision transformers are a type of neural network architecture designed for image processing tasks.', additional_kwargs={}, response_metadata={})]


In [21]:
store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

In [22]:
conversational_rag_chain.invoke(
    {"input": "What kind of embeddings are needed for vision transformers"},
    config={"configurable":{"session_id": "xyz123"}}
)["answer"]


'Vision transformers typically use patch embeddings. Patch embeddings involve dividing the input image into smaller patches, then embedding each patch separately before being concatenated to form the final sequence.'

In [23]:
conversational_rag_chain.invoke(
    {"input": "Tell me more about it?"},
    config={"configurable":{"session_id": "xyz123"}}
)["answer"]

"I don't know. The context only mentions that vision transformers are trained in a supervised manner on a dataset of images and their corresponding labels, but it doesn't provide further details on the specifics of patch embeddings used in them."

In [24]:
store

{'xyz123': InMemoryChatMessageHistory(messages=[HumanMessage(content='What kind of embeddings are needed for vision transformers', additional_kwargs={}, response_metadata={}), AIMessage(content='Vision transformers typically use patch embeddings. Patch embeddings involve dividing the input image into smaller patches, then embedding each patch separately before being concatenated to form the final sequence.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Tell me more about it?', additional_kwargs={}, response_metadata={}), AIMessage(content="I don't know. The context only mentions that vision transformers are trained in a supervised manner on a dataset of images and their corresponding labels, but it doesn't provide further details on the specifics of patch embeddings used in them.", additional_kwargs={}, response_metadata={})])}