# Constrained decoding

The goal of this experiment is to reject texts that are too difficult as judged by the sentence classifiers.

In [23]:
from vertexai.preview.generative_models import GenerativeModel, Part, HarmCategory, HarmBlockThreshold
import pandas as pd
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))
import config
import pandas as pd
import random
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(config.SEED)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config.PATH_TO_GCP_CREDS

from sentence_transformers import SentenceTransformer

import spacy
nlp = spacy.load("en_core_web_sm")

egp = pd.read_csv('../dat/egponline.csv')

In [17]:
class FeedforwardNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(FeedforwardNN, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_dim, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        hidden = self.relu(self.fc1(x))
        output = self.sigmoid(self.fc2(hidden))
        return output

embeddings_model = SentenceTransformer('llmrails/ember-v1')

In [18]:
def get_sentences(text):
    doc = nlp(text)
    sentences = [token.text.strip() for token in doc.sents]
    return list(sentences)

In [31]:
def score(text, level):
    sentences = get_sentences(text)
    embeddings = embeddings_model.encode(sentences)

    total_score = 0
    for model_file in os.listdir("../models"):
        if not model_file.endswith(".pth"): continue
        construction = egp[egp['#']==int(model_file[:-4])].iloc[0]
        if construction['Level'] == level:
            model = torch.load(f"../models/{model_file}")
            model.eval()
            outputs = model(torch.tensor(embeddings, device=device))
            total_score += outputs.detach().cpu().mean()
        
    return total_score

text = 'Friends are people who we like, trust, and share common interests with. We sometimes do a lot of things with them.'
score(text, 'C2')

tensor(0.4084)

In [32]:
cefr_texts = pd.read_csv("../dat/cefr_leveled_texts.csv")
cefr_texts.head()
description = {
    "C2": "Can understand and interpret critically virtually all forms of the written language including abstract, structurally complex, or highly colloquial literary and non-literary writings. Can understand a wide range of long and complex texts, appreciating subtle distinctions of style and implicit as well as explicit meaning.",
    "C1": "Can understand in detail lengthy, complex texts, whether or not they relate to his/her own area of speciality, provided he/she can reread difficult sections.",
    "B2": "Can read with a large degree of independence, adapting style and speed of reading to different texts and purposes, and using appropriate reference sources selectively. Has a broad active reading vocabulary, but may experience some difficulty with low-frequency idioms.",
    "B1": "Can read straightforward factual texts on subjects related to his/her field and interest with a satisfactory level of comprehension.",
    "A2": "Can understand short, simple texts on familiar matters of a concrete type which consist of high frequency everyday or job-related language. Can understand short, simple texts containing the highest frequency vocabulary, including a proportion of shared international vocabulary items.",
    "A1": "Can understand very short, simple texts a single phrase at a time, picking up familiar names, words and basic phrases and rereading as required."
}

In [13]:
def generate(level, storyPrompt):
  model = GenerativeModel("gemini-pro")
  print(level)
  print(storyPrompt)
  
  prompt = f"Write a story using the following prompt on CEFR level {level} (Description: {description[level]})\n\n{storyPrompt}"

  responses = model.generate_content(
    prompt,
    safety_settings={
        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
    },
    generation_config={
        "max_output_tokens": 1024,
        "temperature": 1,
        "top_p": 0.9,
        
    },
  stream=True,
  )

  text = ""
  for response in responses:
    try:
      text += response.candidates[0].content.parts[0].text
    except Exception as e:
      print(response.candidates)
      print(e)
      #return generate(level, storyPrompt)
  return text

num_stories = 10
num_candidates = 3
generated_texts = pd.read_csv("../dat/generated_texts.csv")
storyPrompts = generated_texts.story.unique()

file_path = "../dat/controlled_generated_texts.csv"
if os.path.exists(file_path):
    existing_df = pd.read_csv(file_path)
else:
    existing_df = pd.DataFrame(columns=["label", "story", "text"])
    
story_counts = existing_df['label'].value_counts()
for level in description.keys():
    current_count = story_counts.get(level, 0)
    stories_to_add = num_stories - current_count

    for story in storyPrompts[num_stories-stories_to_add:]:
        candidates = [generate(level, story) for _ in range(num_candidates)]
        scores = [score(candidate) for candidate in candidates]
        print(scores)
        text = candidates[scores.index(max(scores))]
        new_row = {"label": level, "story": story, "text": text}
        pd.DataFrame([new_row]).to_csv(file_path, mode='a', index=False, header=not os.path.exists(file_path))

Returned 4 candidates
Accepted 4 candidates
Returned 4 candidates
Accepted 4 candidates
Returned 4 candidates
Accepted 4 candidates
Returned 4 candidates
Accepted 4 candidates
Returned 3 candidates
Accepted 3 candidates
Returned 4 candidates
Accepted 4 candidates


In [14]:
text

'Friends are people who you like and enjoy spending time with.Friends are people who share common interests and have a mutual bond. Friends are people who care about each other and support each other through thick and thin. Friends are an important part of our lives. Friends are important because they provide emotional support, companionship, and a sense of belonging. Friends are people who you like and enjoy spending time with. Friends are important because they provide emotional support, companionship, and a sense of belonging. '