# Exp13: Control Text Generation with a locally running LLM

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator
import pandas as pd
import numpy as np
from transformers import BertModel, BertTokenizer
import copy
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
import config
import random


  from .autonotebook import tqdm as notebook_tqdm
Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


Configuration

In [2]:
MODEL="mistralai/Mistral-7B-Instruct-v0.2"
device = 'cuda'

Load the model and generate three sentences (as indicated by the end of sequence tokens) and print the longest sentence. This re-ranking will be based on the grammar classifiers later.

Load the grammar classifiers from the previous step

In [3]:
class NonlinearTaskHead(torch.nn.Module):
    def __init__(self, input_dim, num_labels, hidden_dim=16):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.classifier = torch.nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        hidden = self.relu(self.fc1(x))
        output = self.classifier(hidden)
        return output

class MultiTaskBERT(torch.nn.Module):
    def __init__(self, bert, task_heads):
        super().__init__()
        self.bert = bert
        self.task_heads = torch.nn.ModuleList(task_heads)

    def forward(self, input_ids, attention_mask, task_id):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        task_output = self.task_heads[task_id](pooled_output)
        return task_output

    def forward_all(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        task_outputs = torch.stack(
            [torch.argmax(self.task_heads[task_id](pooled_output), dim=1) for task_id in range(len(self.task_heads))],
            dim=1
        )
        return task_outputs

In [4]:
df = pd.read_json('../dat/egp_merged.json')
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config.CACHE_DIR)
backbone_model = BertModel.from_pretrained('bert-base-uncased', cache_dir=config.CACHE_DIR)

def load_model(level="A1", max_constructs=500):  
    df_level = df[df['Level'] == level]
    num_classifiers = min(len(df_level), max_constructs)
    task_heads = [NonlinearTaskHead(backbone_model.config.hidden_size, 2) for _ in range(num_classifiers)]
    multi_task_model = MultiTaskBERT(copy.deepcopy(backbone_model), task_heads).to(device)
    multi_task_model.load_state_dict(torch.load('../models/bert/multi_task_model_state_dict_' + level + '.pth'))
    return multi_task_model

models = {level: load_model(level) for level in ["A1", "A2", "B1", "B2", "C1", "C2"]}

In [5]:
cefr_texts = pd.read_csv("../dat/cefr_leveled_texts.csv")
cefr_texts.head()
description = {
    "C2": "Can understand and interpret critically virtually all forms of the written language including abstract, structurally complex, or highly colloquial literary and non-literary writings. Can understand a wide range of long and complex texts, appreciating subtle distinctions of style and implicit as well as explicit meaning.",
    "C1": "Can understand in detail lengthy, complex texts, whether or not they relate to his/her own area of speciality, provided he/she can reread difficult sections.",
    "B2": "Can read with a large degree of independence, adapting style and speed of reading to different texts and purposes, and using appropriate reference sources selectively. Has a broad active reading vocabulary, but may experience some difficulty with low-frequency idioms.",
    "B1": "Can read straightforward factual texts on subjects related to his/her field and interest with a satisfactory level of comprehension.",
    "A2": "Can understand short, simple texts on familiar matters of a concrete type which consist of high frequency everyday or job-related language. Can understand short, simple texts containing the highest frequency vocabulary, including a proportion of shared international vocabulary items.",
    "A1": "Can understand very short, simple texts a single phrase at a time, picking up familiar names, words and basic phrases and rereading as required."
}

In [6]:
cefr_texts

Unnamed: 0,text,label
0,Hi!\nI've been meaning to write for ages and f...,B2
1,﻿It was not so much how hard people found the ...,B2
2,Keith recently came back from a trip to Chicag...,B2
3,"The Griffith Observatory is a planetarium, and...",B2
4,-LRB- The Hollywood Reporter -RRB- It's offici...,B2
...,...,...
1489,Light propagating in the vicinity of astrophys...,C2
1490,Future of dentistry has become one of the most...,C2
1491,﻿The forests – and suburbs – of Europe are ech...,C2
1492,Hedge funds are turning bullish on oil once ag...,C2


Generate candidates and rank them using the classifiers.

In [7]:
model = AutoModelForCausalLM.from_pretrained(MODEL, device_map="auto", torch_dtype=torch.float16, cache_dir=config.CACHE_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL, cache_dir=config.CACHE_DIR)

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:55<00:00, 18.53s/it]


