In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

class PolyEncoder(nn.Module):
    def __init__(self, encoder_name, poly_m=16):
        super(PolyEncoder, self).__init__()
        self.encoder = BertModel.from_pretrained(encoder_name)
        self.poly_m = poly_m
        self.poly_code_embeddings = nn.Embedding(self.poly_m, self.encoder.config.hidden_size)
        self.poly_code_embeddings.weight.data.normal_(mean=0.0, std=self.encoder.config.initializer_range)

    def forward(self, input_ids, attention_mask):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        encoder_hidden_states = encoder_outputs.last_hidden_state

        # Poly-encoder context embeddings
        poly_codes = self.poly_code_embeddings.weight.unsqueeze(0).expand(input_ids.size(0), -1, -1)
        poly_mask = torch.ones(poly_codes.size()[:2], device=poly_codes.device)

        # Attention between poly codes and encoder hidden states
        poly_contexts = torch.bmm(poly_codes, encoder_hidden_states.transpose(1, 2))
        return poly_contexts, poly_mask

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")  # 한국어를 잘 해주는 bert를 찾아봐야 할 것 같음 (수정 예정)
poly_encoder = PolyEncoder("bert-base-uncased")  # 한국어를 잘 해주는 bert를 찾아봐야 할 것 같음 (수정 예정)

In [None]:
from peft import LoraConfig, get_peft_model

# LoraConfig 설정
lora_config = LoraConfig(
    r=16,  # Low-rank dimension
    lora_alpha=32,  # Scaling factor
    lora_dropout=0.1,  # Dropout rate
    bias="none"  # Bias setting
)

# PEFT 적용
model = get_peft_model(model, lora_config)

In [None]:
# 예시 데이터
examples = [
    {"context": "여기에 문서 내용", "query": "질문 내용", "summary": "정답 요약"},
    # 추가 데이터...
]

# 데이터 전처리
inputs = [tokenizer(ex["context"], ex["query"], return_tensors="pt", truncation=True) for ex in examples]
labels = [tokenizer(ex["summary"], return_tensors="pt", truncation=True) for ex in examples]

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

llama_model_name = "vaiv/llamion-14b-base"
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name)
llama_model = AutoModelForCausalLM.from_pretrained(llama_model_name)

In [None]:
class PolyLlamaModel(nn.Module):
    def __init__(self, encoder_name, decoder_name, poly_m=16):
        super(PolyLlamaModel, self).__init__()
        self.encoder = PolyEncoder(encoder_name, poly_m)
        self.decoder = AutoModelForCausalLM.from_pretrained(decoder_name)

    def forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask):
        # Encoder
        encoder_contexts, poly_mask = self.encoder(input_ids, attention_mask)

        # Decoder
        decoder_outputs = self.decoder(input_ids=decoder_input_ids, attention_mask=decoder_attention_mask,
                                       encoder_hidden_states=encoder_contexts, encoder_attention_mask=poly_mask)

        return decoder_outputs

# 모델 초기화
model = PolyLlamaModel("bert-base-uncased", "vaiv/llamion-14b-base")

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",   # 수정 필요
    num_train_epochs=3,
    per_device_train_batch_size=2,
    save_steps=10_000,
    save_total_limit=2,
)

# Trainer 설정
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=inputs,  # 학습 데이터셋(확실하지 않음, 다시 확인해봐야 함)
    eval_dataset=labels,    # 정답 데이터셋(확실하지 않음, 다시 확인해봐야 함)
)

# 모델 학습
trainer.train()

# 모델 평가
results = trainer.evaluate()
print(f"Evaluation results: {results}")

In [None]:
# 평가 함수
def evaluate_model(model, tokenizer, dataset):
    model.eval()
    correct = 0
    total = len(dataset)

    for example in dataset:
        inputs = tokenizer(example["context"], example["query"], return_tensors="pt")
        with torch.no_grad():
            outputs = model.generate(inputs.input_ids)
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

        if prediction == example["summary"]:
            correct += 1

    accuracy = correct / total
    return accuracy

# 평가 실행
accuracy = evaluate_model(model, tokenizer, examples)
print(f"Accuracy: {accuracy * 100:.2f}%")

## bash 파일에 넣어야 할 것들 (모델 추출을 위함)
- huggingface-cli login
- transformers-cli repo create my-awesome-model
- transformers-cli repo upload --repo my-awesome-model --path ./results