In [5]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import numpy as np

In [6]:
#for now using google flan, will switch to llama once we set it up on dsmlp
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [7]:
def classify_long_text_with_flan(text, token_limit=512):
    """given a scientific paper as text input, determines if a paper is relevant.
    splits the text into chunks of 512 tokens by default and if any of the chunks is relevant, returns 1"""
    # 
    prompt_template = (
    "Is there any mention of passivating molecules or techniques aimed at improving perovskite solar cell stability in this research paper? "
    "Classify as 'Relevant' if the text refers to either of these topics, even indirectly. Otherwise, answer 'Not Relevant'.\n\nContent: {}"
)
    
    # Tokenize and split into chunks
    tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
    chunk_size = token_limit - len(tokenizer(prompt_template.format(""))["input_ids"])  # Adjust for prompt length
    chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]

    predictions = []
    for chunk in chunks:
        # Decode chunk and add to prompt
        content = tokenizer.decode(chunk, skip_special_tokens=True)
        prompt = prompt_template.format(content)
        
        # Generate response using FLAN
        inputs = tokenizer(prompt, return_tensors="pt", max_length=token_limit, truncation=True)
        outputs = model.generate(**inputs, max_length=50)
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

        if prediction == "Relevant":
            predictions.append(1)
        else:
            predictions.append(0)
    
    # Majority voting
    final_prediction = 1 if predictions.count(1) > 0 else 0
    return final_prediction

In [8]:
#loading in test dataset
df = pd.read_csv('../data/merged_label.csv')

In [9]:
#testing the model on 5 rows we know are relevant
df_good = df[df['label'] == 1]
df_good_sample = df_good.head()

df_good_sample['pred'] = df_good_sample['text'].apply(classify_long_text_with_flan)
df_good_sample

Token indices sequence length is longer than the specified maximum sequence length for this model (13857 > 512). Running this sequence through the model will result in indexing errors
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_good_sample['pred'] = df_good_sample['text'].apply(classify_long_text_with_flan)


Unnamed: 0,link,label,text,pred
0,https://www.nature.com/articles/s41566-019-0398-2,1,Surface passivation of perovskite film for eff...,1
1,https://www.nature.com/articles/s41560-020-007...,1,Intact 2D/3D halide junction perovskite solar ...,1
2,https://www.nature.com/articles/s41467-021-236...,1,Multication perovskite 2D/3D interfaces form v...,0
3,https://doi.org/10.1038%2Fs41586-022-04604-5,1,Stability-limiting heterointerfaces of perovsk...,0
4,https://doi.org/10.1038%2Fs41467-022-30426-0,1,Imaging and quantifying non-radiative losses a...,1


In [10]:
df['pred'] = df['text'].apply(classify_long_text_with_flan)

In [12]:
accuracy = np.mean(df['pred'] == df['label'])
tp = ((df['pred'] == 1) & (df['label'] == 1)).sum()
fn = ((df['pred'] == 0) & (df['label'] == 1)).sum()
tn = ((df['pred'] == 0) & (df['label'] == 0)).sum()
fp = ((df['pred'] == 1) & (df['label'] == 0)).sum()

sensitivity = tp / (tp + fn) if (tp + fn) != 0 else 0
specificity = tn / (tn + fp) if (tn + fp) != 0 else 0
recall = tp / (tp + fn)

ber = 1 - 0.5 * (sensitivity + specificity)

In [21]:
#writing results to model result folder
results_path = '../data/model_results/flan.csv'
new_row = pd.DataFrame({ 'model_name': 'FLAN-T5', 'accuracy': accuracy, 'recall': recall, 'BER': ber }, index=[0])
new_row.to_csv(results_path, index=False)