In [None]:
# PRE PROCESSED MODEL

In [1]:

import torch
import json
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.optim import AdamW
import jiwer

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def load_data_file(path):
    with open(path, 'rb') as f:
        raw_data = json.load(f)
    contexts = []
    questions = []
    answers = []

    for group in raw_data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context.lower())
                    questions.append(question.lower())
                    answers.append(answer)

    return contexts, questions, answers

def add_answer_end_positions(answers, contexts):
    for answer, context in zip(answers, contexts):
        answer['text'] = answer['text'].lower()
        answer['answer_end'] = answer['answer_start'] + len(answer['text'])

def preprocess_data(contexts, questions, answers, tokenizer, max_length):
    encodings = tokenizer(questions, contexts, max_length=max_length, padding='max_length', truncation=True)
    start_positions = []
    end_positions = []

    for i, (answer, context) in enumerate(zip(answers, contexts)):
        context_ids = encodings['input_ids'][i]
        answer_start = answer['answer_start']
        answer_end = answer['answer_end']

        # Convert answer's start/end positions in paragraph text to start/end positions in tokenized paragraph
        answer_start_token = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(context[:answer_start]))
        answer_end_token = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(context[:answer_end]))

        # Ensure the answer is within the tokenized context
        if len(answer_start_token) == 0 or len(answer_end_token) == 0:
            start_positions.append(0)
            end_positions.append(0)
            continue

        # Calculate the middle of the answer span
        mid = (answer_start_token[0] + answer_end_token[0]) // 2

        # A single window is obtained by slicing the portion of paragraph containing the answer
        paragraph_start = max(0, min(mid - max_length // 2, len(context_ids) - max_length))
        paragraph_end = paragraph_start + max_length

        start_positions.append(answer_start - paragraph_start)
        end_positions.append(answer_end - paragraph_start)

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    return encodings

class QADataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_contexts, train_questions, train_answers = load_data_file('spoken_train-v1.1.json')
valid_contexts, valid_questions, valid_answers = load_data_file('spoken_test-v1.1.json')

add_answer_end_positions(train_answers, train_contexts)
add_answer_end_positions(valid_answers, valid_contexts)

MAX_LENGTH = 512
MODEL_PATH = "distilbert-base-uncased"

tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH)

train_encodings = preprocess_data(train_contexts, train_questions, train_answers, tokenizer, MAX_LENGTH)
valid_encodings = preprocess_data(valid_contexts, valid_questions, valid_answers, tokenizer, MAX_LENGTH)

train_dataset = QADataset(train_encodings)
valid_dataset = QADataset(valid_encodings)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1)

distilbert_model = DistilBertForQuestionAnswering.from_pretrained(MODEL_PATH).to(device)

optimizer = AdamW(distilbert_model.parameters(), lr=5e-5)

def train_epoch(model, dataloader, optimizer):
    model.train()
    total_loss = 0.0
    for batch in tqdm(dataloader, desc='Training'):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, 
                        start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

def evaluate_model(model, dataloader):
    model.eval()
    wer_list = []
    for batch in tqdm(dataloader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        start_pred = torch.argmax(outputs.start_logits, dim=1)
        end_pred = torch.argmax(outputs.end_logits, dim=1)
        for i in range(len(start_true)):
            pred_answer = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]+1])
            true_answer = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]+1])
            if true_answer.strip():  # Check if true_answer is not empty
                wer = jiwer.wer(true_answer, pred_answer)
                wer_list.append(wer)
    return sum(wer_list) / len(wer_list) if wer_list else 0.0


EPOCHS = 5
for epoch in range(EPOCHS):
    train_loss = train_epoch(distilbert_model, train_loader, optimizer)
    wer_score = evaluate_model(distilbert_model, valid_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss}, WER Score: {wer_score}")


Token indices sequence length is longer than the specified maximum sequence length for this model (593 > 512). Running this sequence through the model will result in indexing errors
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training: 100%|██████████| 2320/2320 [09:24<00:00,  4.11it/s]
Evaluating: 100%|██████████| 15875/15875 [01:56<00:00, 136.33it/s]


Epoch 1/5, Train Loss: 6.0139772275398515, WER Score: 1.4500994916772374


Training: 100%|██████████| 2320/2320 [09:24<00:00,  4.11it/s]
Evaluating: 100%|██████████| 15875/15875 [01:56<00:00, 136.51it/s]


Epoch 2/5, Train Loss: 5.612893968512272, WER Score: 1.2757634766585904


Training: 100%|██████████| 2320/2320 [09:24<00:00,  4.11it/s]
Evaluating: 100%|██████████| 15875/15875 [01:56<00:00, 136.06it/s]


Epoch 3/5, Train Loss: 5.28202116623007, WER Score: 1.465233573342405


Training: 100%|██████████| 2320/2320 [09:24<00:00,  4.11it/s]
Evaluating: 100%|██████████| 15875/15875 [01:56<00:00, 135.96it/s]


Epoch 4/5, Train Loss: 4.959557330505601, WER Score: 1.843065890566444


Training: 100%|██████████| 2320/2320 [09:24<00:00,  4.11it/s]
Evaluating: 100%|██████████| 15875/15875 [01:55<00:00, 137.07it/s]

Epoch 5/5, Train Loss: 4.650380419965448, WER Score: 1.9395529243622893



