In [1]:
import gc
import logging
from time import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import ctypes
from functools import partial

import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# For RAG
import faiss
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_from_disk, Dataset
from sentence_transformers import SentenceTransformer

# For LLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file

In [2]:
# Function to clean RAM & vRAM
def clean_memory():
    gc.collect()
    ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

In [3]:
class my_SentenceTransformer:
    def __init__(self, checkpoint, device="cuda:0"):
        self.device = device
        self.checkpoint = checkpoint
        self.model = AutoModel.from_pretrained(checkpoint).to(self.device).half()
        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    def transform(self, batch):
        tokens = self.tokenizer(batch["text"], truncation=True, padding=True, return_tensors="pt", max_length=MAX_SEQ_LEN)
        return tokens.to(self.device)  

    def get_dataloader(self, sentences, batch_size=32):
        sentences = ["Represent this sentence for searching relevant passages: " + x for x in sentences]
        dataset = Dataset.from_dict({"text": sentences})
        dataset.set_transform(self.transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        return dataloader

    def encode(self, sentences, show_progress_bar=False, batch_size=32):
        dataloader = self.get_dataloader(sentences, batch_size=batch_size)
        pbar = tqdm(dataloader) if show_progress_bar else dataloader

        embeddings = []
        for batch in pbar:
            with torch.no_grad():
                e = self.model(**batch).pooler_output
                e = F.normalize(e, p=2, dim=1)
                embeddings.append(e.detach().cpu().numpy())
        embeddings = np.concatenate(embeddings, axis=0)
        return embeddings

### bge-small-faiss embedding 

In [5]:
## bge-small-faiss embedding 

# Load data
df = pd.read_csv("../dataset/test.csv", index_col="id")

NUM_TITLES = 5
MAX_SEQ_LEN = 512
MODEL_PATH = "output/bge-small-faiss/"

## load embedding model
start = time()
print(f"Starting prompt embedding, t={time() - start :.1f}s")
model = my_SentenceTransformer(MODEL_PATH, device="cuda:2") ## 직접 정의한 sentencetransformer 사용

## Get query embedding
f = lambda row : " ".join([row["prompt"], row["A"], row["B"], row["C"], row["D"], row["E"]])
inputs = df.apply(f, axis=1).values # better results than prompt only
prompt_embeddings = model.encode(inputs, show_progress_bar=False)

## faiss wikipedia index 불러오기
print(f"Loading faiss index, t={time() - start :.1f}s")
faiss_index = faiss.read_index(MODEL_PATH + '/faiss.index')
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index) # OOM이 일어날 때는 지우기 

## top-5의 관련있는 인덱스 가져오기 
print(f"Starting text search, t={time() - start :.1f}s")
search_index = faiss_index.search(np.float32(prompt_embeddings), NUM_TITLES)[1]

## 인덱스를 찾아서 실제 문서 가져오기 
print(f"Starting context extraction, t={time() - start :.1f}s")
dataset = load_from_disk("dataset/all-paraphs-parsed-expanded")
for i in range(len(df)):
    df.loc[i, "context"] = "-" + "\n-".join([dataset[int(j)]["text"] for j in search_index[i]])

Starting prompt embedding, t=0.0s
Loading faiss index, t=2.6s
Starting text search, t=13.1s
Starting context extraction, t=13.2s


In [6]:
df.head(10)

Unnamed: 0_level_0,prompt,A,B,C,D,E,context
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,Which of the following statements accurately d...,MOND is a theory that reduces the observed mis...,MOND is a theory that increases the discrepanc...,MOND is a theory that explains the missing bar...,MOND is a theory that reduces the discrepancy ...,MOND is a theory that eliminates the observed ...,-MOND is an example of a class of theories kno...
1,Which of the following is an accurate definiti...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,-In such systems we can define a certain time-...
2,Which of the following statements accurately d...,The triskeles symbol was reconstructed as a fe...,The triskeles symbol is a representation of th...,The triskeles symbol is a representation of a ...,The triskeles symbol represents three interloc...,The triskeles symbol is a representation of th...,"-Classical Antiquity The triskeles proper, com..."
3,What is the significance of regularization in ...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,-Regularization: Classical physics theory brea...
4,Which of the following statements accurately d...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,-Several qualitative observations can be made ...
5,Which of the following statements accurately d...,Gauss's law holds only for situations involvin...,"Gauss's law holds in all cases, but it is most...","Gauss's law, which applies equally to all elec...",Gauss's law only holds for electric fields wit...,"Gauss's law, which holds for all situations, i...",-While the electric flux is not affected by ch...
6,Which of the following statements accurately d...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex dep...,-An inductive dimension may be defined inducti...
7,Which of the following statements accurately d...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,-Magnetic blocking temperature The so-called m...
8,What is the term used in astrophysics to descr...,Blueshifting,Redshifting,Reddening,Whitening,Yellowing,-The interactions and phenomena summarized in ...
9,What is the role of axioms in a formal theory?,Basis statements called axioms form the founda...,Axioms are supplementary statements added to a...,Axioms are redundant statements that can be de...,The axioms in a theory are used for experiment...,The axioms in a formal theory are added to pro...,"-Thus, an axiom is an elementary basis for a f..."


