## **Finetuning DPR(Dense Passage Retriever) model**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install wikipedia-api
!pip install transformers

Collecting wikipedia-api
  Downloading wikipedia_api-0.7.1.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: wikipedia-api
  Building wheel for wikipedia-api (setup.py) ... [?25l[?25hdone
  Created wheel for wikipedia-api: filename=Wikipedia_API-0.7.1-py3-none-any.whl size=14346 sha256=2f31a5094432cbe8a9bbd82298e23c19db5d83b20a036c0e1807da96137ba45e
  Stored in directory: /root/.cache/pip/wheels/4c/96/18/b9201cc3e8b47b02b510460210cfd832ccf10c0c4dd0522962
Successfully built wikipedia-api
Installing collected packages: wikipedia-api
Successfully installed wikipedia-api-0.7.1


In [3]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [None]:
import pandas as pd
import numpy as np
import wikipediaapi
import torch
from torch import nn
import torch.nn.functional as F
from datasets import Dataset
from transformers import get_scheduler
import re
from transformers import (DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer, DPRQuestionEncoder, DPRContextEncoder, DPRReader, Trainer, TrainingArguments)
import random

In [5]:
import torch

if torch.cuda.is_available():
    print("GPU is available:", torch.cuda.get_device_name(0))
else:
    print("GPU not available. Check runtime settings.")

GPU is available: Tesla T4


## Extraction of pages related to the given plant name from wikipedia

In [6]:
wiki = wikipediaapi.Wikipedia('plantInfoRetrieval(haricharangoudca1@gmail.com)', 'en' )

def fetch_page(plant_name):
  page = wiki.page(plant_name)
  if page.exists:
    return page.text
  else:
    return None

In [7]:
plant = 'rose'
plant_info = fetch_page(plant)
if plant_info:
  print(plant_info[:200])
else:
  print(f"No information found for {plant}")

A rose is either a woody perennial flowering plant of the genus Rosa (), in the family Rosaceae (), or the flower it bears. There are over three hundred species and tens of thousands of cultivars. The


In [8]:
def load_and_preprocess_data(csv_path):
    df = pd.read_csv(csv_path)
    df['query'] = df['Plant Name']
    df['context'] = (
        "Plant Family: " + df['Family'] + "\n"
        + "Description: " + df['Description'] + "\n"
        + "Uses: " + df['Uses']
    )
    return df[['query', 'context']]

In [21]:
def initialize_models_and_tokenizers(device):
    query_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(device)
    passage_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)

    query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
    passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

    return query_model, passage_model, query_tokenizer, passage_tokenizer

In [22]:
def create_dense_layer(input_size, output_size):
    return nn.Sequential(
        nn.Linear(input_size, output_size * 2),
        nn.ReLU(),
        nn.Linear(output_size * 2, output_size),
        nn.GELU(),
    )


In [23]:
def batch_tokenize(df, query_tokenizer, passage_tokenizer, batch_size=2, sample_size=4):
    rand_idx = np.random.randint(0, len(df), (batch_size, sample_size))
    queries = []
    contexts = []
    true_idx = []

    for row_idx, row in enumerate(rand_idx):
        rand_query_idx = random.randint(0, len(df) - 1)
        query = df.iloc[rand_query_idx]['query']
        true_context = df.iloc[rand_query_idx]['context']
        queries.append(query)

        if rand_query_idx not in row:
            idx = random.randint(0, sample_size - 1)
            rand_idx[row_idx][idx] = rand_query_idx
            true_idx.append(idx)
        else:
            true_idx.append(np.where(rand_idx[row_idx] == rand_query_idx)[0][0])

        for col_idx in row:
            context = df.iloc[col_idx]['context']
            contexts.append(context)

    passage_tensor = passage_tokenizer(contexts, padding='longest', return_tensors="pt")
    query_tensor = query_tokenizer(queries, padding='longest', return_tensors="pt")

    return passage_tensor, query_tensor, true_idx


In [None]:
def dot_product_similarity(query_embeddings, passage_embeddings):
    query_embeddings = query_embeddings.unsqueeze(1)
    similarity_scores = torch.matmul(query_embeddings, passage_embeddings.transpose(-2, -1))
    return similarity_scores
def forward_pass(
    query_model, passage_model, query_tokenizer, passage_tokenizer,
    df, passage_dense_layer, query_dense_layer,
    device, batch_size=2, sample_size=4
):
    passage_tensor, query_tensor, true_idx = batch_tokenize(df, query_tokenizer, passage_tokenizer, batch_size, sample_size)

    passage_tensor = {key: val.to(device) for key, val in passage_tensor.items()}
    query_tensor = {key: val.to(device) for key, val in query_tensor.items()}

    dense_passage = passage_model(input_ids=passage_tensor['input_ids'], attention_mask=passage_tensor['attention_mask'])
    dense_query = query_model(input_ids=query_tensor['input_ids'], attention_mask=query_tensor['attention_mask'])

    dense_passage = dense_passage['pooler_output']
    dense_passage = dense_passage.reshape(batch_size, sample_size, -1).to(device)
    dense_query = dense_query['pooler_output'].to(device)

    dense_passage = passage_dense_layer(dense_passage)
    dense_query = query_dense_layer(dense_query)

    similarity_scores = dot_product_similarity(dense_query, dense_passage)
    similarity_scores = similarity_scores.squeeze(1)

    log_scores = F.log_softmax(similarity_scores, dim=1)

    return log_scores, true_idx

In [25]:
def compute_loss(log_scores, true_idx,device):
    true_idx = torch.tensor(true_idx, dtype=torch.long, device = device)
    loss = F.nll_loss(log_scores, true_idx)
    return loss

In [29]:
def finetune_model(
    query_model, passage_model, query_tokenizer, passage_tokenizer,
    df, passage_dense_layer, query_dense_layer,
    device, epochs, batch_size, sample_size, learning_rate,
    model_save_path="/content/drive/MyDrive/fine_tuned_model"
):
    optimizer = torch.optim.AdamW(
        list(query_model.parameters())
        + list(passage_model.parameters())
        + list(passage_dense_layer.parameters())
        + list(query_dense_layer.parameters()),
        lr=learning_rate,
        weight_decay=1e-4
    )

    total_steps = (len(df) // batch_size) * epochs
    scheduler = get_scheduler(
        "linear", optimizer=optimizer, num_warmup_steps=int(total_steps * 0.1), num_training_steps=total_steps
    )

    for epoch in range(epochs):
        query_model.train()
        passage_model.train()
        epoch_loss = 0

        for _ in range(len(df) // batch_size):
            optimizer.zero_grad()

            log_scores, true_idx = forward_pass(
                query_model, passage_model, query_tokenizer, passage_tokenizer,
                df, passage_dense_layer, query_dense_layer,
                device, batch_size, sample_size
            )

            loss = compute_loss(log_scores, true_idx, device)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(query_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(passage_model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {epoch_loss:.4f}")

    print("Saving fine-tuned models...")
    query_model.save_pretrained(f"{model_save_path}/query_model")
    passage_model.save_pretrained(f"{model_save_path}/passage_model")
    torch.save(passage_dense_layer.state_dict(), f"{model_save_path}/passage_dense_layer.pth")
    torch.save(query_dense_layer.state_dict(), f"{model_save_path}/query_dense_layer.pth")
    print(f"Models saved to {model_save_path}")


In [30]:
def main():
    csv_path = '/content/drive/MyDrive/plants_dataset.csv'
    df = load_and_preprocess_data(csv_path)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    query_model, passage_model, query_tokenizer, passage_tokenizer = initialize_models_and_tokenizers(device)

    dense_size = 128
    passage_dense_layer = create_dense_layer(768, dense_size).to(device)
    query_dense_layer = create_dense_layer(768, dense_size).to(device)

    model_save_path = "/content/drive/MyDrive/fine_tuned_model"
    finetune_model(
        query_model, passage_model, query_tokenizer, passage_tokenizer,
        df, passage_dense_layer, query_dense_layer,
        device, epochs=10, batch_size=10, sample_size=8, learning_rate=5e-3,
        model_save_path=model_save_path
    )

if __name__ == "__main__":
    main()

Using device: cuda


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Epoch 1/10, Training Loss: 2.0801
Epoch 2/10, Training Loss: 2.0739
Epoch 3/10, Training Loss: 2.0691
Epoch 4/10, Training Loss: 2.2256
Epoch 5/10, Training Loss: 2.0834
Epoch 6/10, Training Loss: 2.2512
Epoch 7/10, Training Loss: 2.0881
Epoch 8/10, Training Loss: 2.2328
Epoch 9/10, Training Loss: 2.0728
Epoch 10/10, Training Loss: 2.1214
Saving fine-tuned models...
Models saved to /content/drive/MyDrive/fine_tuned_model


In [36]:
import re
import torch

def extract_relevant_info(text):
    family_name = None
    description = None
    uses = None

    family_pattern = re.search(
        r"(?i)((\b(plant\s*family|family\s*name|genus)\s*[:\-]?\s*[\w\s]+)|" +
        r"belongs\s+to\s+the\s+(family|genus)\b.*?\." +
        r"|is\s+a\s+member\s+of\s+the\s+(family|genus)\b.*?\." +
        r"|classified\s+under\s+(the\s+family|the\s+genus)\b.*?\." +
        r"|is\s+from\s+the\s+(family|genus)\b.*?\." +
        r"|related\s+to\s+(the\s+family|the\s+genus)\b.*?\.)", text)

    description_pattern = re.search(
        r"(?i)(\b\w+\s+(is|are)\s+(a|an|widely|commonly|known|characterized|used|identified)\b.*?\." +
        r"|belongs\s+to\b.*?\." +
        r"|has\s+(properties|features|characteristics)\b.*?\." +
        r"|can\s+be\s+(identified|used|found)\b.*?\.)", text)

    uses_pattern = re.search(r"(?i)(uses\s*[:\-]?\s*([^.]+))", text)

    if family_pattern:
        family_name = family_pattern.group(0).strip()

    if description_pattern:
        description = description_pattern.group(0).strip()

    if uses_pattern:
        uses = uses_pattern.group(2).strip()

    if description:
        description = limit_text_length(description, min_length=50, max_length=100)

    if uses:
        uses = limit_text_length(uses, min_length=50, max_length=100)

    if not description:
        description = extract_first_sentences(text, num_sentences=2)

    return {
        "family_name": family_name,
        "description": description,
        "uses": uses
    }

def limit_text_length(text, min_length=50, max_length=100):
    words = text.split()
    if len(words) > max_length:
        text = " ".join(words[:max_length]) + "..."
    elif len(words) < min_length:
        text = " ".join(words + ['...'])
    return text

def extract_first_sentences(text, num_sentences=2):
    sentences = text.split(". ")
    return ". ".join(sentences[:num_sentences]) + ("" if text.endswith(".") else ".")

def retrieve_relevant_documents(query):
    document_text = fetch_page(query)

    if not document_text:
        return ["Sorry, no relevant information found."]

    return extract_relevant_info(document_text)

def predict(user_input):
    query_tensor = query_tokenizer(user_input, padding='longest', truncation=True, max_length=512, return_tensors="pt")
    with torch.no_grad():
        dense_query = query_model(**query_tensor)['pooler_output']
    dense_query = query_dense_layer(dense_query)

    relevant_info = retrieve_relevant_documents(user_input)

    document_list = [f"""
    Family: {relevant_info.get('family_name', 'Not available')}
    Description: {relevant_info.get('description', 'Not available')}
    Uses: {relevant_info.get('uses', 'Not available')}
    """]

    passage_tensor = passage_tokenizer(document_list, padding='longest', truncation=True, max_length=512, return_tensors="pt")
    with torch.no_grad():
        dense_passage = passage_model(input_ids=passage_tensor['input_ids'], attention_mask=passage_tensor['attention_mask'])['pooler_output']
    dense_passage = passage_dense_layer(dense_passage)

    similarity_scores = torch.matmul(dense_query, dense_passage.transpose(-2, -1)).squeeze()
    most_similar_index = torch.argmax(similarity_scores).item()

    return document_list[most_similar_index]

user_input = input("Enter your query: ")
output = predict(user_input)
print(f"\nRelevant Information:\n{output}")


Enter your query: clove

Relevant Information:

    Family: None
    Description: and are commonly used as a spice, flavoring, or fragrance in consumer products, such as toothpaste, soaps, or cosmetics. ...
    Uses: Cloves are used in the cuisine of Asian, African, Mediterranean, and the Near and Middle East countries, lending flavor to meats (such as baked ham), curries, and marinades, as well as fruit (such as apples, pears, and rhubarb) ...
    
