In [None]:
# BASE MODEL 

In [6]:
import torch
import json
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import jiwer
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def load_data_file(path):
    try:
        with open(path, 'r', encoding='utf-8') as f:  
            raw_data = json.load(f)
    except FileNotFoundError:
        logger.error(f"File not found: {path}")
        raise

    contexts = []
    questions = []
    answers = []
    num_questions = 0
    num_possible = 0
    num_impossible = 0

    for group in raw_data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                question = qa['question']
                num_questions += 1
                if 'is_impossible' in qa and qa['is_impossible']:
                    num_impossible += 1
                else:
                    num_possible += 1
                for answer in qa.get('answers', []):
                    contexts.append(context.lower())
                    questions.append(question.lower())
                    answers.append(answer)

    return num_questions, num_possible, num_impossible, contexts, questions, answers

try:
    num_train_questions, num_train_possible, num_train_impossible, train_contexts, train_questions, train_answers = load_data_file('spoken_train-v1.1.json')
    num_valid_questions, num_valid_possible, num_valid_impossible, valid_contexts, valid_questions, valid_answers = load_data_file('spoken_test-v1.1.json')
except Exception as e:
    logger.error(f"Error loading data: {e}")
    exit()

def add_answer_end_positions(answers, contexts):
    for answer, context in zip(answers, contexts):
        answer_text = answer.get('text', '').lower()
        answer_start = answer.get('answer_start', -1)
        answer['answer_end'] = answer_start + len(answer_text)

add_answer_end_positions(train_answers, train_contexts)
add_answer_end_positions(valid_answers, valid_contexts)

MAX_LENGTH = 512
MODEL_PATH = "distilbert-base-uncased"

tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH)

train_encodings = tokenizer(train_questions, train_contexts, max_length=MAX_LENGTH, padding=True, truncation=True)
valid_encodings = tokenizer(valid_questions, valid_contexts, max_length=MAX_LENGTH, padding=True, truncation=True)

class QADataset(Dataset):
    def __init__(self, encodings, answers):
        self.encodings = encodings
        self.answers = answers

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['start_positions'] = torch.tensor(self.answers[idx].get('answer_start', -1))
        item['end_positions'] = torch.tensor(self.answers[idx].get('answer_end', -1))
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = QADataset(train_encodings, train_answers)
valid_dataset = QADataset(valid_encodings, valid_answers)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1)

qa_model = DistilBertForQuestionAnswering.from_pretrained(MODEL_PATH).to(device)

optimizer = torch.optim.AdamW(qa_model.parameters(), lr=5e-5)

def train_one_epoch(model, dataloader, optimizer):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc='Training'):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        inputs_embeds = None  
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, 
                        start_positions=start_positions, end_positions=end_positions,
                        inputs_embeds=inputs_embeds)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

def evaluate_model(model, dataloader):
    model.eval()
    wer_list = []
    for batch in tqdm(dataloader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        start_pred = torch.argmax(outputs.start_logits, dim=1)
        end_pred = torch.argmax(outputs.end_logits, dim=1)
        for i in range(len(start_true)):
            pred_answer = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]+1])
            true_answer = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]+1])
            if true_answer.strip():  
                wer = jiwer.wer(true_answer, pred_answer)
                wer_list.append(wer)
    return sum(wer_list) / len(wer_list) if wer_list else 0.0

EPOCHS = 5
best_wer = float('inf')
patience = 3
counter = 0

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(qa_model, train_loader, optimizer)
    wer_score = evaluate_model(qa_model, valid_loader)
    logger.info(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss}, WER Score: {wer_score}")

    if wer_score < best_wer:
        best_wer = wer_score
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            logger.info("Early stopping triggered!")
            break


Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [02:03<00:00, 128.85it/s]


Epoch 1/5, Train Loss: 5.872068548819114, WER Score: 14.545863414116578


Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [01:56<00:00, 135.71it/s]


Epoch 2/5, Train Loss: 5.353944472814429, WER Score: 3.5266719288784265


Training: 100%|██████████| 2320/2320 [09:19<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [01:54<00:00, 138.84it/s]


Epoch 3/5, Train Loss: 4.823115943423633, WER Score: 2.6733518694144633


Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [01:57<00:00, 135.56it/s]


Epoch 4/5, Train Loss: 4.363082417331893, WER Score: 2.0523018580566452


Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [01:57<00:00, 135.51it/s]

Epoch 5/5, Train Loss: 4.024793861755009, WER Score: 2.1047314573792426