In [8]:
def get_scores(level_model, candidates, max_len=128):
    input_ids = []
    attention_masks = []
    
    for candidate in candidates:
        encoding = bert_tokenizer.encode_plus(
            candidate,
            add_special_tokens=True,
            max_length=max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoding['input_ids'].squeeze(0))  # Remove the batch dimension
        attention_masks.append(encoding['attention_mask'].squeeze(0))
    
    input_ids = torch.stack(input_ids).to(device)
    attention_masks = torch.stack(attention_masks).to(device)
    
    return level_model.forward_all(input_ids, attention_mask=attention_masks)

In [40]:
get_scores(models["C2"], ["Don't ever touch my belongings without permission!", "Don't take another step!", "Don't question my authority again!", "Don't miss the deadline!", "Don't ever talk back to me!"])

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [23]:
def generate_candidate(input_ids, max_token_sentence = 64, tok_k=10, eos_chars = [".", "!", "?"]):
    generated_tokens = torch.tensor([[]], dtype=torch.int, device=device)
    with torch.no_grad():
        for _ in range(max_token_sentence):
            next_token_logits = model(torch.cat([input_ids, generated_tokens], dim=1)).logits
            probs = torch.nn.functional.softmax(next_token_logits[:, -1, :], dim=-1)
            top_k_probs, top_k_indices = torch.topk(probs, tok_k)
            renormalized_top_k_probs = top_k_probs / top_k_probs.sum()
            top_k_id = torch.multinomial(renormalized_top_k_probs, num_samples=1).item()
            next_token_id = top_k_indices[0, top_k_id]
            
            next_token = tokenizer.decode(next_token_id)
            generated_tokens = torch.cat([generated_tokens, torch.tensor([[next_token_id]]).to(device)], dim=1)
            #print(generated_tokens)
            if any(eos_char in next_token for eos_char in eos_chars):
                break

    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

def write_story(level, story, num_candidates=3, max_len = 512, add_info=False):
    info = f", which is described as {description[level]}" if add_info else ""
    prompt = f"<s>[INST] Continue the writing on CEFR level {level}{info}. Do not talk about the CEFR level. [/INST] "
    while len(story) < max_len:
        inputs = tokenizer(prompt + story, return_tensors="pt").to(device)
        candidates = [generate_candidate(inputs.input_ids) for i in range(num_candidates)]
        scores = get_scores(models[level], candidates)
        mean_scores = torch.mean(scores.float(),dim=1)
        #print(list(zip(candidates, mean_scores)))
        story += " " + candidates[torch.argmax(mean_scores)]
    return story

In [24]:
num_stories = 5
num_candidates = 3
min_length = 50
storyPrompts = cefr_texts.text.apply(lambda text: text[:text.find(' ', min_length)].strip().lstrip('\ufeff')).unique()
random.shuffle(storyPrompts)

In [28]:
file_path = "../dat/controlled_generated_texts_mistral.csv"

for story in storyPrompts[:num_stories]:
    print("_" * 100)
    print(story)
    for level in models.keys():
        print(level)
        text = write_story(level, story, num_candidates, add_info=True)
        print(text)
        #new_row = {"label": level, "story": story, "text": text}
        #pd.DataFrame([new_row]).to_csv(file_path, mode='a', index=False, header=not os.path.exists(file_path))

____________________________________________________________________________________________________
Lucy had a cat. His name was Pirate. Pirate was 14
A1
Lucy had a cat. His name was Pirate. Pirate was 14 years old. every day, Lucy fed Pirate and gave him water. Pirate purred loudly. He liked to sit on Lucy's lap and be stroked. Lucy loved Pirate very much. She took care of him with kindness. In the evening, Pirate slept next to her in bed. It was cozy and warm. 

Sometimes, Pirate would play with a red ball. He would bat it around the room with his paw. Lucy watched him play and laughed. Pirate brought her joy. He was a good friend. 

Lucy took Pirate to the vet for regular check-ups.
A2
Lucy had a cat. His name was Pirate. Pirate was 14 years old, grey and had a missing eyelid which made him look a little bit scary yet cute. Lucy loved Pirate dearly, and they spent most days together at home. Pirate would often lie in the sun, while Lucy read books or worked on her computer. 

Lucy'

KeyboardInterrupt: 