In [None]:
import torch
import pandas as pd
from transformers import (
    MambaConfig, MambaForCausalLM, AutoTokenizer
)
from jinja2 import Environment

In [None]:
device = 'cuda'

# model_path = './train_exp_1b/complete' # '/kaggle/input/prompt_reversal_hf/transformers/1b/1' # './train_exp_1b/complete'
# model_path =
# model_path = './train_exp_3b'
model_path = './train_exp_4/complete'

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = MambaForCausalLM.from_pretrained(
    model_path, 
    device_map=device, 
    torch_dtype=torch.bfloat16
)

start_sub = '<|PROMPT|>'
end_sub = '<|END_PROMPT|>'

prompt = """<|ORIGINAL_TEXT|>{{ original_text }}<|END_ORIGINAL_TEXT|>
<|GENERATED_TEXT|>{{ rewritten_text }}<|END_GENERATED_TEXT|>
<|PROMPT|>"""
jinja_env = Environment()
prompt_template = jinja_env.from_string(prompt)

In [None]:
test_df = pd.read_csv("data/kaggle_3p_data/data/juanmerinobermejo/rewritten_texts_csv.csv", encoding_errors='ignore', on_bad_lines='skip')

test_df.insert(0, 'id', range(0, len(test_df)))
test_df.rename(columns={'prompt': 'gt_rewrite_prompt'}, inplace=True)

In [None]:
test_df = test_df[:5]
test_df

In [None]:
from tqdm import tqdm

# loop through everything
with tqdm(total=test_df.shape[0]) as pbar:
    for idx, row in test_df.iterrows():
        # generate input prompt
        
        prompt = prompt_template.render(
            original_text=row['original_text'],
            rewritten_text=row['rewritten_text']
        )
        
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
        gen = model.generate(
            input_ids,
            # max_length=2048,
            max_new_tokens=200    
        )
        out = tokenizer.batch_decode(gen)[0]
        start_idx = out.find(start_sub)
        end_idx = out.find(end_sub)
        res = out[start_idx + len(start_sub): end_idx]
        print(row['gt_rewrite_prompt'])
        print(res)
        test_df.loc[idx, 'rewrite_prompt'] = res # 'Improve this text'
        
        pbar.update(1)

In [None]:
# code to score

import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity

st_model = SentenceTransformer('sentence-transformers/sentence-t5-base')

def get_sharpened_cosine_similarity(text1, text2):
    embeddings1 = st_model.encode(text1)
    embeddings2 = st_model.encode(text2)
    cosine_score = util.cos_sim(embeddings1, embeddings2)
    # print(cosine_score) 
    return (cosine_score[0] ** 3).numpy()[0]

def calc_prompt_similarity(row):
    return get_sharpened_cosine_similarity(row['gt_rewrite_prompt'], row['rewrite_prompt'])

In [None]:
# calc and show score

test_df['score'] = test_df.apply(lambda row: calc_prompt_similarity(row), axis=1)

test_df['score'].mean()


In [None]:
print(test_df.iloc[0]['gt_rewrite_prompt'])
print(test_df.iloc[0]['rewrite_prompt'])
print(test_df.iloc[0]['score'])

## Results

### Model: `./train_exp_4/complete`

LB Score: 0.72472763


In [None]:
# results

./train_exp_4/complete