In [7]:
## 뽑힌 context의 내용 확인하기 

context=df['context'].tolist()

for idx,q in enumerate(df['prompt'].tolist()):
    print (q)
    print(context[idx])
    print()
    print()
    print()

Which of the following statements accurately describes the impact of Modified Newtonian Dynamics (MOND) on the observed "missing baryonic mass" discrepancy in galaxy clusters?
-MOND is an example of a class of theories known as modified gravity, and is an alternative to the hypothesis that the dynamics of galaxies are determined by massive, invisible dark matter halos. Since Milgrom's original proposal, proponents of MOND have claimed to successfully predict a variety of galactic phenomena that they state are difficult to understand as consequences of dark matter.Though MOND explains the anomalously great rotational velocities of galaxies at their perimeters, it does not fully explain the velocity dispersions of individual galaxies within galaxy clusters. MOND reduces the discrepancy between the velocity dispersions and clusters' observed missing baryonic mass from a factor of around 10 to a factor of about 2. However, the residual discrepancy cannot be accounted for by MOND, requiring

In [8]:
len(context)

200

In [9]:
## context의 길이 알아보기 

lengths=[]
for c in context:
    lengths.append(len(c))

hap=0
for l in lengths:
    hap+=l

print(lengths)
print(hap/200)

[9562, 3858, 6799, 4290, 3539, 2936, 3111, 3021, 6251, 2681, 7054, 4999, 2457, 3272, 4448, 4768, 3332, 3100, 4957, 1840, 6876, 2498, 3524, 8186, 3273, 2832, 3692, 3777, 4511, 4766, 4283, 2735, 3589, 4186, 3943, 3983, 3342, 5039, 4856, 4949, 4480, 4083, 6015, 6461, 4481, 3866, 4414, 4485, 7593, 5835, 6118, 4296, 2887, 3045, 2768, 3039, 6931, 5938, 4752, 3111, 3753, 5259, 4809, 4598, 4959, 4977, 2875, 3318, 5459, 8200, 5055, 8067, 3393, 6227, 2798, 5116, 2871, 3624, 4826, 6056, 3589, 2312, 5429, 3992, 1977, 3621, 3811, 2766, 3817, 2927, 2003, 2826, 2427, 5226, 4841, 3139, 2931, 3607, 3777, 3720, 5573, 4040, 4078, 7686, 3759, 4121, 5226, 4844, 3474, 5653, 4289, 6084, 3535, 4864, 2780, 2004, 4768, 4803, 4502, 3094, 3208, 1837, 3053, 3962, 5512, 2271, 4660, 2996, 8884, 6903, 3261, 7200, 6230, 2998, 4039, 14339, 5961, 2901, 5034, 3586, 5015, 7332, 3157, 3879, 5621, 2322, 5950, 9691, 3634, 6744, 3670, 6665, 3962, 2391, 5690, 9661, 4928, 3631, 8049, 6471, 3198, 2036, 2774, 2086, 4411, 3602, 23

### all-MInLM-L6-v2 embedding

In [10]:
# Load data
df = pd.read_csv("../dataset/test.csv", index_col="id")

## all-MiniLM-L6-v2 embedding 

NUM_TITLES = 5
MAX_SEQ_LEN = 512
MODEL_PATH = "output/all-MiniLM-L6-v2"

## load embedding model
start = time()
print(f"Starting prompt embedding, t={time() - start :.1f}s")
model = SentenceTransformer(MODEL_PATH, device="cuda:2")

## Get query embedding
f = lambda row : " ".join([row["prompt"], row["A"], row["B"], row["C"], row["D"], row["E"]])
inputs = df.apply(f, axis=1).values # better results than prompt only
prompt_embeddings = model.encode(inputs, show_progress_bar=False)

## faiss wikipedia index 불러오기
print(f"Loading faiss index, t={time() - start :.1f}s")
faiss_index = faiss.read_index(MODEL_PATH + '/faiss.index')
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index) # OOM이 일어날 때는 지우기 

## top-5의 관련있는 인덱스 가져오기 
print(f"Starting text search, t={time() - start :.1f}s")
search_index = faiss_index.search(np.float32(prompt_embeddings), NUM_TITLES)[1]

## 인덱스를 찾아서 실제 문서 가져오기 
print(f"Starting context extraction, t={time() - start :.1f}s")
dataset = load_from_disk("dataset/all-paraphs-parsed-expanded")
for i in range(len(df)):
    df.loc[i, "context"] = "-" + "\n-".join([dataset[int(j)]["text"] for j in search_index[i]])

No sentence-transformers model found with name output/all-MiniLM-L6-v2. Creating a new one with MEAN pooling.


Starting prompt embedding, t=0.0s
Loading faiss index, t=0.6s
Starting text search, t=5.2s
Starting context extraction, t=5.2s


