In [1]:
import os
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import logging
from sklearn.metrics import accuracy_score

from transformers import TrainingArguments, Trainer, TrainerCallback, DefaultDataCollator
from transformers.trainer_pt_utils import _get_learning_rate
from transformers import AutoConfig, AutoTokenizer, AutoModel
import logging

logging.disable(logging.INFO) # disable INFO and DEBUG logging everywhere
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Seed 고정

In [2]:
def seed_everything(seed:int = 1004):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # current gpu seed
    torch.cuda.manual_seed_all(seed) # All gpu seed
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False  # True로 하면 gpu에 적합한 알고리즘을 선택함.

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Config & Tokenizer

In [3]:
KOELECTRA = 'monologg/koelectra-base-v3-discriminator'
KOBIGBIRD = 'monologg/kobigbird-bert-base'
KOBERT = 'monologg/kobert'
KOROBERTA = 'klue/roberta-base'
KOROBERTA_checkpoint = '/mnt/HDD8TB/PersonalityAI/koRoBERTa_base-checkpoint16000'

# Hidden_size = 768

In [5]:
electra_tokenizer = AutoTokenizer.from_pretrained(KOELECTRA)
bigbird_tokenizer = AutoTokenizer.from_pretrained(KOBIGBIRD)
klue_tokenizer = AutoTokenizer.from_pretrained(KOROBERTA)
bert_tokenizer = AutoTokenizer.from_pretrained(KOBERT)

The repository for monologg/kobert contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/monologg/kobert.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


### Dataset

In [7]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, data, label, electra_tokenizer, bigbird_tokenizer, klue_tokenizer, bert_tokenizer):
        self.data = data
        self.label = label
        self.electra_tokenizer = electra_tokenizer
        self.bigbird_tokenizer = bigbird_tokenizer
        self.klue_tokenizer = klue_tokenizer
        self.bert_tokenizer = bert_tokenizer

    def __getitem__(self, idx):
        text = self.data[idx]
        electra_tokens = self.electra_tokenizer(text, 
                              #  return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        bigbird_tokens = self.bigbird_tokenizer(text, 
                              #  return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        klue_tokens = self.klue_tokenizer(text, 
                              #  return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        bert_tokens = self.bert_tokenizer(text, 
                              #  return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        
        for key in list(bigbird_tokens.keys()):
            bigbird_tokens[key+"_bird"] = bigbird_tokens.pop(key)

        for key in list(klue_tokens.keys()):
            klue_tokens[key+"_klue"] = klue_tokens.pop(key)

        for key in list(bert_tokens.keys()):
            bert_tokens[key+"_bert"] = bert_tokens.pop(key)

        electra_tokens.update(bigbird_tokens)
        electra_tokens.update(klue_tokens)
        electra_tokens.update(bert_tokens)

        electra_tokens['label'] = [float(i) for i in self.label[idx][1:-1].split(',')]
        
     #   item = {key: torch.tensor(values) for key, values in electra_tokens.items()}

        return electra_tokens


    def __len__(self):
        return len(self.data)

In [8]:
import pandas as pd

keti_train = pd.read_csv("/mnt/HDD8TB/Data/PAI/KETI_DATASET/KETI_train_from_ver0.4.csv",
                 index_col=0
                 )
keti_val = pd.read_csv("/mnt/HDD8TB/Data/PAI/KETI_DATASET/KETI_val_from_ver0.4.csv",
                 index_col=0
                 )
keti_test = pd.read_csv("/mnt/HDD8TB/Data/PAI/KETI_DATASET/KETI_test_from_ver0.4.csv",
                 index_col=0
                 )
#data = data.sample(frac=1).reset_index(drop=True) 

In [9]:
keti_train['len'] = keti_train['transcription'].str.split().apply(len)
keti_val['len'] = keti_val['transcription'].str.split().apply(len)
keti_test['len'] = keti_test['transcription'].str.split().apply(len)

In [10]:
keti_test.len.describe()

count    631.000000
mean      12.036450
std       11.124131
min        1.000000
25%        6.000000
50%        8.000000
75%       14.000000
max       77.000000
Name: len, dtype: float64

In [11]:
train_dataset = CLIPDataset(data = keti_train['transcription'].to_list(),
                           label = keti_train['OCEAN'].to_list(),
                            electra_tokenizer = electra_tokenizer,
                            bigbird_tokenizer = bigbird_tokenizer,
                            klue_tokenizer = klue_tokenizer,
                            bert_tokenizer = bert_tokenizer
                           )

val_dataset = CLIPDataset(data = keti_val['transcription'].to_list(),
                           label = keti_val['OCEAN'].to_list(),
                          electra_tokenizer = electra_tokenizer,
                          bigbird_tokenizer = bigbird_tokenizer,
                          klue_tokenizer = klue_tokenizer,
                          bert_tokenizer = bert_tokenizer
                         )

test_dataset = CLIPDataset(data = keti_test['transcription'].to_list(),
                           label = keti_test['OCEAN'].to_list(),
                           electra_tokenizer = electra_tokenizer,
                           bigbird_tokenizer = bigbird_tokenizer,
                           klue_tokenizer = klue_tokenizer,
                           bert_tokenizer = bert_tokenizer
                          )

In [12]:
print(train_dataset.__len__())
print(train_dataset.__getitem__(970))

4431
{'input_ids': [2, 6258, 4219, 2279, 4116, 4116, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

### Model

In [11]:
model = AutoModel.from_pretrained(KOELECTRA)
model2 = AutoModel.from_pretrained("klue/bert-base")
model3 = AutoModel.from_pretrained(KOBIGBIRD)
model4 = AutoModel.from_pretrained(KOROBERTA_checkpoint)

print("koelectra size : ", model.num_parameters())
print("kobert size : ", model2.num_parameters())
print("kobigbird size : ", model3.num_parameters())
print("koroberta size : ", model4.num_parameters())

  return self.fget.__get__(instance, owner)()


koelectra size :  112330752
kobert size :  110617344
kobigbird size :  113753856
koroberta size :  110618112


In [12]:
class EnsembleModel(nn.Module):
    def __init__(self, num_labels, projection_dim=768, d_out=0.1, ):
        super().__init__()
        self.projection_dim = projection_dim
        self.num_labels = num_labels
        
        self.koelectra = AutoModel.from_pretrained(KOELECTRA).eval()
        self.kobigbird = AutoModel.from_pretrained(KOBIGBIRD).eval()
        self.kobert = AutoModel.from_pretrained(KOBERT).eval()
        self.koroberta = AutoModel.from_pretrained(KOROBERTA_checkpoint).eval()  # 1 x (768 *4) <행렬곱> (768*4) x 5 == 1 x 5
        
        self.dense = nn.Linear(self.kobert.config.hidden_size * 4, self.projection_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(d_out)
        self.out_proj = nn.Linear(self.projection_dim, self.num_labels)

    def forward(
        self,
        input_ids = None, attention_mask = None, token_type_ids = None,
        input_ids_bird = None, attention_mask_bird = None, token_type_ids_bird = None,
        input_ids_klue = None, attention_mask_klue = None, token_type_ids_klue = None,
        input_ids_bert = None, attention_mask_bert = None, token_type_ids_bert = None,
        position_ids = None,
        head_mask = None,
        inputs_embeds = None,
        labels = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
    ):
        
        with torch.no_grad():
            electra_outputs = self.koelectra(
                input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict=return_dict,
            )

            bigbird_outputs = self.kobigbird(
                input_ids = input_ids_bird,
                attention_mask = attention_mask_bird,
                token_type_ids = token_type_ids_bird,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict = return_dict,
            )

            bert_outputs = self.kobert(
                input_ids = input_ids_bert,
                attention_mask = attention_mask_bert,
                token_type_ids = token_type_ids_bert,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict = return_dict,
            )

            roberta_outputs = self.koroberta(
                input_ids = input_ids_klue,
                attention_mask = attention_mask_klue,
                token_type_ids = token_type_ids_klue,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict = return_dict,
            )
        
        # get pooled output([CLS] Token)
        electra_embeds = electra_outputs[0][:,0,:]
        bigbird_embeds = bigbird_outputs[0][:,0,:]
        bert_embeds = bert_outputs[0][:,0,:]
        roberta_embeds = roberta_outputs[0][:,0,:]

        concated_input = torch.concat([electra_embeds, bigbird_embeds, bert_embeds, roberta_embeds], dim=1)  # [Batch_size, hidden * 4 (3072)] 

     #   avg_input = (electra_embeds + bigbird_embeds + bert_embeds + roberta_embeds) / 4

        x = self.dense(concated_input)
        x = self.act(x)  # GELU
        x = self.dropout(x)
        logits = self.out_proj(x)  # [Batch_Size x 5]

        loss = None

        if labels != None:
            loss_fct = nn.L1Loss()
    
           # loss_fct = loss_fct(logits_O.view(-1, self.num_labels), labels[0])
            loss = loss_fct(logits.squeeze(), labels.squeeze())

        return loss, logits

In [13]:
model = EnsembleModel(num_labels = 5)

In [14]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

431253509

In [15]:
from torchinfo import summary

summary(model)

Layer (type:depth-idx)                                            Param #
EnsembleModel                                                     --
├─ElectraModel: 1-1                                               --
│    └─ElectraEmbeddings: 2-1                                     --
│    │    └─Embedding: 3-1                                        26,880,000
│    │    └─Embedding: 3-2                                        393,216
│    │    └─Embedding: 3-3                                        1,536
│    │    └─LayerNorm: 3-4                                        1,536
│    │    └─Dropout: 3-5                                          --
│    └─ElectraEncoder: 2-2                                        --
│    │    └─ModuleList: 3-6                                       85,054,464
├─BigBirdModel: 1-2                                               --
│    └─BigBirdEmbeddings: 2-3                                     --
│    │    └─Embedding: 3-7                                        24,96

## Train

In [16]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import multilabel_confusion_matrix
from transformers import EvalPrediction
from sklearn.metrics import mean_absolute_error
import torch

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions
    
    MAE = 1 - mean_absolute_error(labels, preds)
    
    
    return {
        '1 - MAE' : MAE
    }

In [17]:
training_args = TrainingArguments(
    output_dir="/mnt/HDD8TB/PersonalityAI/FI_1head/keti_koEnsemble_model_usingrobertacheck+kobert_concat_nosafetensor2",
    logging_dir= "/mnt/HDD8TB/PersonalityAI/FI_1head/keti_koEnsemble_model_usingrobertacheck+kobert_concat_nosafetensor2_log",
    num_train_epochs=10,
    learning_rate = 3e-4,  # Best : 3e-4
   # max_steps=1000,
    per_device_train_batch_size=32,
#    gradient_accumulation_steps = 16,
    per_device_eval_batch_size = 32,
#    eval_accumulation_steps = 32,
    logging_strategy = "epoch",
    save_strategy = "epoch",
    lr_scheduler_type = "linear",
    dataloader_num_workers = 12,
#    warmup_ratio = 0.1,
#    weight_decay=0.01,
    evaluation_strategy='epoch',
    save_safetensors =False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=DefaultDataCollator(return_tensors = "pt"),
    compute_metrics=compute_metrics,
)



In [None]:
#trainer.add_callback(CustomCallback(trainer))
trainer.train()

### Load

In [18]:
test_ensemble = EnsembleModel(num_labels = 5)

state_dict = torch.load("/mnt/HDD8TB/PersonalityAI/FI_1head/keti_koEnsemble_model_usingrobertacheck+kobert_concat_nosafetensor2/checkpoint-70/pytorch_model.bin")

test_ensemble.load_state_dict(state_dict)

<All keys matched successfully>

In [19]:
jf_text1 = "한밤중에 깨어나면 베개에 바퀴벌레가 '많아' 있고 이불이 매우 불편합니다! \
집에 돌아와서 짐을 풀어보니 두 개를 가지고 왔다는 걸 발견했어요! 역겨운! ! !"

class testDataset(torch.utils.data.Dataset):
    def __init__(self, data, electra_tokenizer, bigbird_tokenizer, klue_tokenizer, bert_tokenizer, label=None):
        self.data = data
        self.label = label
        self.electra_tokenizer = electra_tokenizer
        self.bigbird_tokenizer = bigbird_tokenizer
        self.klue_tokenizer = klue_tokenizer
        self.bert_tokenizer = bert_tokenizer

    def __getitem__(self, idx):
        text = self.data[idx]
        electra_tokens = self.electra_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        bigbird_tokens = self.bigbird_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        klue_tokens = self.klue_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        bert_tokens = self.bert_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        
        for key in list(bigbird_tokens.keys()):
            bigbird_tokens[key+"_bird"] = bigbird_tokens.pop(key)

        for key in list(klue_tokens.keys()):
            klue_tokens[key+"_klue"] = klue_tokens.pop(key)

        for key in list(bert_tokens.keys()):
            bert_tokens[key+"_bert"] = bert_tokens.pop(key)

        electra_tokens.update(bigbird_tokens)
        electra_tokens.update(klue_tokens)
        electra_tokens.update(bert_tokens)
        
     #   item = {key: torch.tensor(values) for key, values in electra_tokens.items()}

        return electra_tokens


    def __len__(self):
        return len(self.data)

tokens = testDataset(data = jf_text1,
                            electra_tokenizer = electra_tokenizer,
                            bigbird_tokenizer = bigbird_tokenizer,
                            klue_tokenizer = klue_tokenizer,
                            bert_tokenizer = bert_tokenizer,
                           )
tokens = tokens.__getitem__(0)

_, result = test_ensemble(
        input_ids = tokens['input_ids'],
        attention_mask = tokens['attention_mask'],
        token_type_ids = tokens['token_type_ids'],
        input_ids_bird = tokens['input_ids_bird'],
        attention_mask_bird = tokens['attention_mask_bird'],
        token_type_ids_bird = tokens['token_type_ids_bird'],
        input_ids_klue = tokens['input_ids_klue'],
        attention_mask_klue = tokens['attention_mask_klue'],
        token_type_ids_klue = tokens['token_type_ids_klue'],
        input_ids_bert = tokens['input_ids_bert'],
        attention_mask_bert = tokens['attention_mask_bert'],
        token_type_ids_bert = tokens['token_type_ids_bert'],
            )

print(result)

tensor([[0.5553, 0.6893, 0.7340, 0.5556, 0.4327]], grad_fn=<AddmmBackward0>)


In [20]:
result[0][0].item()

0.5552600622177124

### Test

In [21]:
def test_compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions
    
    labels = labels.T
    preds = preds.T
    
    MAE = []
    
    for i in range(5):
        MAE.append(1 - mean_absolute_error(labels[i], preds[i]))
    
    
    return {
        'Avg' : np.mean(MAE),
        '1 - MAE' : MAE
    }

In [24]:
from safetensors.torch import load_model

test_model = EnsembleModel(num_labels = 5)

######################################### Check your model path Again#####################################
load_model(test_model, "/mnt/HDD8TB/PersonalityAI/FI_1head/keti_koEnsemble_model_usingrobertacheck+kobert_concat/checkpoint-70/model.safetensors")

(set(), [])

In [22]:
eval_trainer = Trainer(
    model=test_ensemble,
    args=training_args,
    data_collator=DefaultDataCollator(return_tensors = "pt"),
    compute_metrics=test_compute_metrics,
)

In [23]:
eval_trainer.evaluate(test_dataset)



{'eval_loss': 0.1510709971189499,
 'eval_Avg': 0.8489405632019043,
 'eval_1 - MAE': [0.8426289558410645,
  0.836871549487114,
  0.8657149076461792,
  0.8616010397672653,
  0.8378863632678986],
 'eval_runtime': 12.0291,
 'eval_samples_per_second': 52.456,
 'eval_steps_per_second': 0.831}

### Test another

In [2]:
KOELECTRA = 'monologg/koelectra-base-v3-discriminator'
KOBIGBIRD = 'monologg/kobigbird-bert-base'
KOBERT = 'monologg/kobert'
KOROBERTA = 'klue/roberta-base'
KOROBERTA_checkpoint = '/mnt/HDD8TB/PersonalityAI/koRoBERTa_base-checkpoint16000'


class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, data, electra_tokenizer, bigbird_tokenizer, klue_tokenizer, bert_tokenizer, label=None):
        self.data = data
        self.label = label
        self.electra_tokenizer = electra_tokenizer
        self.bigbird_tokenizer = bigbird_tokenizer
        self.klue_tokenizer = klue_tokenizer
        self.bert_tokenizer = bert_tokenizer

    def __getitem__(self, idx):
        text = self.data
        electra_tokens = self.electra_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        bigbird_tokens = self.bigbird_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        klue_tokens = self.klue_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        bert_tokens = self.bert_tokenizer(text, 
                                return_tensors="pt",  # pytorch.Tensor로 리턴
                                max_length=512, 
                                padding="max_length",  
                                truncation=True,  # max_length 넘어가면 버림)
                               )
        
        for key in list(bigbird_tokens.keys()):
            bigbird_tokens[key+"_bird"] = bigbird_tokens.pop(key)

        for key in list(klue_tokens.keys()):
            klue_tokens[key+"_klue"] = klue_tokens.pop(key)

        for key in list(bert_tokens.keys()):
            bert_tokens[key+"_bert"] = bert_tokens.pop(key)

        electra_tokens.update(bigbird_tokens)
        electra_tokens.update(klue_tokens)
        electra_tokens.update(bert_tokens)
        
     #   item = {key: torch.tensor(values) for key, values in electra_tokens.items()}

        return electra_tokens


    def __len__(self):
        return len(self.data)

class EnsembleModel(nn.Module):
    def __init__(self, num_labels, projection_dim=768, d_out=0.1, ):
        super().__init__()
        self.projection_dim = projection_dim
        self.num_labels = num_labels
        
        self.koelectra = AutoModel.from_pretrained(KOELECTRA).eval()
        self.kobigbird = AutoModel.from_pretrained(KOBIGBIRD).eval()
        self.kobert = AutoModel.from_pretrained(KOBERT).eval()
        self.koroberta = AutoModel.from_pretrained(KOROBERTA_checkpoint).eval()  # 1 x (768 *4) <행렬곱> (768*4) x 5 == 1 x 5
        
        self.dense = nn.Linear(self.kobert.config.hidden_size * 4, self.projection_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(d_out)
        self.out_proj = nn.Linear(self.projection_dim, self.num_labels)

    def forward(
        self,
        input_ids = None, attention_mask = None, token_type_ids = None,
        input_ids_bird = None, attention_mask_bird = None, token_type_ids_bird = None,
        input_ids_klue = None, attention_mask_klue = None, token_type_ids_klue = None,
        input_ids_bert = None, attention_mask_bert = None, token_type_ids_bert = None,
        position_ids = None, position_ids_bird = None, position_ids_klue = None, position_ids_bert = None,
        head_mask = None,
        inputs_embeds = None,
        labels = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
    ):
        
        with torch.no_grad():
            electra_outputs = self.koelectra(
                input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict=return_dict,
            )

            bigbird_outputs = self.kobigbird(
                input_ids = input_ids_bird,
                attention_mask = attention_mask_bird,
                token_type_ids = token_type_ids_bird,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict = return_dict,
            )

            bert_outputs = self.kobert(
                input_ids = input_ids_bert,
                attention_mask = attention_mask_bert,
                token_type_ids = token_type_ids_bert,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict = return_dict,
            )

            roberta_outputs = self.koroberta(
                input_ids = input_ids_klue,
                attention_mask = attention_mask_klue,
                token_type_ids = token_type_ids_klue,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict = return_dict,
            )

        print(electra_outputs[0].shape)
        
        # get pooled output([CLS] Token)
        electra_embeds = electra_outputs[0][:,0,:]
        bigbird_embeds = bigbird_outputs[0][:,0,:]
        bert_embeds = bert_outputs[0][:,0,:]
        roberta_embeds = roberta_outputs[0][:,0,:]

        print("elec shape: ", electra_embeds.shape)
        print("bigbird shape: ", bigbird_embeds.shape)
        print("bert shape: ", bert_embeds.shape)
        print("roberta shape: ", roberta_embeds.shape)

        concated_input = torch.cat([electra_embeds, bigbird_embeds, bert_embeds, roberta_embeds], dim=1)  # [Batch_size, hidden * 4 (3072)] 

     #   avg_input = (electra_embeds + bigbird_embeds + bert_embeds + roberta_embeds) / 4

        x = self.dense(concated_input)
        x = self.act(x)  # GELU
        x = self.dropout(x)
        logits = self.out_proj(x)  # [Batch_Size x 5]

        loss = None
        
        if labels != None:
            loss_fct = nn.L1Loss()

           # loss_fct = loss_fct(logits_O.view(-1, self.num_labels), labels[0])
            loss = loss_fct(logits.squeeze(), labels.squeeze())

        return loss, logits
    
electra_tokenizer = AutoTokenizer.from_pretrained(KOELECTRA)
bigbird_tokenizer = AutoTokenizer.from_pretrained(KOBIGBIRD)
klue_tokenizer = AutoTokenizer.from_pretrained(KOROBERTA)
bert_tokenizer = AutoTokenizer.from_pretrained(KOBERT, trust_remote_code=True)

test_ensemble = EnsembleModel(num_labels = 5)

state_dict = torch.load("/mnt/HDD8TB/PersonalityAI/FI_1head/keti_koEnsemble_model_usingrobertacheck+kobert_concat_nosafetensor/checkpoint-70/pytorch_model.bin")

test_ensemble.load_state_dict(state_dict, strict=False)

def run(text):
    jf_text1 = text
    
    tokens = CLIPDataset(data = jf_text1,
                        electra_tokenizer = electra_tokenizer,
                        bigbird_tokenizer = bigbird_tokenizer,
                        klue_tokenizer = klue_tokenizer,
                        bert_tokenizer = bert_tokenizer
                       )
    tokens = tokens.__getitem__(0)
    print(tokens)
    print(electra_tokenizer.decode(tokens['input_ids'][0]))
    
    test_ensemble.eval()
    
    _, result = test_ensemble(
            input_ids = tokens['input_ids'],
            attention_mask = tokens['attention_mask'],
            token_type_ids = tokens['token_type_ids'],
            input_ids_bird = tokens['input_ids_bird'],
            attention_mask_bird = tokens['attention_mask_bird'],
            token_type_ids_bird = tokens['token_type_ids_bird'],
            input_ids_klue = tokens['input_ids_klue'],
            attention_mask_klue = tokens['attention_mask_klue'],
            token_type_ids_klue = tokens['token_type_ids_klue'],
            input_ids_bert = tokens['input_ids_bert'],
            attention_mask_bert = tokens['attention_mask_bert'],
            token_type_ids_bert = tokens['token_type_ids_bert'],
                )
    
    print(result)

  return self.fget.__get__(instance, owner)()
