In [None]:
%pip install transformers sentencepiece datasets asian-bart wandb

In [None]:
import os
if "drive" not in os.listdir("/content") :
    from google.colab import drive
    drive.mount('/content/drive')
os.chdir("/content/drive/MyDrive/NLP_Project_3")

In [None]:
!wandb login

fine_tuned_model_name = "tagged_back_translation_kor2eng"

import wandb
wandb.init(project = "Goorm_3rd_project", entity = "2nd_group", name = fine_tuned_model_name)

In [None]:
import pandas as pd
import numpy as np
import torch
import json
import datasets
import random

from transformers import MBartForConditionalGeneration, MBartTokenizer, DataCollatorForSeq2Seq, AutoTokenizer, get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split
from collections import defaultdict, Counter, deque
from tqdm import tqdm
from torchtext.data.metrics import bleu_score

from asian_bart import AsianBartTokenizer, AsianBartForConditionalGeneration
from transformers.models.bart.modeling_bart import shift_tokens_right

SEED = 20220819
BACKBONE = "hyunwoongko/asian-bart-ecjk"

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if "cuda" in DEVICE.type :
    torch.cuda.set_device(DEVICE)
print(DEVICE)

model = AsianBartForConditionalGeneration.from_pretrained("../Model/large_batch_kor2eng")
model.train()
model = model.to(DEVICE)

tokenizer = AutoTokenizer.from_pretrained(BACKBONE, src_lang="ko_KR", tgt_lang="en_XX")

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("# of params in model :", params)

tokenizer.add_tokens("<tag>")
model.resize_token_embeddings(len(tokenizer))

In [None]:
original_train = utils.load_parallel("./RawData")

In [None]:
translation = pd.read_csv("../RawData/generated_kor_data_from_eng.csv")
translate_data = translation.rename({"inferenced" : "ko"},
                                 axis = "columns")
translate_data.loc[:, "type"] = "new_word"
translate_data.loc[:, "ko"] = "<tag>" + translate_data.ko
translate_data.loc[:, "domain"] = "null"

sampled_train = original_train.sample(n = len(translate_data), random_state = SEED, replace = False).reset_index(drop = True)
total_train = pd.concat([original_train, translate_data]).reset_index(drop = True)

In [None]:
def tokenizing(inputs, tokenizer, training):
    model_inputs = tokenizer(inputs["ko"])
    if training :
        with tokenizer.as_target_tokenizer() :
            model_inputs["labels"] = tokenizer(inputs["en"])["input_ids"]        
    return model_inputs

def get_dataset(inputs, tokenizer, collator, batch_size, training) :
    inputs = datasets.Dataset.from_pandas(inputs)
    tokenized_inputs = inputs.map(tokenizing,
                                  batched = True,
                                  fn_kwargs = {"training" : training,
                                               "tokenizer" : tokenizer})
    
    if training :
        columns = tokenizer.model_input_names + ["labels"]
    else :
        columns = tokenizer.model_input_names

    tokenized_inputs.set_format("torch", columns = columns)
    train_dataset = torch.utils.data.DataLoader(tokenized_inputs,
                                                batch_size = batch_size,
                                                shuffle = training,
                                                collate_fn = collator)
    return train_dataset

In [None]:
batch_size = 16
collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, return_tensors = "pt")

train_pd, valid_pd = train_test_split(total_train, random_state = SEED, test_size = .3)

train_data = get_dataset(train_pd, tokenizer, collator, batch_size, True, "ko", "en")
valid_data = get_dataset(valid_pd, tokenizer, collator, batch_size * 2, True, "ko", "en")

  0%|          | 0/610 [00:00<?, ?ba/s]

  0%|          | 0/262 [00:00<?, ?ba/s]

In [None]:
learning_rate = 1e-6
epochs = 3

optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate, eps = 1e-6, weight_decay = 0.02)

In [None]:
wandb_config = {
    "learning_rate" : learning_rate,
    "batch_size" : batch_size,
    "backbone" : BACKBONE,
    "epochs" : epochs
}

wandb.config.update(wandb_config)

In [None]:
scaler = torch.cuda.amp.GradScaler()
wandb.watch(model, log = "all", log_freq = 500)

valid_check_period = 5000
early_stopping = utils.EarlyStopping(path = "../Model/tagged_back_translation_kor2eng_checkpoint", patience = 1, verbose = True)
halt = False

