<a href="https://colab.research.google.com/github/jared-ni/6.8610-project/blob/main/new_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install datasets
!pip install spacy
!pip install scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz
!pip install googletrans==4.0.0-rc1
!pip install deep-translator
!pip install transformers
!pip install torch

Collecting https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz
  Using cached https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz (14.8 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [3]:
from datasets import load_dataset
import pandas as pd
import spacy
from deep_translator import GoogleTranslator
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Flags for which LLMs to use
USE_LLAMA = True
USE_MISTRAL = False
USE_FALCON = False

# Define max_length multiplier for LLM prompts
MAX_LENGTH_MULTIPLIER = 2

In [4]:
# Load datasets into pandas DataFrames
def load_law_dataset():
    ds = load_dataset("casehold/casehold", "all")
    train_df = pd.DataFrame(ds['train'])
    test_df = pd.DataFrame(ds['test'])
    validation_df = pd.DataFrame(ds['validation'])
    law_dataset = pd.concat([train_df, test_df, validation_df], ignore_index=True)['citing_prompt']
    return law_dataset

def load_medical_dataset():
    ds = load_dataset("zhengyun21/PMC-Patients")
    train_df = pd.DataFrame(ds['train'])
    medical_dataset = train_df['patient']
    return medical_dataset

# Combine datasets
def load_all_datasets():
    law_dataset = load_law_dataset()
    medical_dataset = load_medical_dataset()
    return [law_dataset, medical_dataset]


In [5]:
# Load SpaCy model
def load_spacy_model(model_path='en_core_sci_sm'):
    return spacy.load(model_path)

# Extract entities from text
def extract_entities(nlp, text):
    doc = nlp(text)
    return [ent.text for ent in doc.ents]

# Translate entities to a target language
def translate_entities(entities, target_lang):
    translations = [GoogleTranslator(source='auto', target=target_lang).translate(entity) for entity in entities]
    return translations


In [9]:
from huggingface_hub import login
llama_token = "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".strip()
login(llama_token)

# Load Llama model
def load_llama_model():
    llama_model_name = "meta-llama/Llama-2-7b-chat-hf"
    llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name)
    llama_model = AutoModelForCausalLM.from_pretrained(llama_model_name, torch_dtype="auto", device_map="auto")
    return llama_tokenizer, llama_model

# Generate text with Llama
def llama_generate_text(tokenizer, model, prompt, max_length):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        inputs["input_ids"],
        max_length=max_length,
        eos_token_id=tokenizer.eos_token_id,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        do_sample=True
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response


In [None]:
def load_mistral_model():
    model_name = "mistralai/Mistral-7B-Instruct-v0.3"
    mistral_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    mistral_model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
    )
    if mistral_tokenizer.pad_token is None:
        mistral_tokenizer.pad_token = mistral_tokenizer.eos_token
    mistral_model.config.pad_token_id = mistral_tokenizer.pad_token_id
    return mistral_tokenizer, mistral_model

def mistral_generate_text(tokenizer, model, prompt, max_length):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to("cuda")
    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=max_length,
        pad_token_id=model.config.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        top_k=50,
        do_sample=False
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response


In [None]:
def load_falcon_model():
    falcon_model_name = "tiiuae/falcon-7b-instruct"
    falcon_tokenizer = AutoTokenizer.from_pretrained(falcon_model_name, trust_remote_code=True)
    falcon_model = AutoModelForCausalLM.from_pretrained(
        falcon_model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
    )
    if falcon_tokenizer.pad_token is None:
        falcon_tokenizer.pad_token = falcon_tokenizer.eos_token
    falcon_model.config.pad_token_id = falcon_tokenizer.pad_token_id
    return falcon_tokenizer, falcon_model

def falcon_generate_text(tokenizer, model, prompt, max_length):
    inputs = falcon_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to("cuda")
    outputs = falcon_model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=max_length,
        pad_token_id=falcon_model.config.pad_token_id,
        eos_token_id=falcon_model.config.eos_token_id,
        temperature=0.7,
        top_k=50,
        top_p=0.9,
        do_sample=False
    )
    response = falcon_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

In [26]:
# Placeholder function for calculating JTC score
def calculate_JTC(translations):
    # TODO: Replace with actual JTC score calculation logic
    return 0

In [29]:
def run_pipeline():
    # Load datasets
    datasets = load_all_datasets()

    # Load NLP model
    nlp = load_spacy_model()

    # Load LLMs
    llama_tokenizer, llama_model = load_llama_model() if USE_LLAMA else (None, None)
    mistral_tokenizer, mistral_model = load_mistral_model() if USE_MISTRAL else (None, None)
    falcon_tokenizer, falcon_model = load_falcon_model() if USE_FALCON else (None, None)

    target_lanugages = ["Chinese", "French"]

    for dataset in datasets:
        for i, text in enumerate(dataset[:10]):  # Iterate through the first 10 entries for testing
            text = " ".join(text.split()[:30])  # Truncate to the first 30 words
            entities = extract_entities(nlp, text)

            for llm_name, tokenizer, model, is_active in [
                ("Llama", llama_tokenizer, llama_model, USE_LLAMA),
                ("Mistral", mistral_tokenizer, mistral_model, USE_MISTRAL),
                ("Falcon", falcon_tokenizer, falcon_model, USE_FALCON),
            ]:
                if not is_active:
                    continue

                # Regular translations
                regular_translations = []
                for target_lang in target_lanugages:
                    for k in range(3):
                        prompt = f"Translate the following text to {target_lang}: {text}"
                        max_length = len(prompt) * MAX_LENGTH_MULTIPLIER
                        regular_translations.append(
                            llama_generate_text(tokenizer, model, prompt, max_length)
                        )

                # LEAP translations
                leap_translations = []
                for target_lang in target_lanugages:
                    for k in range(3):
                        prompt = f"Translate the following text to {target_lang} using these mappings {entities}: {text}"
                        max_length = len(prompt) * MAX_LENGTH_MULTIPLIER
                        leap_translations.append(
                            llama_generate_text(tokenizer, model, prompt, max_length)
                        )

                # Calculate JTC scores
                calculate_JTC(regular_translations)
                calculate_JTC(leap_translations)


In [None]:
run_pipeline()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

