S-BERT 파인튜닝 => domain adap 하기 전
# KCBERT-V2023 S-BERT 파인튜닝

In [4]:
import json
import os
import torch
from sentence_transformers import SentenceTransformer, InputExample, losses, models, datasets, evaluation
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import InputExample, losses, evaluation



# Set environment variables
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TORCH_USE_CUDA_DSA"] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Define paths to your dataset
TRAIN_DATASET_FPATH = '/home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/0529new_QA/train_dataset.json'
VAL_DATASET_FPATH = '/home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/0529new_QA/val_dataset.json'
OUTPUT_DIR = '/home/azureuser/cloudfiles/code/Users/hb.suh/OUR_BERT/Sbert_KCBERT2023/exp_finetune_onNEWQA'

# Load datasets
with open(TRAIN_DATASET_FPATH, 'r') as f:
    train_dataset = json.load(f)
with open(VAL_DATASET_FPATH, 'r') as f:
    val_dataset = json.load(f)

train_corpus = train_dataset['corpus']
train_queries = train_dataset['queries']
train_relevant_docs = train_dataset['relevant_docs']
val_corpus = val_dataset['corpus']
val_queries = val_dataset['queries']
val_relevant_docs = val_dataset['relevant_docs']

# Define your model and tokenizer
model_name = "beomi/KcBERT-v2023"
# AutoModel과 AutoTokenizer를 사용하여 sentence-transformers의 형태로 모델을 래핑한 다음 저장
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Define SentenceTransformer model
# word_embedding_model = models.Transformer(model_name)
# pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True)
# model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)




Some weights of RobertaModel were not initialized from the model checkpoint at beomi/KcBERT-v2023 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropou

In [5]:
# Define mean pooling function
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

# Define function to encode texts
def encode_texts(texts, model, tokenizer, device, max_length=512):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=max_length).to(device)
    outputs = model(**inputs)
    return mean_pooling(outputs, inputs['attention_mask']).to(device)


# Prepare training examples
train_examples = []
for query_id, queries in train_queries.items():
    if query_id in train_relevant_docs:
        for onequery in queries:
            for doc_id in train_relevant_docs[query_id]:
                if doc_id in train_corpus:
                    doc_text = train_corpus[doc_id]
                    if isinstance(doc_text, str) and isinstance(onequery, str):
                        train_examples.append(InputExample(texts=[onequery, doc_text], label=1.0))
                    else:
                        print(f"Warning: query or doc_text is not a string. Query: {onequery}, Doc: {doc_text}")
    else:
        print(f"Warning: query_id '{query_id}' not found in train_relevant_docs")


val_examples = []
for query_id, queries in val_queries.items():
    if query_id in val_relevant_docs:
        for onequery in queries:
            
            for doc_id in val_relevant_docs[query_id]:
                if doc_id in val_corpus:
                    doc_text = val_corpus[doc_id]

                    if isinstance(doc_text, str) and isinstance(onequery, str):
                        val_examples.append(InputExample(texts=[onequery, doc_text], label=1.0))
                    else:
                        print(f"Warning: query or doc_text is not a string. Query: {onequery}, Doc: {doc_text}")
            
    else:
        print(f"Warning: query_id '{query_id}' not found in val_relevant_docs")

        
# Custom collate function for DataLoader
def custom_collate(batch):
    texts1, texts2, labels = zip(*[(example.texts[0], example.texts[1], example.label) for example in batch])
    return list(texts1), list(texts2), torch.tensor(labels, dtype=torch.float)

# Convert InputExamples to Dataset
class TextDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

train_dataset = TextDataset(train_examples)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4, collate_fn=custom_collate)

# Define Cosine Similarity Loss
class CosineSimilarityLoss(torch.nn.Module):
    def __init__(self, model, tokenizer, device):
        super(CosineSimilarityLoss, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=1)

    def forward(self, texts1, texts2, labels):
        """
        texts는 [onequery, doc_text] 형태의 리스트
        """
        embeddings1 = encode_texts(texts1, self.model, self.tokenizer, self.device)
        embeddings2 = encode_texts(texts2, self.model, self.tokenizer, self.device)
        similarities = self.cosine_similarity(embeddings1, embeddings2)
        return torch.nn.functional.mse_loss(similarities, labels)

# Initialize loss
train_loss = CosineSimilarityLoss(model, tokenizer, device)

# Training function
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            texts1, texts2, labels = batch
            labels = labels.to(device)
            optimizer.zero_grad()
            loss = loss_fn(texts1, texts2, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader)}")
    return model

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# Train the model
model = train_model(model, train_dataloader, train_loss, optimizer, epochs=12)

# Save the model
model.save_pretrained(OUTPUT_DIR)
#model.save(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# Evaluation
evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(val_examples, name='val-evaluator')


Epoch 1/12, Loss: 0.0001007616694575364
Epoch 2/12, Loss: 1.405664930161307e-06
Epoch 3/12, Loss: 8.067726413350186e-07
Epoch 4/12, Loss: 5.206104806019757e-07
Epoch 6/12, Loss: 2.2335395309888968e-07
Epoch 7/12, Loss: 1.1469863654005527e-07
Epoch 8/12, Loss: 3.87722350432929e-08
Epoch 9/12, Loss: 1.5137988742839738e-08
Epoch 10/12, Loss: 7.525394704733149e-09
Epoch 12/12, Loss: 3.0099128149811864e-09
