In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import os
from datasets import ClassLabel, Sequence
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

device = "cuda"
# model_name_or_path = "t5-large"
# tokenizer_name_or_path = "t5-large"
model_name_or_path = "facebook/bart-large-mnli"
tokenizer_name_or_path = "facebook/bart-large-mnli"

text_column = "sentence"
label_column = "text_label"
max_length = 256
lr = 1e-2
num_epochs = 50
batch_size = 16



In [None]:
from datasets import load_dataset,  load_from_disk
import pandas as pd

dataset =  load_dataset('csv', data_files={'train': "dataset/augemented_training_combine_potong.csv",
                                             'test': 'dataset/test_filter_data_potong.csv'})
dataset = dataset.map(
    batched=True,
    num_proc=1,
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)


def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[label_column]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=5, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs

In [None]:
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

In [None]:
peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [None]:
from sklearn.model_selection import KFold
import numpy as np

# Assuming `dataset` is your combined dataset before preprocessing
k_folds = 5
kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

# Placeholder for cross-validation performance metrics
cv_metrics = []

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    print(f"Fold {fold+1}/{k_folds}")
    
    # Split dataset into current fold's training and validation subsets
    train_subset = dataset.select(train_idx)
    val_subset = dataset.select(val_idx)
    
    # Preprocess datasets
    # Note: You may need to adjust preprocessing to be performed here if it's not feasible to preprocess the entire dataset beforehand
    train_dataset = train_subset.map(preprocess_function, batched=True, ...)
    eval_dataset = val_subset.map(preprocess_function, batched=True, ...)
    
    # Create DataLoaders
    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size)
    eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size)
    
    # Initialize or reset your model and optimizer here
    model = ... # Model initialization or reset
    optimizer = ... # Optimizer initialization or reset
    
    # Training and evaluation loop (as you've defined)
    for epoch in range(num_epochs):
        # Training step
        ...
        # Evaluation step
        ...
    
    # Collect performance metrics for the current fold
    cv_metrics.append({
        'train_loss': train_epoch_loss.item(),
        'eval_loss': eval_epoch_loss.item(),
        'train_ppl': train_ppl.item(),
        'eval_ppl': eval_ppl.item()
    })