step = 0
for epoch in range(epochs) :
    cum_loss = deque(maxlen = 20)
    curr_loss = []

    with tqdm(train_data, unit = " batch") as tepoch :
        curr_loss.clear()
        model.train()

        for i, batch in enumerate(tepoch) :
            step += 1
            optimizer.zero_grad()
            tepoch.set_description(f"Train Epoch {epoch}")

            batch = {k : v.to(DEVICE) for k, v in batch.items()}

            with torch.cuda.amp.autocast() :
                outputs = model(**batch)
                loss = outputs["loss"]

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # lr_scheduler.step()

            cum_loss.append(loss.item())
            curr_loss.append(loss.item())

            del batch, outputs, loss

            tepoch.set_postfix(loss = sum(cum_loss) / len(cum_loss))

            wandb.log({"train_loss" : sum(cum_loss) / len(cum_loss),
                       "lr" : optimizer.state_dict()["param_groups"][0]['lr'],
                       "train_step" : step})


            if not step % valid_check_period :
                model.eval()
                val_losses = []
                with torch.no_grad() :
                    for j, val_batch in enumerate(valid_data) :
                        val_batch = {k : v.to(DEVICE) for k, v in val_batch.items()}
                        with torch.cuda.amp.autocast() :
                            val_outputs = model(**val_batch)
                            val_loss = val_outputs["loss"]
                        val_losses.append(val_loss.item())

                        del val_batch, val_outputs, val_loss
                
                wandb.log({"valid_loss" : sum(val_losses) / len(val_losses),
                           "valid_step" : step // valid_check_period})

                early_stopping(sum(val_losses) / len(val_losses), model)

                if early_stopping.early_stop:
                    print("Early stopping")
                    halt = True
                    break
                else :
                    model.train()

        print("Train loss : ", sum(curr_loss) / len(curr_loss))

    if halt :
        break
    curr_loss.clear()
    cum_loss.clear()        

Train Epoch 0:  13%|█▎        | 4999/38077 [18:16<1:56:57,  4.71 batch/s, loss=1.2] 

Validation loss decreased (inf --> 1.174058).  Saving model ...


Train Epoch 0:  26%|██▋       | 9999/38077 [45:28<1:40:35,  4.65 batch/s, loss=1.26]

Validation loss decreased (1.174058 --> 1.143443).  Saving model ...


Train Epoch 0:  39%|███▉      | 14999/38077 [1:11:48<1:18:04,  4.93 batch/s, loss=1.21]

Validation loss decreased (1.143443 --> 1.125325).  Saving model ...


Train Epoch 0:  53%|█████▎    | 19999/38077 [1:38:46<1:01:45,  4.88 batch/s, loss=1.1]

Validation loss decreased (1.125325 --> 1.112009).  Saving model ...


Train Epoch 0:  66%|██████▌   | 24999/38077 [2:05:45<46:55,  4.65 batch/s, loss=1.21]

Validation loss decreased (1.112009 --> 1.102763).  Saving model ...


Train Epoch 0:  79%|███████▉  | 29999/38077 [2:32:39<27:49,  4.84 batch/s, loss=1.08]

Validation loss decreased (1.102763 --> 1.094765).  Saving model ...


Train Epoch 0:  92%|█████████▏| 34999/38077 [2:59:31<10:43,  4.78 batch/s, loss=1.25]

Validation loss decreased (1.094765 --> 1.087760).  Saving model ...


Train Epoch 0: 100%|██████████| 38077/38077 [3:19:06<00:00,  3.19 batch/s, loss=1.14]


Train loss :  1.190682259917926


Train Epoch 1:   5%|▌         | 1922/38077 [06:48<1:59:48,  5.03 batch/s, loss=1.21]

Validation loss decreased (1.087760 --> 1.081810).  Saving model ...


Train Epoch 1:  18%|█▊        | 6922/38077 [33:10<1:44:37,  4.96 batch/s, loss=1.1] 

Validation loss decreased (1.081810 --> 1.076784).  Saving model ...


Train Epoch 1:  31%|███▏      | 11922/38077 [59:27<1:29:04,  4.89 batch/s, loss=1.09]

Validation loss decreased (1.076784 --> 1.072968).  Saving model ...


Train Epoch 1:  44%|████▍     | 16922/38077 [1:25:18<1:11:33,  4.93 batch/s, loss=1.14]

Validation loss decreased (1.072968 --> 1.069528).  Saving model ...


Train Epoch 1:  58%|█████▊    | 21922/38077 [1:51:27<55:46,  4.83 batch/s, loss=1.15]

Validation loss decreased (1.069528 --> 1.065706).  Saving model ...


Train Epoch 1:  71%|███████   | 26922/38077 [2:17:51<39:01,  4.76 batch/s, loss=1.22]

Validation loss decreased (1.065706 --> 1.062057).  Saving model ...


Train Epoch 1:  84%|████████▍ | 31922/38077 [2:44:15<21:20,  4.81 batch/s, loss=1.08]

Validation loss decreased (1.062057 --> 1.058424).  Saving model ...


Train Epoch 1:  97%|█████████▋| 36922/38077 [3:10:49<03:53,  4.94 batch/s, loss=1.22]

Validation loss decreased (1.058424 --> 1.056544).  Saving model ...


Train Epoch 1: 100%|██████████| 38077/38077 [3:23:47<00:00,  3.11 batch/s, loss=1.18]


Train loss :  1.1113039284899922


Train Epoch 2:  10%|█         | 3845/38077 [13:38<1:55:05,  4.96 batch/s, loss=1.24]

Validation loss decreased (1.056544 --> 1.054559).  Saving model ...


Train Epoch 2:  23%|██▎       | 8845/38077 [40:25<1:36:14,  5.06 batch/s, loss=1.16]

Validation loss decreased (1.054559 --> 1.053606).  Saving model ...


Train Epoch 2:  36%|███▋      | 13845/38077 [1:07:11<1:24:25,  4.78 batch/s, loss=1.1] 

Validation loss decreased (1.053606 --> 1.050312).  Saving model ...


Train Epoch 2:  49%|████▉     | 18845/38077 [1:34:06<1:05:49,  4.87 batch/s, loss=1.1] 

Validation loss decreased (1.050312 --> 1.050068).  Saving model ...


Train Epoch 2:  63%|██████▎   | 23845/38077 [2:01:04<50:31,  4.69 batch/s, loss=0.983]

Validation loss decreased (1.050068 --> 1.046355).  Saving model ...


Train Epoch 2:  76%|███████▌  | 28845/38077 [2:27:54<31:00,  4.96 batch/s, loss=1.27]

Validation loss decreased (1.046355 --> 1.045153).  Saving model ...


Train Epoch 2:  89%|████████▉ | 33845/38077 [2:54:50<14:55,  4.73 batch/s, loss=1.08]

Validation loss decreased (1.045153 --> 1.043925).  Saving model ...


Train Epoch 2: 100%|██████████| 38077/38077 [3:19:08<00:00,  3.19 batch/s, loss=0.994]

Train loss :  1.0736625434575775





In [None]:
model.save_pretrained("../Model/tagged_bt_kor2eng")