# This is based off of some of the baseline inference scripts that Lawrence ran for full sentence holdings. Here I run it for parenthetical holdings.

In [None]:
!pip install transformers
!pip install --upgrade transformers
!pip install textwrap
!pip install langchain
!pip install gc
!pip install torch
!pip install sentencepiece

In [None]:
!pip install accelerate # I got an error when loadin the model that this is required when loading the model
!pip install -q -U bitsandbytes

In [None]:
import os
# from dotenv import load_dotenv
from huggingface_hub import login
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import json
import textwrap
from langchain import HuggingFacePipeline
from langchain import PromptTemplate,  LLMChain
from langchain.memory import ConversationBufferMemory
import pandas as pd
import time
import gc

In [None]:
# load_dotenv()
# HUGGINGFACEHUB_API_TOKEN = os.getenv("HF_AUTH_TOKEN")
HUGGINGFACEHUB_API_TOKEN = "<token>" #load from hf_token.txt
login(token=HUGGINGFACEHUB_API_TOKEN)

In [None]:
# This code was borrowed from another notebook, as I needed a quick fix to using an accelerator: https://brev.dev/blog/fine-tuning-llama-2-your-own-data
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

In [None]:
# This is new
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# This is new
file_path_test = '/content/drive/MyDrive/Lang Gen Project/qlora_data/cleaned_test_qlora.jsonl'

In [None]:
model_directory = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_directory)

model = AutoModelForCausalLM.from_pretrained(model_directory,
                                             device_map='auto',
                                             torch_dtype=torch.float16,
                                             # load_in_8bit=True,
                                             load_in_4bit=True
                                             )

In [None]:
pipe = pipeline("text-generation",
                model=model,
                tokenizer= tokenizer,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                max_new_tokens = 1024,
                do_sample=True,
                top_k=30,
                num_return_sequences=1,
                eos_token_id=tokenizer.eos_token_id
                )

# The cell below has important functions
## It's best to run it, regardless of which task and model you're using

In [None]:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
    SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
    prompt_template =  B_INST + SYSTEM_PROMPT + instruction + E_INST
    return prompt_template

def cut_off_text(text, prompt):
    cutoff_phrase = prompt
    index = text.find(cutoff_phrase)
    if index != -1:
        return text[:index]
    else:
        return text

def remove_substring(string, substring):
    return string.replace(substring, "")



def generate(text):
    prompt = get_prompt(text)
    with torch.autocast('cuda', dtype=torch.bfloat16):
        inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
        outputs = model.generate(**inputs,
                                 max_new_tokens=1024,
                                 eos_token_id=tokenizer.eos_token_id,
                                 pad_token_id=tokenizer.eos_token_id,
                                 )
        final_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        final_outputs = cut_off_text(final_outputs, '</s>')
        final_outputs = remove_substring(final_outputs, prompt)

    return final_outputs#, outputs

def parse_text(text):
        wrapped_text = textwrap.fill(text, width=100)
        print(wrapped_text +'\n\n')
        # return assistant_text

def count_words(input_string):
    words = input_string.split(" ")
    return len(words)

def summarize_chunks(chunks, model, tokenizer):
    summaries = []
    for chunk in chunks:
        output = llm_chain.run(chunk)
        # print(count_words(output))
        # parse_text(output)
        summaries.append(output)
    return summaries

def create_final_summary(summaries):
    # Option 1: Just join the summaries
    final_summary = ' '.join(summaries)

    # Option 2: Apply another round of summarization (can be useful for coherence)
    # final_summary = generate(final_summary)  # This is recursive and might degrade quality

    return final_summary

def chunk_text_with_overlap(text, chunk_word_count, overlap_word_count):
    words = text.split()
    chunks = []
    index = 0

    while index < len(words):
        current_chunk_end = index + chunk_word_count

        # We don't want to overshoot the list of words for the current chunk
        current_chunk_end = min(current_chunk_end, len(words))

        chunk = " ".join(words[index:current_chunk_end])
        chunks.append(chunk)

        index += chunk_word_count - overlap_word_count

        # If the calculated index doesn't advance (due to large overlap), we force it to advance to avoid an infinite loop
        if index >= current_chunk_end:
            index = current_chunk_end

    return chunks

