In [1]:
import numpy as np
import torch
from tqdm import tqdm

In [2]:
%%capture
from transformers import AutoTokenizer, AutoModelForMultipleChoice

model = AutoModelForMultipleChoice.from_pretrained("danlou/albert-xxlarge-v2-finetuned-csqa")

tokenizer = AutoTokenizer.from_pretrained("danlou/albert-xxlarge-v2-finetuned-csqa")

In [3]:
import logging
logging.basicConfig(filename='inferece_albert.log', filemode='w', format='%(asctime)s : %(name)s : %(levelname)s : %(message)s')
logger=logging.getLogger() 
logger.setLevel(logging.INFO)

In [4]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
logger.info("Current Device: " + device)

In [9]:
_ = model.to(device)

In [6]:
def inference(Question, Choices):
    
#     choices = riddle['choices']    
    
    # create k(number of options)  duplicates of the question
    choices = Choices.split(",")
    option_count = len(choices)
    
    if option_count < 2:
        return "More than 2 options should be provided"
#     questions = ((riddle['question'] + "~~") * option_count).split("~~")[:-1]

    questions = ((Question + "~~") * option_count).split("~~")[:-1]


    encoding = tokenizer(questions, choices, return_tensors="pt", padding=True).to(device)
    outputs = model(**{k: v.unsqueeze(0) for k, v in encoding.items()})
    probability = torch.softmax(outputs.logits, dim=1) 
    prediction = torch.argmax(probability, dim=1).item()
    probability = probability.tolist()[0]

    proba_dist = " "
    
    for i in range(len(probability)):
        proba_dist = proba_dist + choices[i]+ ":"
        proba_dist = proba_dist + str(round(probability[i], 4))
        
        if i < len(probability)-1:
            proba_dist = proba_dist + " | "
        
    complete_output = "Prediction --> "+ str(prediction) + "  || " +"Probability Distribution --> " + proba_dist    
    
    logger.info("Question:: "+ Question+"    "+ complete_output)

    return  complete_output# "Probability Distribution --> " + proba_dist



In [84]:
riddle = {
    'question': 'A man is incarcerated in prison, and as his punishment he has to carry a one tonne bag of sand backwards and forwards across a field the size of a football pitch.  What is the one thing he can put in it to make it lighter?'
, 'choices': ['throw', 'bit', 'gallon', 'mouse', 'hole']}


In [11]:
question = 'A man is incarcerated in prison, and as his punishment he has to carry a one tonne bag of sand backwards and forwards across a field the size of a football pitch.  What is the one thing he can put in it to make it lighter?'
choices =  'throw ,  bit ,  gallon, mouse , hole'
            
            

In [12]:
inference(question, choices)

'Prediction --> 4  || Probability Distribution -->  throw :0.1484 |   bit :0.0632 |   gallon:0.0838 |  mouse :0.2529 |  hole:0.4517'

In [7]:
import gradio as gr

In [34]:
question

'A man is incarcerated in prison, and as his punishment he has to carry a one tonne bag of sand backwards and forwards across a field the size of a football pitch.  What is the one thing he can put in it to make it lighter?'

In [8]:
demo = gr.Interface(fn=inference,
                    inputs=["text", "text"],
                    outputs="text",
                    title="Riddle Solver",
                    description="Question: must be in string,  Choices: Options seperated by comma")

demo.launch(share=True) 

IMPORTANT: You are using gradio version 3.0, however version 3.14.0 is available, please upgrade.
--------
Running on local URL:  http://127.0.0.1:7860/
Running on public URL: https://b59b9b5d3ee6a2c4.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://www.huggingface.co/spaces)


(<gradio.routes.App at 0x7f91e58ecb50>,
 'http://127.0.0.1:7860/',
 'https://b59b9b5d3ee6a2c4.gradio.app')

In [11]:
!pip install gradio==3.0

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting gradio==3.0
  Downloading gradio-3.0-py3-none-any.whl (5.6 MB)
[K     |████████████████████████████████| 5.6 MB 16.6 MB/s eta 0:00:01
Collecting paramiko
  Downloading paramiko-3.1.0-py3-none-any.whl (211 kB)
[K     |████████████████████████████████| 211 kB 117.6 MB/s eta 0:00:01
[?25hCollecting pycryptodome
  Downloading pycryptodome-3.17-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 111.6 MB/s eta 0:00:01
Collecting analytics-python
  Downloading analytics_python-1.4.post1-py2.py3-none-any.whl (23 kB)
Collecting monotonic>=1.5
  Downloading monotonic-1.6-py2.py3-none-any.whl (8.2 kB)
Collecting backoff==1.10.0
  Downloading backoff-1.10.0-py2.py3-none-any.whl (31 kB)
Collecting bcrypt>=3.2
  Downloading bcrypt-4.0.1-cp36-abi3-manylinux_2_28_x86_64.whl (593 kB)
[K     |████████████████████████████████| 593 kB 118.8 MB/s eta 0:00: