In [None]:
import torch
from torch.utils.data import DataLoader, RandomSampler, Dataset
from torch import nn

from typing import List
from transformers import ElectraModel, ElectraTokenizer, get_linear_schedule_with_warmup
from tqdm.auto import tqdm

import os
import glob
import pandas as pd
import numpy as np
import random

In [None]:
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)

In [None]:
bert_model = ElectraModel.from_pretrained("monologg/koelectra-base-v3-discriminator")
tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-discriminator")

In [None]:
class SentimentDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer):
        # 데이터 셋이 가져야 하는 기본적인 값들을 세팅하는 함수
        self.texts = texts
        self.tokenizer = tokenizer

    def __getitem__(self, index:int):
        # bert입력을 만들어 주는 곳
        text = self.texts[index]
        bert_inputs = self.tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    
        bert_inputs["input_ids"] = bert_inputs["input_ids"].squeeze()
        bert_inputs["attention_mask"] = bert_inputs["attention_mask"].squeeze()
        bert_inputs["token_type_ids"] = bert_inputs["token_type_ids"].squeeze()
    
        return bert_inputs
  
    def __len__(self) -> int:
        # 이 데이터셋의 크기
        return len(self.texts)

In [None]:
class SentimentClassificationModel(nn.Module):
    def __init__(self, bert):
        super().__init__()
        # BERT 모델을 입력 받기, 2차원으로 줄여주는 모델(weight) 추가
        self.bert = bert
        self.classification_layer = nn.Linear(768, 2)

    def forward(self, batch_data):
        # 1. BERT 모델의 결과 뽑기
        bert_output = self.bert.forward(
          input_ids = batch_data["input_ids"],
          attention_mask = batch_data["attention_mask"], 
          token_type_ids = batch_data["token_type_ids"])
        pooled_output = bert_output[0][:,0]

        #2. BERT 모델 결과를 2차원으로 만들기
        classification_output = self.classification_layer.forward(pooled_output)
        return classification_output

In [None]:
def predict_model(pred_dataloader):
    
    model.eval()# back progation을 하지 않는 거 빼곤 train과 거의 동일, freeze all weight
    
    predictions = []
    
    for batch_data in tqdm(pred_dataloader):
        batch_data = {key:value.to(device) for key, value in batch_data.items()}
        with torch.no_grad():
            classification_output = model.forward(batch_data)
        
        logits = classification_output.detach().cpu().numpy()
        predictions.append(logits)
          
    predictions = np.concatenate(predictions, axis = 0)

    return predictions

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SentimentClassificationModel(bert_model)
model.load_state_dict(torch.load('checkpoints/checkpoint_cat3.pt', map_location=torch.device(device)))
model.to(device)

In [None]:
batch_size = 32

In [None]:
dir_name = glob.glob('../xlsx_data/#4_result/*')
g = ["남성", "남자+-남성", "여성", "여자+-여성"]

for i, d in enumerate(dir_name):
    for f in tqdm(glob.glob(d+"/*.xlsx")):
        df = pd.read_excel(f, index_col=0)
        if len(df) == 0:
            df.to_excel(os.path.join(os.path.dirname(os.getcwd()), 'xlsx_data', '#4_result', 
                                     '#4_'+g[i], '#4_'+"_".join(f.split("_")[-3:])))
            continue
        df['total'] = df['title'].str.strip() + ".\n" + df['text'].str.strip()
        df = df.fillna('')
        total_texts = [[d[index:index+512] for index in range(0, len(d)-256, 256)] for d in df['total'].tolist()]
        all_texts = [text for paragraph in total_texts for text in paragraph]
        all_indexes = [ list(range(len(paragraph))) for paragraph in total_texts]
        pred_dataset = SentimentDataset(all_texts, tokenizer)
        pred_dataloader = DataLoader(pred_dataset, batch_size=batch_size)
        predictions = predict_model(pred_dataloader)
        sp = 0
        article_sentiment = list()
        for indexes in all_indexes:
            mean_logit = np.mean([predictions[index+sp] for index in indexes], axis=0)
            article_sentiment.append(np.argmax(mean_logit, axis=-1))
            sp += len(indexes)
        df['sentiment'] = article_sentiment
        df = df.drop(['total'], axis=1)
        df.to_excel(os.path.join(os.path.dirname(os.getcwd()), 'xlsx_data', '#4_result', 
                                 '#4_'+g[i], '#4_'+"_".join(f.split("_")[-3:])))