# Exp13: Control Text Generation with a locally running LLM

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator
import pandas as pd
import numpy as np
from transformers import BertModel, BertTokenizer
import copy
import os

  from .autonotebook import tqdm as notebook_tqdm
Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


Configuration

In [2]:
MODEL="mistralai/Mistral-7B-Instruct-v0.2"
device = 'cuda'

Load the model and generate three sentences (as indicated by the end of sequence tokens) and print the longest sentence. This re-ranking will be based on the grammar classifiers later.

In [3]:
model = AutoModelForCausalLM.from_pretrained(MODEL, device_map="auto", torch_dtype=torch.float16, cache_dir="/mnt/qb/work/meurers/mpb672/cache")
tokenizer = AutoTokenizer.from_pretrained(MODEL, cache_dir="/mnt/qb/work/meurers/mpb672/cache")

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.00s/it]


Load the grammar classifiers from the previous step

In [18]:
class TaskHead(torch.nn.Module):
    def __init__(self, bert_hidden_size, num_labels, dropout_rate=0.1):
        super().__init__()
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.classifier = torch.nn.Linear(bert_hidden_size, num_labels)

    def forward(self, x):
        x = self.dropout(x)
        return self.classifier(x)

class MultiTaskBERT(torch.nn.Module):
    def __init__(self, bert, task_heads):
        super().__init__()
        self.bert = bert
        self.task_heads = torch.nn.ModuleList(task_heads)

    def forward(self, input_ids, attention_mask, task_id):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        task_output = self.task_heads[task_id](pooled_output)
        return task_output

    def forward_all(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        task_outputs = torch.stack(
            [torch.argmax(self.task_heads[task_id](pooled_output), dim=1) for task_id in range(len(self.task_heads))],
            dim=1
        )
        return task_outputs

In [19]:
df = pd.read_json('../dat/egp_merged.json')
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir="/mnt/qb/work/meurers/mpb672/cache")
backbone_model = BertModel.from_pretrained('bert-base-uncased', cache_dir="/mnt/qb/work/meurers/mpb672/cache")

def load_model(level="A1"):    
    df_level = df[df['Level'] == level]
    num_classifiers = len(df_level)
    task_heads = [TaskHead(backbone_model.config.hidden_size, 2) for _ in range(num_classifiers)]
    multi_task_model = MultiTaskBERT(copy.deepcopy(backbone_model), task_heads).to(device)
    multi_task_model.load_state_dict(torch.load('../models/bert/multi_task_model_state_dict_' + level + '.pth'))
    return multi_task_model

models = {level: load_model(level) for level in ['A1', 'A2']}

In [6]:
cefr_texts = pd.read_csv("../dat/cefr_leveled_texts.csv")
cefr_texts.head()
description = {
    "C2": "Can understand and interpret critically virtually all forms of the written language including abstract, structurally complex, or highly colloquial literary and non-literary writings. Can understand a wide range of long and complex texts, appreciating subtle distinctions of style and implicit as well as explicit meaning.",
    "C1": "Can understand in detail lengthy, complex texts, whether or not they relate to his/her own area of speciality, provided he/she can reread difficult sections.",
    "B2": "Can read with a large degree of independence, adapting style and speed of reading to different texts and purposes, and using appropriate reference sources selectively. Has a broad active reading vocabulary, but may experience some difficulty with low-frequency idioms.",
    "B1": "Can read straightforward factual texts on subjects related to his/her field and interest with a satisfactory level of comprehension.",
    "A2": "Can understand short, simple texts on familiar matters of a concrete type which consist of high frequency everyday or job-related language. Can understand short, simple texts containing the highest frequency vocabulary, including a proportion of shared international vocabulary items.",
    "A1": "Can understand very short, simple texts a single phrase at a time, picking up familiar names, words and basic phrases and rereading as required."
}

In [7]:
cefr_texts

Unnamed: 0,text,label
0,Hi!\nI've been meaning to write for ages and f...,B2
1,﻿It was not so much how hard people found the ...,B2
2,Keith recently came back from a trip to Chicag...,B2
3,"The Griffith Observatory is a planetarium, and...",B2
4,-LRB- The Hollywood Reporter -RRB- It's offici...,B2
...,...,...
1489,Light propagating in the vicinity of astrophys...,C2
1490,Future of dentistry has become one of the most...,C2
1491,﻿The forests – and suburbs – of Europe are ech...,C2
1492,Hedge funds are turning bullish on oil once ag...,C2


