In [6]:
import numpy as np
import gym # open ai gym
import os,re
import torch
import random
from static_env import StaticEnv
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM



In [32]:
def seed_everything(seed=42):                                                 
    random.seed(seed)                                                     
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)                                                   
        torch.cuda.manual_seed_all(seed)                                             
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

if torch.cuda.is_available():
    device = torch.device("cuda")
        #print("Using GPU")
else:
    device = torch.device("cpu")

class LLMQueryEnv(gym.Env, StaticEnv):
    """
    Simple gym environment with the goal to navigate the player from its
    starting position to the highest point on a two-dimensional map within
    a limited number of steps. Rewards are defined as the difference in
    altitude between states minus a penalty for each step. The player starts
    in the lower left corner of the map and the highest point is in the upper
    right corner. Map layout mirrors CliffWalking environment:
    top left = (0, 0), top right = (0, m-1), bottom left = (n-1, 0),
    bottom right = (n-1, m-1).
    The setup of this environment was inspired by the energy landscape in
    protein folding.
    """

    # origAIG incomplete prompt
    
    def __init__(self,orig_prompt="def hello_world():"):
        seed_everything()
        model_name = "Salesforce/codegen-350M-multi"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        self.orig_prompt = orig_prompt
        self.init_state = self.get_tokenized_state(self.orig_prompt)
        self.num_tokens=0
        self.n_actions = 51200 #self.tokenizer.vocab_size
        self.stopwords = ['\n\n']
        self.depth=20

        #self.ep_length = NUM_LENGTH_EPISODES # not required

    def get_tokenized_state(self,prompt):
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        return input_ids.numpy()

    def get_initial_state(self):
        state = self.init_state
        return state

    def reset(self):
        """ Go back to the initial state. """
        state = self.init_state
        return state

    def trim_with_stopwords(self, currentState):
        
        # with torch.no_grad():
        #     decoded = self.tokenizer.decode(currentState[0])
            
        for w in sorted(self.stopwords, key=len, reverse=True):
            
            if currentState.endswith(w):
                currentState = currentState[:-len(w)]
                # print('Trimmed', repr(currentState))
                return currentState
    
    def isPromptComplete(self,currentState,depth):
        """Needs to be implemented"""
        
        with torch.no_grad():
            torchState = torch.from_numpy(currentState).to(device)
            decoded = self.tokenizer.decode(currentState[0])
            
        # decoded = self.get_prompt_from_state(self,currentState)
        print('decoded state',repr(decoded))
        
        for w in sorted(self.stopwords, key=len, reverse=True):
            if decoded.endswith(w):
                # decoded = decoded[:-len(w)]
                # print('Trimmed', repr(decoded))
                return True

    # def getPromptScore(self,completePrompt):
    #     """Needs to be implemented"""
    #     return 0.0

    def getPromptScore(self,currentState):
        """Needs to be implemented"""
        #print(currentState)
        #print(self.num_tokens)
        initScore = 0
        if currentState[0][-1] == 4480 or currentState[0][-1] == 361 or currentState[0][-1] == 1820:
            initScore+=0.1
        if currentState[0][-2] == 37881 or currentState[0][-2] == 2 or currentState[0][-2] == 17772:
            initScore+=0.1
        if currentState[0][-3] == 1635 or currentState[0][-3] == 611 or currentState[0][-3] == 220:
            initScore+=0.1
        if currentState[0][-4] == 422 or currentState[0][-4] == 6 or currentState[0][-4] == 5145:
            initScore+=0.1
        if currentState[0][-5] == 995 or currentState[0][-5] == 422 or currentState[0][-5] == 705:
            initScore+=0.1
        if currentState[0][-6] == 15496 or currentState[0][-6] == 6894 or currentState[0][-6] == 10603:
            initScore+=0.1
        if currentState[0][-7] == 334 or currentState[0][-7] == 1391 or currentState[0][-7] == 6407:
            initScore+=0.1
        if currentState[0][-8] == 4798 or currentState[0][-8] == 37811 or currentState[0][-8] == 31373:
            initScore+=0.1
        if currentState[0][-9] == 50280 or currentState[0][-9] == 50286 or currentState[0][-9] == 50272:
            initScore+=0.1
        if currentState[0][-10] == 201 or currentState[0][-10] == 1441 or currentState[0][-10] == 628:
            initScore+=0.1
        return initScore

    def next_state(self,state,action):
        nextState = np.append(state,np.array([[action]]),axis=-1)
        return nextState

    def is_done_state(self,state,depth):
        
        if self.isPromptComplete(state,depth):
            return True
        elif depth>=self.depth:
            return True
        else:
            return False

    def get_prompt_from_state(self,state):
        with torch.no_grad():
            torchState = torch.from_numpy(state).to(device)
            prompt_from_state = self.tokenizer.decode(torchState[0])
            return prompt_from_state

    def getLLMestimates(self,state):
        with torch.no_grad():
            torchState = torch.from_numpy(state).to(device)
            output = self.model(input_ids=torchState)
            next_token_logits = output.logits[0, -1, :]
            next_token_probs = torch.softmax(next_token_logits, dim=-1)
            return next_token_probs.detach().cpu().numpy()

    def get_best_terminal_state(self,state,depth):
        with torch.no_grad():
            torchState = torch.from_numpy(state).to(device)
            while not self.is_done_state(state,depth):
                output = self.model(input_ids=torchState)
                next_token_logits = output.logits[0, -1, :]
                next_token_probs = torch.softmax(next_token_logits, dim=-1)
                sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)
                torchState = torch.cat([torchState, sorted_ids[None,0, None]], dim=-1)
                state = torchState.detach().cpu().numpy()
                depth+=1
            return state

    def get_montecarlo_return(self,state,depth):
        best_terminal_state = self.get_best_terminal_state(state,depth)
        complete_prompt = self.get_prompt_from_state(best_terminal_state)
        score = self.getPromptScore(best_terminal_state)
        #score = self.getPromptScore(complete_prompt)
        return score

    def get_return(self,state,depth):
        ##Sanity Check##
        if not self.is_done_state(state,depth):
            print("Serious error")
            exit(1)
        complete_prompt = self.get_prompt_from_state(state)
        #score = self.getPromptScore(complete_prompt)
        score = self.getPromptScore(state)
        return score

