In [None]:
# FINETUNING 
# Creates a linear scheduler for the optimizer/ Doc stride/ Gradient Accumulation

In [2]:
import torch
import json
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.optim import AdamW
import jiwer

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def load_data_file(path):
    with open(path, 'r') as f:
        raw_data = json.load(f)
    contexts = []
    questions = []
    answers = []

    for group in raw_data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context.lower())
                    questions.append(question.lower())
                    answers.append(answer)

    return contexts, questions, answers

def add_answer_end_positions(answers, contexts):
    for answer, context in zip(answers, contexts):
        answer['text'] = answer['text'].lower()
        answer['answer_end'] = answer['answer_start'] + len(answer['text'])

def split_context(context, max_length, doc_stride):
    # Split context into overlapping windows
    tokenized_context = distilbert_tokenizer(context, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
    tokens = distilbert_tokenizer.convert_ids_to_tokens(tokenized_context['input_ids'][0])
    windows = []
    for i in range(0, len(tokens), doc_stride):
        window = {'input_ids': tokenized_context['input_ids'][0][i:i+max_length],
                  'attention_mask': tokenized_context['attention_mask'][0][i:i+max_length]}
        windows.append(window)
    return windows

class QADataset(Dataset):
    def __init__(self, encodings, answers):
        self.encodings = encodings
        self.answers = answers

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['start_positions'] = torch.tensor(self.answers[idx]['answer_start'])
        item['end_positions'] = torch.tensor(self.answers[idx]['answer_end'])
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

train_contexts, train_questions, train_answers = load_data_file('spoken_train-v1.1.json')
valid_contexts, valid_questions, valid_answers = load_data_file('spoken_test-v1.1.json')

add_answer_end_positions(train_answers, train_contexts)
add_answer_end_positions(valid_answers, valid_contexts)

MAX_LENGTH = 512
MODEL_PATH = "distilbert-base-uncased"
doc_stride = 128

distilbert_tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH)

train_encodings_fast = distilbert_tokenizer(train_questions, train_contexts, max_length=MAX_LENGTH, padding='max_length', truncation=True)
valid_encodings_fast = distilbert_tokenizer(valid_questions, valid_contexts, max_length=MAX_LENGTH, padding='max_length', truncation=True)

train_dataset = QADataset(train_encodings_fast, train_answers)
valid_dataset = QADataset(valid_encodings_fast, valid_answers)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1)

distilbert_model = DistilBertForQuestionAnswering.from_pretrained(MODEL_PATH).to(device)

optimizer = AdamW(distilbert_model.parameters(), lr=5e-5)

EPOCHS = 5  
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

def train_epoch(model, dataloader, optimizer, scheduler):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc='Training'):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()

    return total_loss / len(dataloader)

def evaluate_model(model, dataloader):
    model.eval()
    wer_list = []
    for batch in tqdm(dataloader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        
        pred_start_logits = outputs.start_logits
        pred_end_logits = outputs.end_logits
        
        for i in range(len(input_ids)):
            context = valid_contexts[i]
            windows = split_context(context, MAX_LENGTH, doc_stride)
            pred_start_idx = torch.argmax(pred_start_logits[i])
            pred_end_idx = torch.argmax(pred_end_logits[i])
            
            # Merge predictions from overlapping windows
        for window in windows:
            if window['input_ids'].nonzero().numel() > 0:  # Check if there are non-zero elements
                window_start = window['input_ids'].nonzero().min().item()
                window_end = len(window['input_ids']) - window['input_ids'].flip(dims=[0]).nonzero().min().item()
                if pred_start_idx >= window_start and pred_end_idx < window_end:
                    pred_start_idx += window_start
                    pred_end_idx += window_start
                    break

            pred_answer = distilbert_tokenizer.decode(input_ids[i][pred_start_idx:pred_end_idx+1])
            true_answer = distilbert_tokenizer.decode(input_ids[i][start_true[i]:end_true[i]+1])
            if true_answer.strip():  # Check if true_answer is not empty
                wer = jiwer.wer(true_answer, pred_answer)
                wer_list.append(wer)
    return sum(wer_list) / len(wer_list) if wer_list else 0.0

for epoch in range(EPOCHS):
    train_loss = train_epoch(distilbert_model, train_loader, optimizer, scheduler)
    wer_score = evaluate_model(distilbert_model, valid_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss}, WER Score: {wer_score}")


Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.16it/s]
Evaluating: 100%|██████████| 15875/15875 [02:27<00:00, 107.40it/s]


Epoch 1/5, Train Loss: 5.897547473167551, WER Score: 10.91984426194956


Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [02:27<00:00, 107.50it/s]


Epoch 2/5, Train Loss: 5.33816732949224, WER Score: 10.128574943269985


Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [02:21<00:00, 112.39it/s]


Epoch 3/5, Train Loss: 4.769730858453389, WER Score: 4.724774392011088


Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [02:18<00:00, 114.66it/s]


Epoch 4/5, Train Loss: 4.291139659388312, WER Score: 3.218864884354455


Training: 100%|██████████| 2320/2320 [09:18<00:00,  4.15it/s]
Evaluating: 100%|██████████| 15875/15875 [02:18<00:00, 114.76it/s]

Epoch 5/5, Train Loss: 3.954296954233071, WER Score: 2.373607447415738



