In [None]:
import argparse
from argparse import ArgumentParser
from finetune_model import FintuneModel, LoggingCallback,set_seed
from news_dataset import NewsSummaryDataModule
import pytorch_lightning as pl
import easydict
from transformers import AutoTokenizer
from torchinfo import summary
import torch
import pandas as pd
from sklearn.model_selection import train_test_split


args = easydict.EasyDict({
        "max_input_length": 512,
        "max_output_length": 512,
        "num_train_epochs":100,
        "output_dir": 't5_pretraining',
        "train_batch_size": 2,
        "learning_rate": 1e-5,
        "model_name_or_path":'lcw99/t5-base-korean-text-summary',
        "tokenizer_name_or_path":'lcw99/t5-base-korean-text-summary',
        "freeze_encoder":False,
        "freeze_embeds":False,
        'weight_decay':0.0,
        'adam_epsilon':1e-8,
        'warmup_steps':0,
        'train_batch_size':2,
        'eval_batch_size':2,
        'num_train_epochs':100,
        'gradient_accumulation_steps':1,
        'n_gpu':1,
        'resume_from_checkpoint':None, 
        # 'val_check_interval' : 10,
        'check_val_every_n_epoch':2,
        'n_val':4,
        'val_percent_check': 5,
        'n_train':50,
        'n_test':-1,
        'early_stop_callback':False,
        'fp_16':False, # if you want to enable 16-bit training then install apex and set this to true
        'opt_level':'O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
        'max_grad_norm':0.5, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
        'seed':42,

})

## Define Checkpoint function
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=args.output_dir,filename='modelcheckpoint')

## If resuming from checkpoint, add an arg resume_from_checkpoint
train_params = dict(
    accumulate_grad_batches=args.gradient_accumulation_steps,
    accelerator="gpu",
    inference_mode=False,
    # gpus=args.n_gpu,
    max_epochs=args.num_train_epochs,
    precision= 16 if args.fp_16 else 32,
    # amp_level=args.opt_level,
    # gradient_clip_val=args.max_grad_norm,
    # checkpoint_callback=checkpoint_callback,
    # val_check_interval=args.val_check_interval,
    check_val_every_n_epoch=args.check_val_every_n_epoch,
    callbacks=[LoggingCallback(),checkpoint_callback]
)

set_seed(42)
model = FintuneModel(args)
tokenizer=AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)

df = pd.read_csv("/Users/dongunyun/study/datascience/encoder_decoder/dataset/news_summary.csv")
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

news_summary_module = NewsSummaryDataModule(train_df, val_df, test_df, args.train_batch_size, tokenizer, args.max_input_length, args.max_output_length)
news_summary_module.setup()
trainer = pl.Trainer(**train_params)
trainer.fit(model,news_summary_module)