In [1]:
import warnings
warnings.filterwarnings("ignore")

import ipywidgets as widgets
from IPython.display import display
import torch
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
# Path where the model was saved
model_path = "./fine_tuned_bart"

# Load the saved model
model = BartForConditionalGeneration.from_pretrained(model_path)
# Load the tokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

# Move model to appropriate device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)



BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): La

In [5]:
# Define the widgets with adjusted layout
input_box = widgets.Textarea(description="Query:", layout=widgets.Layout(width='90%'))
output_box = widgets.Textarea(description="Response:", disabled=True, layout=widgets.Layout(width='90%'))
button = widgets.Button(description="Generate Response")

# Define function to generate responses
def generate_response(query):
    input_text = query + " [SEP]"
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    output = model.generate(
        input_ids,
        max_length=50,
        num_return_sequences=1,
        decoder_start_token_id=model.config.pad_token_id,
        eos_token_id=model.config.eos_token_id,
        pad_token_id=model.config.pad_token_id,
        attention_mask=input_ids.ne(tokenizer.pad_token_id)
    )
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response.split('[SEP]')[-1].strip()

# Function for button click event
def on_button_clicked(b):
    query = input_box.value
    response = generate_response(query)
    output_box.value = response

# Attach button click event
button.on_click(on_button_clicked)

# Display the widgets
display(widgets.VBox([input_box, button, output_box]))

VBox(children=(Textarea(value='', description='Query:', layout=Layout(width='90%')), Button(description='Gener…