# Constrained decoding

The goal of this experiment is to reject sentences that are too difficult as judged by the sentence classifiers.

In [1]:
import vertexai
import torch
import numpy as np
from vertexai.preview.language_models import TextGenerationModel
import sys
from sentence_transformers import SentenceTransformer

import os
sys.path.append(os.path.dirname(os.getcwd()))
import config
import pandas as pd
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config.PATH_TO_GCP_CREDS

vertexai.init(project="llm-grammar", location="us-east4")
model = TextGenerationModel.from_pretrained("text-bison@001")
embeddings_model = SentenceTransformer('llmrails/ember-v1')
egp = pd.read_csv('../dat/egponline.csv')

In [2]:
class FeedforwardNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(FeedforwardNN, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_dim, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        hidden = self.relu(self.fc1(x))
        output = self.sigmoid(self.fc2(hidden))
        return output

In [10]:
def return_accepted(sentences, accepted_levels = ['A1','A2']):
    embeddings = embeddings_model.encode(sentences)
    sentences = np.array(sentences)
    keep_indices = np.ones_like(sentences, dtype=bool)
    
    for model_file in os.listdir("../models"):
        construction = egp[egp['#']==int(model_file[:-4])].iloc[0]
        if construction['Level'] not in accepted_levels:
            model = torch.load(f"../models/{model_file}")
            model.eval()
            outputs = model(torch.Tensor(embeddings))
            
            keep_indices &= (outputs<0.5).squeeze().cpu().numpy()
        
    return list(sentences[keep_indices])

sentences = ['Friends are people who we like, trust, and share common interests with']
return_accepted(sentences)

['Friends are people who we like, trust, and share common interests with']

In [13]:
text = "Friends are people who you like and enjoy spending time with."

while len(text) < 500:
    parameters = {
        "candidate_count": 8,
        "max_output_tokens": 64,
        "stop_sequences": [
            ".",
            "!",
            "?"
        ],
        "temperature": 0.9,
        "top_p": 0.8,
        "top_k": 40
    }
    model = TextGenerationModel.from_pretrained("text-bison")
    response = model.predict(text, **parameters)
    sentences = [candidate.text.strip() for candidate in response.candidates]
    print(f"Returned {len(sentences)} candidates")
    sentences = return_accepted(sentences)
    print(f"Accepted {len(sentences)} candidates")
    if len(sentences):
        text += sentences[0] + ". "

Returned 4 candidates
Accepted 4 candidates
Returned 4 candidates
Accepted 4 candidates
Returned 4 candidates
Accepted 4 candidates
Returned 4 candidates
Accepted 4 candidates
Returned 3 candidates
Accepted 3 candidates
Returned 4 candidates
Accepted 4 candidates


In [14]:
text

'Friends are people who you like and enjoy spending time with.Friends are people who share common interests and have a mutual bond. Friends are people who care about each other and support each other through thick and thin. Friends are an important part of our lives. Friends are important because they provide emotional support, companionship, and a sense of belonging. Friends are people who you like and enjoy spending time with. Friends are important because they provide emotional support, companionship, and a sense of belonging. '