In [None]:
import json

train_file = "/kaggle/input/vietnamese-squad/train-v2.0-translated.json"
dev_file = "/kaggle/input/vietnamese-squad/dev-v2.0-translated.json"
with open(train_file, "r", encoding="utf-8") as f:
    train_data = json.load(f)

with open(dev_file, "r", encoding="utf-8") as f:
    dev_data = json.load(f)

In [None]:
train_data[0]

In [None]:
def change_format(data):
    json_data = []
    for d in data:
        sample = {"context": d[0],
                "question": d[1],
                "ans": d[2]}

        json_data.append(sample)
    return json_data

train_data = change_format(train_data)
dev_data = change_format(dev_data)

In [None]:
train_data[0]

In [None]:
import torch
def create_segment_id(tokenized_input):
    # Search the input_ids for the first instance of the `[SEP]` token.
    sep_index = tokenized_input['input_ids'].index(tokenizer.sep_token_id)

    # The number of segment A tokens includes the [SEP] token istelf.
    num_seg_a = sep_index + 1

    # The remainder are segment B.
    num_seg_b = len(tokenized_input['input_ids']) - num_seg_a

    # Construct the list of 0s and 1s.
    segment_ids = [0]*num_seg_a + [1]*num_seg_b

    # There should be a segment_id for every input token.
    assert len(segment_ids) == len(tokenized_input['input_ids'])
    
    return segment_ids

In [None]:
import re
def find_start_end_char(ques, answer, context):
    match = re.search(re.escape(answer), context, re.IGNORECASE)
    if match:
        start_char = match.start()
        end_char = match.end()
    else:
        start_char, end_char = -1, -1
    return start_char, end_char 

In [None]:
import os
import wandb

with open('/kaggle/input/wandb-key/wandb_key.txt', 'r') as f:
    os.environ['WANDB_API_KEY'] = f.read().strip()
    print(f.read().strip())

# Initialize wandb without interactive login
wandb.login(key=os.environ['WANDB_API_KEY'])

In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import BertForQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
model = BertForQuestionAnswering.from_pretrained("bert-base-multilingual-cased")

def preprocess_function(example):
    inputs = tokenizer(
        example["question"], 
        example["context"], 
        max_length=128, 
        padding="max_length",
        truncation=True,
        return_offsets_mapping=True
    )
    
    start_char, end_char = find_start_end_char(example['question'], example["ans"], example["context"])
    if start_char == -1 or end_char == -1:
        inputs["start_positions"], inputs["end_positions"] = -1, -1
    else:     
        # Convert character positions to token positions
        offsets = inputs["offset_mapping"]
        inputs["start_positions"], inputs["end_positions"] = 0, 0
        for idx, (start, end) in enumerate(offsets):
            if start == start_char:
                inputs["start_positions"] = idx
            if end == end_char:
                inputs["end_positions"] = idx
                break
        inputs.pop("offset_mapping")  # Remove offset_mapping after use
    
    return inputs

In [None]:
from tqdm import tqdm 
from datasets import Dataset

def tokenize_data(squad_formatted_data):
    
    tokenized_dataset = []
    for item in tqdm(squad_formatted_data):
        tokenized_input = preprocess_function(item)
        segment_ids = create_segment_id(tokenized_input)
        tokenized_input["token_type_ids"] = segment_ids
        tokenized_dataset.append(tokenized_input)
    
    return tokenized_dataset

# Tokenize and add custom segment_ids
tokenized_train_data = tokenize_data(train_data)
tokenized_dev_data = tokenize_data(dev_data)

# Convert to Dataset    
tokenized_train_data = Dataset.from_list(tokenized_train_data)
tokenized_dev_data = Dataset.from_list(tokenized_dev_data)

## Train model

In [None]:
from transformers import DataCollatorWithPadding, Trainer, TrainingArguments, BertForQuestionAnswering


# Use the tokenizer you used in preprocess_function
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",           # Evaluate and save at the end of each epoch
    save_strategy="epoch",                 # Save model at each evaluation step
    load_best_model_at_end=True,           # Load the best model at the end of training
    metric_for_best_model="eval_loss",     # Metric to determine the best model (can adjust as needed)
    greater_is_better=False,               # Set to False if lower is better (e.g., for loss)
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=25,
    weight_decay=0.01,
    fp16=True,                             # Use mixed precision for memory efficiency
    gradient_accumulation_steps=4,         # Simulate larger batch size
    save_total_limit=1                     # Keep only the best checkpoint
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_dev_data,  # replace with a validation set if available
    data_collator=data_collator,
)

# Train the model
trainer.train()


## Eval model

In [None]:
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

# Save the model and tokenizer
model.save_pretrained("./fine_tuned_xlm-roberta")
tokenizer.save_pretrained("./fine_tuned_xlm-roberta")


## Run inference

In [None]:
# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the appropriate device
model.to(device)

def answer_question(context, question):
    inputs = tokenizer(question, context, return_tensors="pt").to(device)
   
    with torch.no_grad():
        outputs = model(**inputs)

    start_scores = outputs.start_logits
    end_scores = outputs.end_logits
    start_idx = torch.argmax(start_scores)
    end_idx = torch.argmax(end_scores) + 1

    answer = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx])
    return answer

In [None]:
for sample in dev_data[:5]:
    
    print("Answer: ", sample['ans'])
    print("Predict answer: ", answer_question(sample['context'], sample['question']))