In [19]:
import gc
import logging
from time import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import ctypes
from functools import partial

import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# For RAG
import faiss
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_from_disk, Dataset
from sentence_transformers import SentenceTransformer

# For LLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file

In [20]:
class my_SentenceTransformer:
    def __init__(self, checkpoint, device="cuda:0"):
        self.device = device
        self.checkpoint = checkpoint
        self.model = AutoModel.from_pretrained(checkpoint).to(self.device).half()
        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    def transform(self, batch):
        tokens = self.tokenizer(batch["text"], truncation=True, padding=True, return_tensors="pt", max_length=MAX_SEQ_LEN)
        return tokens.to(self.device)  

    def get_dataloader(self, sentences, batch_size=32):
        sentences = ["Represent this sentence for searching relevant passages: " + x for x in sentences]
        dataset = Dataset.from_dict({"text": sentences})
        dataset.set_transform(self.transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        return dataloader

    def encode(self, sentences, show_progress_bar=False, batch_size=32):
        dataloader = self.get_dataloader(sentences, batch_size=batch_size)
        pbar = tqdm(dataloader) if show_progress_bar else dataloader

        embeddings = []
        for batch in pbar:
            with torch.no_grad():
                e = self.model(**batch).pooler_output
                e = F.normalize(e, p=2, dim=1)
                embeddings.append(e.detach().cpu().numpy())
        embeddings = np.concatenate(embeddings, axis=0)
        return embeddings

In [58]:
import os
import re

files=[i for i in os.listdir('../mmlu-dataset') if re.findall('csv', i)]
files

['MMLU_train_set.csv',
 'MMLU_science_set.csv',
 'MMLU_physics_set.csv',
 'MMLU_dev_set.csv',
 'MMLU_sampling_train_set.csv']

In [71]:
## bge-small-faiss embedding 

# Load data
df = pd.read_csv(f"../mmlu-dataset/{files[4]}")
df['A']=df['A'].apply(str)
df['B']=df['B'].apply(str)
df['C']=df['C'].apply(str)
df['D']=df['D'].apply(str)
# df['E']=df['E'].apply(str) ## sienceexam

NUM_TITLES = 5
MAX_SEQ_LEN = 512
MODEL_PATH = "output/bge-small-faiss/"

## load embedding model
start = time()
print(f"Starting prompt embedding, t={time() - start :.1f}s")
model = my_SentenceTransformer(MODEL_PATH, device="cuda:2") ## 직접 정의한 sentencetransformer 사용

## Get query embedding
# f = lambda row : " ".join([row["prompt"], row["A"], row["B"], row["C"], row["D"], row["E"]]) ## scienceexam
f = lambda row : " ".join([row["question"], row["A"], row["B"], row["C"], row["D"]]) ## MMLU
inputs = df.apply(f, axis=1).values # better results than prompt only
prompt_embeddings = model.encode(inputs, show_progress_bar=False)

## faiss wikipedia index 불러오기
print(f"Loading faiss index, t={time() - start :.1f}s")
faiss_index = faiss.read_index(MODEL_PATH + '/faiss.index')
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index) # OOM이 일어날 때는 지우기 

## top-5의 관련있는 인덱스 가져오기 
print(f"Starting text search, t={time() - start :.1f}s")
search_index = faiss_index.search(np.float32(prompt_embeddings), NUM_TITLES)[1]

## 인덱스를 찾아서 실제 문서 가져오기 
print(f"Starting context extraction, t={time() - start :.1f}s")
dataset = load_from_disk("dataset/all-paraphs-parsed-expanded")
for i in range(len(df)):
    df.loc[i, "context"] = "-" + "\n-".join([dataset[int(j)]["text"] for j in search_index[i]])

Starting prompt embedding, t=0.0s
Loading faiss index, t=45.3s
Starting text search, t=49.7s
Starting context extraction, t=50.4s


In [None]:
df = df.drop(columns='id')

In [72]:
df.head()

Unnamed: 0,question,A,B,C,D,answer,context
0,"CHANGSHA,Feb.14(Xinhua)----Areas of China affe...",one week,two weeks,one month,two months,C,"-The provinces of Hubei, Henan, Shandong, Jian..."
1,1English people have three meals a day. They ...,"cakes, fruit or ice cream",hamburgers or sandwiches,soup and rice,"some porridge, eggs and meat",D,-People usually have two or three meals a day....
2,XI'AN---Seven people died in a fire early on W...,The news report didn't mention the loss caused...,After reading the report we know how the fire ...,The reporter tended to think the bomb had some...,The police refused to admit the bomb had anyth...,C,-Ronan Point was a 23-story council tower bloc...
3,Winter is dangerous because it's so difficult ...,Traffic accidents take place easily in winter.,Fog and melting snow often cause car accidents.,The stopping distance on ice is as long as the...,In winter you should drive your car with great...,C,-Land travel Ice forming on roads is a dangero...
4,Telepathy: Mind-to-mind Contact Telepathy is t...,Help them have a strong desire to communicate.,Separate them all the time.,Help them link up their unconscious minds.,Let them spend much time together.,B,-Telepathy (from Ancient Greek τῆλε (têle) 'di...


In [74]:
df.to_csv(f"../mmlu-dataset/{files[4]}")

In [37]:
## 뽑힌 context의 내용 확인하기 

context=df['context'].tolist()

for idx,q in enumerate(df['prompt'].tolist()):
    print (q)
    print(context[idx])
    print()
    print()
    print()

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [38]:
## context의 길이 알아보기 

lengths=[]
for c in context:
    lengths.append(len(c))

hap=0
for l in lengths:
    hap+=l

print(lengths)
print(hap/200)

[2374, 1301, 3455, 4670, 10452, 4127, 4777, 7681, 3167, 2940, 2488, 6425, 3673, 3493, 10356, 3375, 2482, 2916, 4378, 3397, 3673, 6444, 4134, 2942, 4719, 3193, 6614, 4204, 2803, 3773, 3237, 2787, 3039, 3270, 2974, 2847, 2892, 4385, 4191, 3060, 2683, 5736, 2526, 5083, 4654, 5091, 5980, 5918, 4403, 3716, 3488, 3292, 3559, 2375, 3419, 549, 634, 4314, 4079, 5660, 4543, 5049, 6895, 1420, 9507, 3025, 2707, 3025, 2994, 2922, 2324, 2324, 7101, 6268, 3068, 3163, 4301, 2435, 3377, 2676, 5560, 5724, 7228, 6120, 3149, 3019, 2997, 2061, 2988, 3119, 6337, 8559, 12943, 3177, 4661, 4153, 3169, 3384, 3210, 3354, 2778, 3129, 4629, 5263, 4091, 3027, 1982, 3952, 3685, 4206, 4510, 4539, 6632, 3854, 4938, 5167, 3176, 5510, 2559, 5564, 3302, 4537, 2693, 5864, 10407, 3162, 2159, 3021, 2692, 3952, 5576, 5011, 3413, 3689, 4776, 2547, 5453, 2396, 2383, 3173, 2786, 4966, 4394, 3780, 4801, 2754, 1383, 4452, 4014, 5787, 7748, 1839, 5194, 9359, 6065, 2645, 3047, 5707, 4428, 5802, 2693, 4376, 3300, 2967, 3163, 2841, 7