Generate candidates and rank them using the classifiers.

In [20]:
def get_scores(level_model, candidates, max_len=128):
    input_ids = []
    attention_masks = []
    
    for candidate in candidates:
        encoding = bert_tokenizer.encode_plus(
            candidate,
            add_special_tokens=True,
            max_length=max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoding['input_ids'].squeeze(0))  # Remove the batch dimension
        attention_masks.append(encoding['attention_mask'].squeeze(0))
    
    input_ids = torch.stack(input_ids).to(device)
    attention_masks = torch.stack(attention_masks).to(device)
    
    return level_model.forward_all(input_ids, attention_mask=attention_masks)

In [35]:
def generate_candidate(input_ids, max_token_sentence = 64, tok_k=10, eos_chars = [".", "!", "?"]):
    generated_tokens = torch.tensor([[]], dtype=torch.int, device=device)
    with torch.no_grad():
        for _ in range(max_token_sentence):
            next_token_logits = model(torch.cat([input_ids, generated_tokens], dim=1)).logits
            probs = torch.nn.functional.softmax(next_token_logits[:, -1, :], dim=-1)
            top_k_probs, top_k_indices = torch.topk(probs, tok_k)
            renormalized_top_k_probs = top_k_probs / top_k_probs.sum()
            top_k_id = torch.multinomial(renormalized_top_k_probs, num_samples=1).item()
            next_token_id = top_k_indices[0, top_k_id]
            
            next_token = tokenizer.decode(next_token_id)
            generated_tokens = torch.cat([generated_tokens, torch.tensor([[next_token_id]]).to(device)], dim=1)
            #print(generated_tokens)
            if any(eos_char in next_token for eos_char in eos_chars):
                break

    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

def write_story(level, story, num_candidates=3, max_len = 512):
    while len(story) < max_len:
        prompt = f"<s>[INST] Continue the writing with language on CEFR level {level}. [/INST]"
        inputs = tokenizer(prompt + story, return_tensors="pt").to(device)
        candidates = [generate_candidate(inputs.input_ids) for i in range(num_candidates)]
        print(candidates)
        scores = get_scores(models[level], candidates)
        print(torch.mean(scores.float(),dim=1))
        story += candidates[torch.argmax(torch.mean(scores.float(),dim=1))] + " "
    return story

In [36]:
num_stories = 10
num_candidates = 5
storyPrompts = cefr_texts.text.apply(lambda text: text[:50].strip().lstrip('\ufeff')).unique()

In [37]:
file_path = "../dat/controlled_generated_texts_mistral.csv"
if os.path.exists(file_path):
    existing_df = pd.read_csv(file_path)
else:
    existing_df = pd.DataFrame(columns=["label", "story", "text"])
    
story_counts = existing_df['label'].value_counts()
for level in ['A1', 'A2']: # description.keys()
    print(level)
    current_count = story_counts.get(level, 0)
    stories_to_add = num_stories - current_count
    for story in storyPrompts[num_stories-stories_to_add:num_stories+1]:
        text = write_story(level, story, num_candidates)
        print(text)
        new_row = {"label": level, "story": story, "text": text}
        pd.DataFrame([new_row]).to_csv(file_path, mode='a', index=False, header=not os.path.exists(file_path))

A1
['23 p.', '23 p.', '23 p.', '23 p.', '23 p.']
tensor([0.1651, 0.1284, 0.1376, 0.1284, 0.1560], device='cuda:0')
['m.', 'm.', 'M.', 'm.', 'm.']
tensor([0.1009, 0.1101, 0.0917, 0.1101, 0.1193], device='cuda:0')
['local time.', 'local time in New York City.', 'local time.', 'local time in New York City -RRB-.', 'local time, New York City.']
tensor([0.1651, 0.2018, 0.1468, 0.0367, 0.1651], device='cuda:0')
['A big fire has started in a high-rise building in midtown Manhattan.', '\n\nA big fire in a high-rise building in Manhattan, New York City.', '-RB- A large fire has started at a skyscraper in the heart of Manhattan.', '\n-RCB- A fire has started in a high-rise building in Manhattan.', '-RRB- The famous Statue of Liberty has turned green.']
tensor([0.1284, 0.1560, 0.0183, 0.0000, 0.0000], device='cuda:0')
['\n\nFire department on the scene.', '\n\nMany people are scared and evacuated their homes.', '\n\nFirefighters are fighting the fire on the 15th floor in a tall apartment building