In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
os.environ['TRANSFORMERS_CACHE'] = '../huggingface_cache/model_cache'
os.environ['HF_DATASETS_CACHE'] = '../huggingface_cache/data_cache'

In [2]:
data_path = Path("../data/pairs1.csv")
data = pd.read_csv(data_path, index_col='Unnamed: 0')
data.head()

Unnamed: 0,q,a
0,1. How expensive is NMSU?,New Mexico State University is proud to offer ...
1,2. What kind of university is NMSU?,"NMSU is a public, land-grant university, which..."
2,3. Is NMSU ranked?,Yes! Visit this page for NMSU rankings and exp...
3,4. What are the most popular fields of study?,NMSU is a comprehensive university. Popular un...
4,5. What is the academic calendar at NMSU?,"The academic calendar at NMSU, like most U.S. ..."


In [3]:
import re

def remove_number_start(sentence):
    return re.sub('[0-9]+\.\ ', '', sentence)

s = "1. How expensive is NMSU?"
data['q'] = data['q'].apply(remove_number_start)
data.head(20)

Unnamed: 0,q,a
0,How expensive is NMSU?,New Mexico State University is proud to offer ...
1,What kind of university is NMSU?,"NMSU is a public, land-grant university, which..."
2,Is NMSU ranked?,Yes! Visit this page for NMSU rankings and exp...
3,What are the most popular fields of study?,NMSU is a comprehensive university. Popular un...
4,What is the academic calendar at NMSU?,"The academic calendar at NMSU, like most U.S. ..."
5,How many international students are there at N...,Currently NMSU enrolls more than 700 internati...
6,"How do NMSU students pay their tuition, fees, ...",Payment for tuition (course costs) and fees ar...
7,Is financial aid available to international st...,There are limited numbers of scholarships and ...
8,What is the health insurance requirement? How ...,All F-1 students and J-1 students/scholars are...
9,Can I purchase a health insurance plan from my...,If you would like to have a non-NMSU health in...


In [4]:
import torch
from sentence_transformers import SentenceTransformer
from transformers import pipeline

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

qa_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
general_model = pipeline('text-generation', model='EleutherAI/gpt-neo-1.3B', device=DEVICE)

In [9]:
class FAQBot:
    def __init__(
        self, 
        qa_model,
        general_model,
        database: pd.DataFrame,
        threshold=0.8,
    ):
        self.qa_model = qa_model
        self.general_model = general_model
        self.database = database
        self.threshold = threshold


        self.q_embeddings = self.qa_model.encode(
            self.database['q'],
            device=DEVICE,
            convert_to_tensor=True,
        )

    def get_answer(self, question: str):
        user_qe = self.qa_model.encode(
            question,
            device=DEVICE,
            convert_to_tensor=True,
        )
        user_qe = user_qe.view(1, -1)
        calc_score = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
        dists = calc_score(user_qe, self.q_embeddings)
        ans_idx = torch.argmax(dists).cpu().numpy()
        best_score = dists[ans_idx]
        result = {}
        if best_score >= self.threshold:
            result['Score'] = best_score
            result['Question'] = data['q'][ans_idx]
            result['Answer'] = data['a'][ans_idx]
        else:
            ans = self.general_model(question, do_sample=True, min_length=10, max_new_tokens=100)[0]
            result['Answer'] = ans['generated_text'].split('\n\n')[1]
        return result

faq_bot = FAQBot(
    qa_model=qa_model,
    general_model=general_model,
    database=data,
)

In [12]:
faq_bot.get_answer("What is your purpose?")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


{'Answer': 'We don’t go around thinking we have purpose, but there are always reasons why we go to work each day. What is it that makes you want to walk down the road? How does it feel to drive your car? Or walk down the sidewalk? Or take the stairs of a building? Or eat lunch at work? Or read the newspaper? Or watch TV?'}

In [6]:
import ipywidgets as widgets

In [12]:
input_question = widgets.Text(
    placeholder='Type something',
    description='User question:',
)
score_text = 
display(input_question)
faq_bot.get_answer(input_question.value)

Text(value='', description='String:', placeholder='Type something')

In [9]:
input_question.value

'Hello World'