In [5]:
import json
import os
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator
# import sys
# print(sys.executable)
# ! /anaconda/envs/azureml_py310_sdkv2/bin/python -m pip install torch

def generate_corpus_function(text_data_path):
   
    corpus_data = {}
    
    try:
        with open(text_data_path, 'r', encoding='utf-8') as file:
            for i, line in enumerate(file.readlines()):
                corpus_data[f'doc_{i}'] = line.strip()  # Creating a dictionary entry for each line
    except FileNotFoundError:
        print("Text data file not found.")
    except Exception as e:
        print(f"An error occurred: {e}")

    return corpus_data

if __name__ == "__main__":
    
    
    TRAIN_DATASET_FPATH = '/home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/data/train_dataset.json'
    VAL_DATASET_FPATH = '/home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/data/val_dataset.json'

    batch_size=16
    # Load pretrained SentenceTransformer model

    # 3. 훈련된 BERT MODEL
    model_name = '/home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/trained_model'

    transformer_model = models.Transformer(model_name)

    # Pooling layer 추가
    pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension(),
                                pooling_mode_mean_tokens=True,
                                pooling_mode_cls_token=False,
                                pooling_mode_max_tokens=False)

    # SentenceTransformer에 word_embedding_model과 pooling_model을 추가하여 모델 구성
    model = SentenceTransformer(modules=[transformer_model, pooling_model])

    # Define paths to your dataset
    OUTPUT_DIR = '/home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/generated_QAdata'
    TRAIN_QUERIES_FPATH = os.path.join(OUTPUT_DIR, 'train_queries.json')
    TRAIN_DOCS_FPATH = os.path.join(OUTPUT_DIR, 'train_relevant_docs.json')
    VAL_QUERIES_FPATH = os.path.join(OUTPUT_DIR, 'val_queries.json')
    VAL_DOCS_FPATH = os.path.join(OUTPUT_DIR, 'val_relevant_docs.json')

    # Load datasets
    with open(TRAIN_DATASET_FPATH, 'r+') as f:
        train_dataset = json.load(f)
    with open(VAL_DATASET_FPATH, 'r+') as f:
        val_dataset = json.load(f)
        
    with open(TRAIN_QUERIES_FPATH, 'r') as f:
        train_queries = json.load(f)
    with open(TRAIN_DOCS_FPATH, 'r') as f:
        train_relevant_docs = json.load(f)
    with open(VAL_QUERIES_FPATH, 'r') as f:
        val_queries = json.load(f)
    with open(VAL_DOCS_FPATH, 'r') as f:
        val_relevant_docs = json.load(f)



No sentence-transformers model found with name /home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/trained_model. Creating a new one with MEAN pooling.
Some weights of BertModel were not initialized from the model checkpoint at /home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/trained_model and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
####### 데이터 로더 정의 #########
# Prepare training data

train_corpus = train_dataset['corpus'] # dic
train_queries = train_dataset['queries']
train_relevant_docs = train_dataset['relevant_docs']


# 쿼리 ID에 따라 관련 문서 텍스트를 추출
examples = []
for query_id, query in train_queries.items():
    if query_id in train_relevant_docs:
        node_ids = train_relevant_docs[query_id]  # 문서 번호 리스트
        for node_id in node_ids:
            # train_corpus의 키는 UUID, train_relevant_docs의 값은 문서 번호 리스트
            # 문서 번호(node_id)를 사용하여 해당하는 UUID를 찾아야 함
            uuid_key = None
            for key in train_corpus.keys():
                # UUID 키에 node_id가 포함되어 있는지 체크 (가정: UUID 내에 문서 번호가 포함되어 있다고 가정)
                if str(node_id) in key:
                    uuid_key = key
                    break

            if uuid_key:
                text = train_corpus[uuid_key]  # UUID를 이용해 텍스트 추출
                example = InputExample(texts=[query, text], label=1)
                examples.append(example)
            else:
                print(f"Node ID {node_id} not found in corpus for Query ID {query_id}.")
    else:
        print(f"Query ID {query_id} not found in train_relevant_docs.")


def custom_collate_fn(batch):
    texts = [example.texts for example in batch]
    labels = [example.label for example in batch]
    return texts, labels  

train_loader = DataLoader(
    examples, batch_size, collate_fn=custom_collate_fn
)
    
# Define the loss function
loss = losses.MultipleNegativesRankingLoss(model)

In [6]:
######## Define evaluator ##########

val_corpus = val_dataset['corpus']
val_queries = val_dataset['queries']
val_relevant_docs = val_dataset['relevant_docs']


evaluator = InformationRetrievalEvaluator(val_queries, val_corpus, val_relevant_docs)

# Training
EPOCHS = 10
warmup_steps = int(len(train_loader) * EPOCHS * 0.1)  # Adjust warmup steps as needed

model.fit(
    train_objectives=[(train_loader, loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    evaluator=evaluator,
    evaluation_steps=50,
    output_path=os.path.join(OUTPUT_DIR, 'exp_finetune'),  # Save the fine-tuned model
    show_progress_bar=True
)

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]
Iteration:   0%|          | 0/119 [00:00<?, ?it/s][A
Iteration:   1%|          | 1/119 [00:00<00:29,  4.00it/s][A
Iteration:   2%|▏         | 2/119 [00:00<00:28,  4.09it/s][A
Iteration:   3%|▎         | 3/119 [00:00<00:29,  3.98it/s][A
Iteration:   3%|▎         | 4/119 [00:01<00:29,  3.94it/s][A
Iteration:   4%|▍         | 5/119 [00:01<00:29,  3.93it/s][A
Iteration:   5%|▌         | 6/119 [00:01<00:28,  3.97it/s][A
Iteration:   6%|▌         | 7/119 [00:01<00:28,  3.99it/s][A
Iteration:   7%|▋         | 8/119 [00:02<00:27,  3.98it/s][A
Iteration:   8%|▊         | 9/119 [00:02<00:27,  3.97it/s][A
Iteration:   8%|▊         | 10/119 [00:02<00:27,  3.94it/s][A
Iteration:   9%|▉         | 11/119 [00:02<00:27,  3.93it/s][A
Iteration:  10%|█         | 12/119 [00:03<00:27,  3.95it/s][A
Iteration:  11%|█         | 13/119 [00:03<00:26,  3.93it/s][A
Iteration:  12%|█▏        | 14/119 [00:03<00:26,  3.93it/s][A
Iteration:  13%|█▎        |