In [11]:
df.head(10)

Unnamed: 0_level_0,prompt,A,B,C,D,E,context
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,Which of the following statements accurately d...,MOND is a theory that reduces the observed mis...,MOND is a theory that increases the discrepanc...,MOND is a theory that explains the missing bar...,MOND is a theory that reduces the discrepancy ...,MOND is a theory that eliminates the observed ...,-MOND is an example of a class of theories kno...
1,Which of the following is an accurate definiti...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,-In such systems we can define a certain time-...
2,Which of the following statements accurately d...,The triskeles symbol was reconstructed as a fe...,The triskeles symbol is a representation of th...,The triskeles symbol is a representation of a ...,The triskeles symbol represents three interloc...,The triskeles symbol is a representation of th...,"-Classical Antiquity The triskeles proper, com..."
3,What is the significance of regularization in ...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,-Regularization: Classical physics theory brea...
4,Which of the following statements accurately d...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,-Several qualitative observations can be made ...
5,Which of the following statements accurately d...,Gauss's law holds only for situations involvin...,"Gauss's law holds in all cases, but it is most...","Gauss's law, which applies equally to all elec...",Gauss's law only holds for electric fields wit...,"Gauss's law, which holds for all situations, i...",-While the electric flux is not affected by ch...
6,Which of the following statements accurately d...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex is ...,The dimension of an object in a CW complex dep...,-An inductive dimension may be defined inducti...
7,Which of the following statements accurately d...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,The blocking temperature of an antiferromagnet...,"-Antiferromagnets can couple to ferromagnets, ..."
8,What is the term used in astrophysics to descr...,Blueshifting,Redshifting,Reddening,Whitening,Yellowing,-The interactions and phenomena summarized in ...
9,What is the role of axioms in a formal theory?,Basis statements called axioms form the founda...,Axioms are supplementary statements added to a...,Axioms are redundant statements that can be de...,The axioms in a theory are used for experiment...,The axioms in a formal theory are added to pro...,"-In mathematics and logic, an axiomatic system..."


In [13]:
## 뽑힌 context의 내용 확인하기 

context=df['context'].tolist()

for idx,q in enumerate(df['prompt'].tolist()):
    print (q)
    print(context[idx])
    print()
    print()
    print()

Which of the following statements accurately describes the impact of Modified Newtonian Dynamics (MOND) on the observed "missing baryonic mass" discrepancy in galaxy clusters?
-MOND is an example of a class of theories known as modified gravity, and is an alternative to the hypothesis that the dynamics of galaxies are determined by massive, invisible dark matter halos. Since Milgrom's original proposal, proponents of MOND have claimed to successfully predict a variety of galactic phenomena that they state are difficult to understand as consequences of dark matter.Though MOND explains the anomalously great rotational velocities of galaxies at their perimeters, it does not fully explain the velocity dispersions of individual galaxies within galaxy clusters. MOND reduces the discrepancy between the velocity dispersions and clusters' observed missing baryonic mass from a factor of around 10 to a factor of about 2. However, the residual discrepancy cannot be accounted for by MOND, requiring

In [14]:
## context의 길이 알아보기 

lengths=[]
for c in context:
    lengths.append(len(c))

hap=0
for l in lengths:
    hap+=l

print(lengths)
print(hap/200)

[8194, 3445, 5457, 4509, 3363, 3681, 3733, 2815, 4696, 3033, 8755, 3588, 3971, 3179, 5189, 3361, 3181, 3096, 4139, 2319, 3507, 2205, 2394, 4869, 2482, 4139, 3025, 3311, 5814, 3741, 4304, 3214, 4307, 4813, 2964, 1916, 5524, 4514, 3228, 4849, 2687, 3363, 4806, 7207, 5749, 2638, 3201, 3749, 7376, 4750, 4073, 2719, 2282, 7790, 4060, 3720, 1851, 5461, 5800, 6631, 3443, 5448, 4656, 3081, 3350, 4710, 2809, 3661, 5519, 2135, 4272, 4303, 4272, 6363, 2026, 5632, 3977, 2943, 3576, 5665, 2900, 4943, 4323, 2355, 2340, 3666, 3271, 2728, 4283, 4428, 1440, 3375, 2252, 3226, 3058, 2543, 2312, 4201, 4460, 3125, 4224, 6313, 2766, 5487, 2607, 2678, 4067, 4273, 3371, 4147, 3306, 2836, 5684, 3359, 2274, 2090, 5050, 5363, 3339, 2210, 2938, 3036, 2808, 2856, 2905, 2338, 1737, 4366, 4141, 2707, 2534, 8359, 6470, 1972, 2026, 14370, 2782, 3395, 4968, 5521, 4270, 7742, 3541, 2796, 3044, 1890, 5299, 6270, 3003, 4359, 6465, 7460, 2748, 1952, 3516, 4564, 4110, 2990, 8481, 3794, 2492, 2568, 2695, 2366, 3640, 2664, 14