In [33]:

env = LLMQueryEnv(orig_prompt="def hello_world()")
init_state = env.get_initial_state()
depth=0
# print(init_state)
### Rollout return ###
finalState = env.get_best_terminal_state(init_state,depth)
promptGen = env.get_prompt_from_state(finalState)
filteredGen=env.trim_with_stopwords(promptGen)

print('Filtered relevant code with stop words {}-->\n{}\n'.format(env.stopwords, filteredGen))
#### Get next best state ###

best_prediction = np.argmax(env.getLLMestimates(init_state))
next_state = env.next_state(init_state,best_prediction)
prompt = env.get_prompt_from_state(next_state)
depth+=1
print('2',prompt)
### Again, get the next best state ###
best_prediction = np.argmax(env.getLLMestimates(next_state))
next_state = env.next_state(next_state,best_prediction)
prompt = env.get_prompt_from_state(next_state)
depth+=1
print('3',prompt)


decoded state 'def hello_world()'
decoded state 'def hello_world() {'
decoded state 'def hello_world() {\n'
decoded state 'def hello_world() {\n    '
decoded state 'def hello_world() {\n    return'
decoded state 'def hello_world() {\n    return "'
decoded state 'def hello_world() {\n    return "Hello'
decoded state 'def hello_world() {\n    return "Hello World'
decoded state 'def hello_world() {\n    return "Hello World!"'
decoded state 'def hello_world() {\n    return "Hello World!";'
decoded state 'def hello_world() {\n    return "Hello World!";\n'
decoded state 'def hello_world() {\n    return "Hello World!";\n}'
decoded state 'def hello_world() {\n    return "Hello World!";\n}\n'
decoded state 'def hello_world() {\n    return "Hello World!";\n}\n\n'
Filtered relevant code with stop words ['\n\n']-->
def hello_world() {
    return "Hello World!";
}

2 def hello_world() {
3 def hello_world() {



In [31]:
!python3 LLMQueryEnv.py

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[[ 4299 23748    62  6894 33529]]
decoded state 'def hello_world():'
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)
decoded state 'def hello_world():\n'
decoded state 'def hello_world():\n    '
decoded state 'def hello_world():\n    return'
decoded state "def hello_world():\n    return '"
decoded state "def hello_world():\n    return 'Hello"
decoded state "def hello_world():\n    return 'Hello World"
decoded state "def hello_world():\n    return 'Hello World!'"
decoded state "def hello_world():\n    return 'Hello World!'\n"
decoded state "def hello_world():\n    return 'Hello World!'\n\n"
Filtered relevant code with stop words ['\n\n']-->
def hello_world():
    return 'Hello World!'

1 d