## Install 

In [64]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [65]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [66]:
!pip install wandb -qU

In [67]:
import torch
import torch.optim as optim

import numpy as np
from tqdm import tqdm as tqdm
from datasets import Dataset

from transformers import(
    AutoModel,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    AutoTokenizer,
    DefaultDataCollator
)

In [68]:
import wandb
wandb.login()



True

## Load Model

### koelectra masker

In [69]:
model_name = 'beomi/KcELECTRA-base-v2022'
model_version = 'groom2team/pj3_classifier_data_add/pytorch_finetuned:v2'
max_length = 128

In [70]:
import wandb
run = wandb.init()
artifact1 = run.use_artifact(model_version, type='model')
artifact_dir1 = artifact1.download()

model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.load_state_dict(torch.load(artifact_dir1+'/best_model_at_end/pytorch_model.bin'))

tokenizer = AutoTokenizer.from_pretrained(model_name)

[34m[1mwandb[0m: Downloading large artifact pytorch_finetuned:v2, 489.19MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:0.1
Some weights of the model checkpoint at beomi/KcELECTRA-base-v2022 were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForS

In [71]:
device = torch.device('cuda')
model = model.to(device)

### kogpt2 generater

In [72]:
eval_model_name = 'skt/kogpt2-base-v2'
eval_model_version = 'groom2team/pj3_gen_gpt2/pytorch_finetuned:v0'

In [73]:
run = wandb.init()
artifact2 = run.use_artifact(eval_model_version, type='model')
artifact_dir2 = artifact2.download()

eval_model = AutoModelForCausalLM.from_pretrained(eval_model_name)
eval_model.load_state_dict(torch.load(artifact_dir2+'/best_model_at_end/pytorch_model.bin'))

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[34m[1mwandb[0m: Downloading large artifact pytorch_finetuned:v0, 492.98MB. 8 files... 
[34m[1mwandb[0m:   8 of 8 files downloaded.  
Done. 0:0:0.1


<All keys matched successfully>

In [74]:
gpttokenizer = AutoTokenizer.from_pretrained(eval_model_name,
                                          bos_token='</s>', eos_token='</s>', unk_token='<unk>',
                                          pad_token='<pad>', mask_token='<mask>')

## masking method

In [81]:
def tokenizeWithoutLabel(data):
    tokenized_datas = tokenizer(
        data['texts'],
        max_length=max_length,
        padding="max_length",
        truncation="only_second"
    )
    return tokenized_datas

In [82]:
def del_list_index(tokens, indexs):
    out = tokens[:]
    indexs = list(indexs)
    for index in indexs[::-1]:
        if out[index] == '[UNK]':
            continue
        del out[index]
    return out

In [83]:
from datasets.utils import disable_progress_bar
disable_progress_bar()

In [84]:
from itertools import combinations
def delete_style_token(text, batch_size):
    torch.cuda.empty_cache()
    texts = []
    texts.append({'ids': [0], 'texts' : text})
    tokens = tokenizer.encode(text)
    token_indexs = range(1, len(tokens) - 1)
    for n in range(1,4):
        if len(token_indexs) < n - 1:
            continue
        for indexs in combinations(token_indexs, n):
            texts.append({'ids' : indexs, 'texts' : tokenizer.decode(del_list_index(tokens, indexs)[1:-1])})
    line_data = Dataset.from_list(texts)
    line_tokenized_datasets = line_data.map(tokenizeWithoutLabel, batched=True, remove_columns=line_data.column_names, )
    data_collator = DefaultDataCollator(return_tensors="pt")
    line_loader = torch.utils.data.DataLoader(line_tokenized_datasets, batch_size=batch_size,
                                            shuffle=False, collate_fn=data_collator,
                                            num_workers=0)
    max_diff = 0
    max_info = {}
    vanil_prob = 0
    for up_i, id_inputs in enumerate(line_loader):
        inputs = {}
        inputs['input_ids'] = id_inputs['input_ids'].to(device)
        inputs['token_type_ids'] = id_inputs['token_type_ids'].to(device)
        inputs['attention_mask'] = id_inputs['attention_mask'].to(device)
        outputs = model(**inputs)
        logits = outputs.logits
        softmax = logits.softmax(dim=-1)
        for i, prob in enumerate(softmax):
            if up_i == 0 and i == 0:
                if prob[1] < 0.7:
                    return 
                else:
                    vanil_prob = prob
            else:
                if vanil_prob[1] - prob[1] > max_diff:
                    max_info = (texts[up_i*batch_size + i], prob)
                    max_diff = vanil_prob[1] - prob[1]
    if max_diff >= 0.2:
        return {'masked':max_info[0]['texts'], 'original':text}
            

## generate method

In [85]:
def tokenizeMasked(data):
    text = delete_style_token(data, 32)
    tokenized_datas = gpttokenizer(
        f"<unused0> <unused1> {text['masked']} <unused2>",
        return_tensors="pt"
    )
    return tokenized_datas

In [86]:
import re
def make_moral_text(immoral_text):
    input = tokenizeMasked(immoral_text)
    gen_ids = eval_model.generate(**input,
                            max_length=256,
                            pad_token_id=gpttokenizer.pad_token_id,
                            eos_token_id=gpttokenizer.eos_token_id,
                            bos_token_id=gpttokenizer.bos_token_id)
    output = gpttokenizer.decode(gen_ids[0])
    pred = re.search('2\>\s(.+?)\s\<u', output)
    ans=pred.group(1)
    return ans

# Core Method

In [87]:
print(make_moral_text('악플다는 찐따들 불쌍하다ㅋㅋㅋ'))

악플이 너무 많다
