In [None]:
import os
import random
import numpy as np
import argparse
import warnings
import torch
from torch.utils.data import random_split
import transformers
from transformers import AutoModel, AutoTokenizer

# private modules
from dataset import IMDBDataset, SNLIDataset, AGNewsDataset
from modules import BERT
from modules import TopAdapter
from modules import LayerWiseAdapter
import functions
warnings.filterwarnings(action="ignore")
transformers.logging.set_verbosity_error()
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

"""
HOW TO EXECUTE :
    python main.py --dataset [dataset name] --top --layer --num_adapters 12 --full_finetuning

    explanation : 
        --dataset [dataset name] : Specify the dataset (required)
            - e.g. python main.py --dataset imdb --top --layer
            - dataset list : ['imdb', 'snli', 'agnews']
        --top : Add a top adapter to the model (same as '-t')
        --layer : Add layer-wise adapters to the model (same as '-l')
        --num_adapters : number of adapters (default : 12)
            - optional: half - 6으로 주고 adapter 개수 조절
        --full_finetuning : full fine-tuning or adapter-based tuning (boolean)
"""


"""
HYPER PARAMETERS:
    These parameters are adjustable based on your preference and requirements.
    Feel free to adjust these settings to suit your needs.
"""
RANDOM_SEED = 42
BATCH_SIZE = 16
VALID_RATIO = 0.1
EPOCHS = 10
LEARNING_RATE = 5e-6


def main():
    # set seed
    functions.seed_everything(RANDOM_SEED)
    
    # argument parsing    
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset', dest='dataset', action='store')
    parser.add_argument('-t', '--top', dest='top', action='store_true')
    parser.add_argument('-l', '--layer', dest='layer', action='store_true')
    parser.add_argument('-n', '--num_adapters', dest='num_adapters', action='store', type=int, default=12)
    parser.add_argument('-f', '--full_finetuning', dest='full_finetuning', action='store_true')
    args = parser.parse_args()
    num_adapters = args.num_adapters
    full_finetuning = args.full_finetuning
    layer_wise_adapters = None
    top_adapter = None

    # load pretrained tokenizer
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # load dataset, set adapters and criterion
    if args.dataset == 'imdb':
        train_dataset = IMDBDataset(tokenizer=tokenizer, mode='train')
        test_dataset = IMDBDataset(tokenizer=tokenizer, mode='test')
        num_classes = 2
            
    elif args.dataset == 'snli':
        train_dataset = SNLIDataset(tokenizer=tokenizer, mode='train')
        test_dataset = SNLIDataset(tokenizer=tokenizer, mode='test')
        num_classes = 3
        
    elif args.dataset == 'agnews':
        train_dataset = AGNewsDataset(tokenizer=tokenizer, mode='train')
        test_dataset = AGNewsDataset(tokenizer=tokenizer, mode='test')
        num_classes = 4
    
    # load model with adapters
    if args.top:
        top_adapter = TopAdapter(num_classes=num_classes) 

    if args.layer:
        layer_wise_adapters = LayerWiseAdapter(num_adapters=num_adapters)
        
    #if args.full_finetuning:

    model = BERT(top_adapter=top_adapter, 
                 layer_wise_adapters=layer_wise_adapters, 
                 num_classes=num_classes,
                 num_adapters=num_adapters,
                 full_finetuning=full_finetuning,
                 )
    criterion = torch.nn.CrossEntropyLoss()
    
    # train-valid split
    train_size = int((1-VALID_RATIO) * len(train_dataset))  
    val_size = len(train_dataset) - train_size  
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    
    # dataloaders
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        drop_last=True,
    )
    test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        drop_last=False,
    )
    
    # train & validation
    best_model, train_losses, train_accs, val_losses, val_accs = functions.train(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        criterion=criterion,
        epochs=EPOCHS,
        learning_rate=LEARNING_RATE,
    )

    # visualize
    functions.result(train_losses, train_accs, val_losses, val_accs)

    # inference
    accuracy = functions.inference(best_model, test_dataloader)
    print(f"inference acc: {np.round(accuracy, 4)}")

    
    

if __name__ == '__main__':
    main()
