In [None]:
import os
import galai as gal
import torch

from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda")

galactica_model = gal.load_model("standard", num_gpus=1)

biomedlm_tokenizer = GPT2Tokenizer.from_pretrained("stanford-crfm/BioMedLM")
biomedlm_model = GPT2LMHeadModel.from_pretrained("stanford-crfm/BioMedLM").to(device)

In [None]:
import pandas as pd

from tqdm.auto import tqdm
tqdm.pandas()

df = pd.read_csv("./cochrane_reviews_latest_by_topic_20230223.csv", index_col=False)

In [None]:
df.head()

In [None]:
def get_galactica_output_title(row):
    title = row['title']
    prompt = 'Title: ' + title + '\n\n'
    # using max length of 2048 which is the max for galactica
    # the parameteres for galactica generate method is from galactica's github (https://github.com/paperswithcode/galai)
    return galactica_model.generate(prompt, new_doc=True, top_p=0.7, max_length=2048)

In [None]:
def get_galactica_output_hashtag(row):
    title = row['title']
    prompt = '# ' + title + '\n\n'
    # using max length of 2048 which is the max for galactica
    # the parameteres for galactica generate method is from galactica's github (https://github.com/paperswithcode/galai)
    return galactica_model.generate(prompt, new_doc=True, top_p=0.7, max_length=2048)

In [None]:
def get_biomedlm_output(row):
    title = row['title']
    prompt = 'Title: ' + title
    input_ids = biomedlm_tokenizer.encode(
        prompt, return_tensors="pt"
    ).to(device)
    
    # using max length of 1024 which is the max for biomedlm
    output = biomedlm_model.generate(input_ids, do_sample=True, max_length=1024, top_k=50)

    return biomedlm_tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
df['galactica_output_title'] = df.progress_apply(get_galactica_output_title, axis=1)

In [None]:
df['galactica_output_hashtag'] = df.progress_apply(get_galactica_output_hashtag, axis=1)

In [None]:
df['biomedlm_output'] = df.progress_apply(get_biomedlm_output, axis=1)

In [None]:
df

In [None]:
df.to_csv('./llm_outputs.csv', index=False)