In [None]:
%pip install transformers sentencepiece datasets asian-bart wandb

In [None]:
import os
if "drive" not in os.listdir("/content") :
    from google.colab import drive
    drive.mount('/content/drive')
os.chdir("/content/drive/MyDrive/NLP_Project_3")

In [None]:
!wandb login

fine_tuned_model_name = "write_your_model_name"

import wandb
wandb.init(project = "Goorm_3rd_project", entity = "2nd_group", name = fine_tuned_model_name)

In [None]:
import pandas as pd
import numpy as np
import torch
import json
import datasets
import random

from transformers import MBartForConditionalGeneration, MBartTokenizer, DataCollatorForSeq2Seq, AutoTokenizer, get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split
from collections import defaultdict, Counter, deque
from tqdm import tqdm

from asian_bart import AsianBartTokenizer, AsianBartForConditionalGeneration
from transformers.models.bart.modeling_bart import shift_tokens_right

SEED = 20220819
BACKBONE = "hyunwoongko/asian-bart-ecjk"

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if "cuda" in DEVICE.type :
    torch.cuda.set_device(DEVICE)
print(DEVICE)

model = AsianBartForConditionalGeneration.from_pretrained(BACKBONE)
model.train()
model = model.to(DEVICE)

tokenizer = AutoTokenizer.from_pretrained(BACKBONE, src_lang="en_XX", tgt_lang="ko_KR")

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("# of params in model :", params)

In [None]:
original_train = utils.load_parallel("../RawData")

In [None]:
batch_size = 16
collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, return_tensors = "pt")

train_data = utils.get_dataset(original_train.loc[original_train.loc[:, "type"] == "train", :], tokenizer, collator, batch_size, True, "en", "ko")
valid_data = utils.get_dataset(original_train.loc[original_train.loc[:, "type"] == "valid", :], tokenizer, collator, batch_size * 2, True, "en", "ko")

In [None]:
learning_rate = 1e-4
epochs = 3

optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate, eps = 1e-6, weight_decay = 0.02)
# lr_scheduler = get_cosine_schedule_with_warmup(optimizer = optimizer,
#                                                num_warmup_steps = int(len(train_data) * epochs * 0.02),
#                                                num_training_steps = len(train_data) * epochs)

In [None]:
wandb_config = {
    "learning_rate" : learning_rate,
    "batch_size" : batch_size,
    "backbone" : BACKBONE,
    "epochs" : epochs
}

wandb.config.update(wandb_config)

In [None]:
scaler = torch.cuda.amp.GradScaler()
wandb.watch(model, log = "all", log_freq = 500)

valid_check_period = 5000
early_stopping = utils.EarlyStopping(path = "../Model/parallel_corpus_eng2kor_checkpoint", patience = 1, verbose = True)
halt = False

step = 0
for epoch in range(epochs) :
    cum_loss = deque(maxlen = 20)
    curr_loss = []

    with tqdm(train_data, unit = " batch") as tepoch :
        curr_loss.clear()
        model.train()

        for i, batch in enumerate(tepoch) :
            step += 1
            optimizer.zero_grad()
            tepoch.set_description(f"Train Epoch {epoch}")

            batch = {k : v.to(DEVICE) for k, v in batch.items()}

            with torch.cuda.amp.autocast() :
                outputs = model(**batch)
                loss = outputs["loss"]

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # lr_scheduler.step()

            cum_loss.append(loss.item())
            curr_loss.append(loss.item())

            del batch, outputs, loss

            tepoch.set_postfix(loss = sum(cum_loss) / len(cum_loss))

            wandb.log({"train_loss" : sum(cum_loss) / len(cum_loss),
                       "lr" : optimizer.state_dict()["param_groups"][0]['lr'],
                       "train_step" : step})


            if not step % valid_check_period :
                model.eval()
                val_losses = []
                with torch.no_grad() :
                    for j, val_batch in enumerate(valid_data) :
                        val_batch = {k : v.to(DEVICE) for k, v in val_batch.items()}
                        with torch.cuda.amp.autocast() :
                            val_outputs = model(**val_batch)
                            val_loss = val_outputs["loss"]
                        val_losses.append(val_loss.item())

                        del val_batch, val_outputs, val_loss
                
                wandb.log({"valid_loss" : sum(val_losses) / len(val_losses),
                           "valid_step" : step // valid_check_period})

                early_stopping(sum(val_losses) / len(val_losses), model)

                if early_stopping.early_stop:
                    print("Early stopping")
                    halt = True
                    break
                else :
                    model.train()

        print("Train loss : ", sum(curr_loss) / len(curr_loss))

    if halt :
        break
    curr_loss.clear()
    cum_loss.clear()

        

In [None]:
model.save_pretrained("../Model/large_batch_eng2kor")