# Function to load data from the JSON file and extract the desired information.
def load_and_extract_data(file_path):
    # Reading the file.
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)  # Parsing the JSON data.

    for o in data["casebody"]["data"]["opinions"]:
        if o["type"] == "majority":
            return o["text"]
        else:
            return None

def save_summary_to_text(summary, output_folder, file_path, condensed=False):
    """
    Save the content of 'summary' to a text file derived from the name of the input file.
    """
    # Extract the base file name without extension
    base_name = os.path.splitext(os.path.basename(file_path))[0]

    if condensed:
        summary_file_name = f"{base_name}_condensed_summary.txt"
    else:
        summary_file_name = f"{base_name}_summary.txt"

    summary_file_path = os.path.join(output_folder, summary_file_name)

    try:
        with open(summary_file_path, 'w', encoding='utf-8') as file:
            file.write(summary)
        print(f"Summary successfully written to {summary_file_name}")
    except IOError as e:
        print(f"Unable to write to file: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

def read_file(file_path):
    """
    Read the content of a text file.

    :param file_path: str, path to the file to read.
    :return: str, content of the file.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
        return content
    except IOError as e:
        print(f"Error reading file {file_path}: {e}")
        return None

# Parenthetical generation
## - Using Llama2

In [None]:
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
# Instruction and prompt slightly rephrased.
instruction = "Use the case document to extract the concise holding and phrase it as a parenthetical, which should look something like this: holding that the balance between costs and benefits comes out against applying the exclusionary rule in civil deportation hearings. {text}"
system_prompt = "You are a legal expert who specializes in extracting accurate and concise parenthetical holdings from case documents. Give only the holdings, no other breakdowns or extra text."

template = get_prompt(instruction, system_prompt)
print(template)

prompt = PromptTemplate(template=template, input_variables=["text"])
llm_chain = LLMChain(prompt=prompt, llm=llm)

In [None]:
# This is new
print("Prompt is:", prompt)

In [None]:
# This is new
test_df = pd.read_json(file_path_test, lines=True)

In [None]:
len(test_df)

In [None]:
test_input = test_df.iloc[0]["input"]
print(test_input)

In [None]:
test_input_reference = test_df.iloc[0]["output"]
print(test_input_reference)

In [None]:
test_output = llm_chain.run(test_input)

In [None]:
print(test_output)

In [None]:
test_df1 = test_df.drop(13)
test_df1 = test_df1.reset_index(drop=True)

In [None]:
results_df = pd.DataFrame(columns=["Input", "Prediction", "Reference"])
num_nulls = 0

for i in range(len(test_df1)):
  print(f"Predicting on input number: {i}")
  input_txt = test_df.iloc[i]["input"]
  # output_txt = llm_chain.run(input_txt)

  try:
    # Attempt to generate output
    output_txt = llm_chain.run(input_txt)
  except RuntimeError:
    # If a RuntimeError occurs, use a default NULL value
    print("Generation failed, inserting NULL value")
    output_txt = "NULL"
    num_nulls += 1
  reference_txt = test_df.iloc[i]["output"]

  temp_df = pd.DataFrame({'Input': [input_txt], 'Prediction': [output_txt], 'Reference': [reference_txt]})

  results_df = pd.concat([results_df, temp_df], ignore_index=True)
  torch.cuda.empty_cache()
  gc.collect()
print("Inference has finished")

In [None]:
print(f"The number of null values inserted was {num_nulls}")

In [None]:
check_input = test_df.iloc[13]["input"]
print(check_input)

In [None]:
word_count = len(check_input.split())
print(word_count)

In [None]:
out_path = "/content/drive/MyDrive/Lang Gen Project/Results/llama2_predictions.csv"

In [None]:
results_df.to_csv(out_path, index=False)