In [1]:
import requests
import json
import torch
import torch.nn as nn
import os
from tqdm import tqdm

In [2]:
!pip install transformers

Defaulting to user installation because normal site-packages is not writeable


In [3]:
with open('spoken_train-v1.1.json', 'rb') as f:
  squad = json.load(f)

In [4]:
squad['data'][0].keys()

dict_keys(['title', 'paragraphs'])

In [5]:
gr = -1
for idx in range(len(squad['data'])):
    print(squad['data'][idx]['title'])
    if squad['data'][idx]['title'] == 'Greece':
        gr = idx
        print(gr)
        break

University_of_Notre_Dame
Beyoncé
Montana
Genocide
Antibiotics
Frédéric_Chopin
Sino-Tibetan_relations_during_the_Ming_dynasty
IPod
The_Legend_of_Zelda:_Twilight_Princess
Spectre_(2015_film)
2008_Sichuan_earthquake
New_York_City
To_Kill_a_Mockingbird
Solar_energy
Tajikistan
Anthropology
Portugal
Kanye_West
Buddhism
American_Idol
Dog
2008_Summer_Olympics_torch_relay
Alfred_North_Whitehead
Financial_crisis_of_2007%E2%80%9308
Saint_Barth%C3%A9lemy
Genome
Comprehensive_school
Republic_of_the_Congo
Prime_minister
Institute_of_technology
Wayback_Machine
Dutch_Republic
Symbiosis
Canadian_Armed_Forces
Cardinal_(Catholicism)
Iranian_languages
Lighting
Separation_of_powers_under_the_United_States_Constitution
Architecture
Human_Development_Index
Southern_Europe
BBC_Television
Arnold_Schwarzenegger
Plymouth
Heresy
Warsaw_Pact
Materialism
Space_Race
Pub
Christian
Sony_Music_Entertainment
Oklahoma_City
Hunter-gatherer
United_Nations_Population_Fund
Russian_Soviet_Federative_Socialist_Republic
Univers

In [6]:
squad['data'][186]['paragraphs'][0]['context']

'napoleon bonaparte nato mean l l j e n french nap led the napa tea or nepali on deep wanna parte the fifteenth of august seventeen sixty nine to the fifth of may eighteen twenty one was a french military and political leader who rose to prominence during the french revolution and led several successful campaigns during the revolutionary war sir. as napoleon i he was emperor of the french from eighty know for intel eighteen fourteen and again in eighteen fifteen. napoleon dominated european in global affairs for more than a decade while leaving france against a series of coalitions in the napoleonic wars. he won most of these wars and the vast majority of his battles building a large empire that ruled over continental europe before its final collapse in eighteen fifteen. often considered one of the greatest commanders in history his wars in campaigns are studied at military schools worldwide. he also remains one of the most celebrated and controversial political figures in western hist

In [7]:
def read_data(path):  
  # load the json file
  with open(path, 'rb') as f:
    squad = json.load(f)

  contexts, questions, answers = [], [], []
  contexts.extend(passage['context'] for group in squad['data'] for passage in group['paragraphs'] for qa in passage['qas'] for _ in qa['answers'])
  questions.extend(qa['question'] for group in squad['data'] for passage in group['paragraphs'] for qa in passage['qas'] for _ in qa['answers'])
  answers.extend(answer for group in squad['data'] for passage in group['paragraphs'] for qa in passage['qas'] for answer in qa['answers'])

  return contexts, questions, answers

In [8]:
train_contexts, train_questions, train_answers = read_data('spoken_train-v1.1.json')
valid_contexts, valid_questions, valid_answers = read_data('spoken_test-v1.1.json')
     

In [9]:
print(f'There are {len(train_questions)} questions')
print(train_questions[-10000])
print(train_answers[-10000])

There are 37111 questions
What country borders south Estonia?
{'answer_start': 262, 'text': 'latvia'}


In [10]:
def add_end_idx(answers, contexts):
  for i in range(len(answers)):
    gold_text = answers[i]['text']
    start_idx = answers[i]['answer_start']
    end_idx = start_idx + len(gold_text)

    # Check the exact match first, then adjust if necessary
    if contexts[i][start_idx:end_idx] == gold_text:
      answers[i]['answer_end'] = end_idx
    else:
      # Try shifting the start and end indices until the match is found
      for shift in range(1, 3):
        if contexts[i][start_idx - shift:end_idx - shift] == gold_text:
          answers[i]['answer_start'] = start_idx - shift
          answers[i]['answer_end'] = end_idx - shift
          break  # Stop once the correct indices are found

add_end_idx(train_answers, train_contexts)
add_end_idx(valid_answers, valid_contexts)


In [12]:
print(train_questions[-10000])
print(train_answers[-10000])

What country borders south Estonia?
{'answer_start': 262, 'text': 'latvia', 'answer_end': 268}


In [14]:
import transformers
from transformers import BertModel, BertTokenizerFast, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

num_questions, num_posible, num_imposible = 0, 0, 0

def get_data(path):
    with open(path, 'rb') as f:
        raw_data = json.load(f)
    contexts, questions, answers = [], [], []
    num_q, num_pos, num_imp = 0, 0, 0

    for group in raw_data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                num_q += 1
                for answer in qa['answers']:
                    contexts.append(context.lower())
                    questions.append(qa['question'].lower())
                    answers.append(answer)
    return num_q, num_pos, num_imp, contexts, questions, answers

num_q, num_pos, num_imp, train_contexts, train_questions, train_answers = get_data('spoken_train-v1.1.json')
num_questions, num_posible, num_imposible = num_q, num_pos, num_imp
num_q, num_pos, num_imp, valid_contexts, valid_questions, valid_answers = get_data('spoken_test-v1.1.json')

def add_answer_end(answers, contexts):
    for answer, context in zip(answers, contexts):
        answer['text'] = answer['text'].lower()
        answer['answer_end'] = answer['answer_start'] + len(answer['text'])

add_answer_end(train_answers, train_contexts)
add_answer_end(valid_answers, valid_contexts)

MAX_LENGTH = 512
MODEL_PATH = "bert-base-uncased"
doc_stride = 128
tokenizerFast = BertTokenizerFast.from_pretrained(MODEL_PATH)
pad_on_right = tokenizerFast.padding_side == "right"

train_contexts_trunc = []
for i in range(len(train_contexts)):
    if len(train_contexts[i]) > MAX_LENGTH:
        answer_start = train_answers[i]['answer_start']
        answer_end = train_answers[i]['answer_start'] + len(train_answers[i]['text'])
        mid = (answer_start + answer_end) // 2
        para_start = max(0, min(mid - MAX_LENGTH // 2, len(train_contexts[i]) - MAX_LENGTH))
        para_end = para_start + MAX_LENGTH
        train_contexts_trunc.append(train_contexts[i][para_start:para_end])
        train_answers[i]['answer_start'] = int((MAX_LENGTH / 2) - len(train_answers[i]['text']) // 2)
    else:
        train_contexts_trunc.append(train_contexts[i])

train_encodings_fast = tokenizerFast(
    train_questions,
    train_contexts_trunc,
    max_length=MAX_LENGTH,
    truncation=True,
    stride=doc_stride,
    padding=True
)

valid_encodings_fast = tokenizerFast(
    valid_questions,
    valid_contexts,
    max_length=MAX_LENGTH,
    truncation=True,
    stride=doc_stride,
    padding=True
)

def ret_Answer_start_and_end(encodings, answers, idx):
    ret_start, ret_end = 0, 0
    answer_encoding_fast = tokenizerFast(answers[idx]['text'], max_length=MAX_LENGTH, truncation=True, padding=True)
    input_ids = encodings['input_ids'][idx]

    for a in range(len(input_ids) - len(answer_encoding_fast['input_ids'])):
        if input_ids[a + 1: a + 1 + len(answer_encoding_fast['input_ids']) - 2] == answer_encoding_fast['input_ids'][1:-1]:
            ret_start = a + 1
            ret_end = ret_start + len(answer_encoding_fast['input_ids']) - 2
            break
    return ret_start, ret_end

def update_positions(encodings, answers):
    start_positions, end_positions = [], []
    ctr = 0
    for h in range(len(encodings['input_ids'])):
        s, e = ret_Answer_start_and_end(encodings, answers, h)
        start_positions.append(s)
        end_positions.append(e)
        if s == 0:
            ctr += 1
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    print(ctr)

update_positions(train_encodings_fast, train_answers)
update_positions(valid_encodings_fast, valid_answers)

class InputDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, i):
        return {key: torch.tensor(val[i]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings['input_ids'])

train_dataset = InputDataset(train_encodings_fast)
valid_dataset = InputDataset(valid_encodings_fast)
train_data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=1)

bert_model = BertModel.from_pretrained(MODEL_PATH)

class QAModel(nn.Module):
    def __init__(self):
        super(QAModel, self).__init__()
        self.bert = bert_model
        self.drop_out = nn.Dropout(0.1)
        self.linear_relu_stack = nn.Sequential(
            self.drop_out,
            nn.Linear(768 * 2, 768 * 2),
            nn.LeakyReLU(),
            nn.Linear(768 * 2, 2)
        )
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        model_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        hidden_states = torch.cat((model_output.hidden_states[-1], model_output.hidden_states[-3]), dim=-1)
        logits = self.linear_relu_stack(hidden_states)
        
        start_logits, end_logits = logits.split(1, dim=-1)
        return start_logits.squeeze(-1), end_logits.squeeze(-1)

model = QAModel()

def loss_fn(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    return (start_loss + end_loss) / 2

def focal_loss_fn(start_logits, end_logits, start_positions, end_positions, gamma=1):
    smax = nn.Softmax(dim=1)
    inv_probs_start = 1 - smax(start_logits)
    inv_probs_end = 1 - smax(end_logits)

    lsmax = nn.LogSoftmax(dim=1)
    log_probs_start, log_probs_end = lsmax(start_logits), lsmax(end_logits)
    
    nll = nn.NLLLoss()
    fl_start = nll(inv_probs_start.pow(gamma) * log_probs_start, start_positions)
    fl_end = nll(inv_probs_end.pow(gamma) * log_probs_end, end_positions)
    
    return (fl_start + fl_end) / 2

optim = AdamW(model.parameters(), lr=2e-5, weight_decay=2e-2)
EPOCHS = 3
scheduler = transformers.get_linear_schedule_with_warmup(optim, num_warmup_steps=0, num_training_steps=len(train_dataset) * EPOCHS)

total_acc, total_loss, f1_scores = [], [], []

def train_epoch(model, dataloader, epoch):
    model.train()
    losses, acc = [], []
    batch_tracker = 0

    for batch in tqdm(dataloader, desc='Running Epoch'):
        optim.zero_grad()
        input_ids, attention_mask, token_type_ids = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['token_type_ids'].to(device)
        start_positions, end_positions = batch['start_positions'].to(device), batch['end_positions'].to(device)

        out_start, out_end = model(input_ids, attention_mask, token_type_ids)
        loss = focal_loss_fn(out_start, out_end, start_positions, end_positions)
        losses.append(loss.item())
        loss.backward()
        optim.step()
        
        start_pred, end_pred = torch.argmax(out_start, dim=1), torch.argmax(out_end, dim=1)
        acc.extend([(start_pred == start_positions).float().mean().item(), (end_pred == end_positions).float().mean().item()])
        
        batch_tracker += 1
        if batch_tracker == 250 and epoch == 1:
            total_acc.append(sum(acc))
            total_loss.append(sum(losses) / len(losses))
            batch_tracker = 0
    scheduler.step()
    return sum(acc) / len(acc), sum(losses) / len(losses)

def eval_model(model, dataloader):
    model.eval()
    acc, answer_list, f1_true, f1_pred = [], [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Running Evaluation'):
            input_ids, attention_mask, token_type_ids = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['token_type_ids'].to(device)
            start_true, end_true = batch['start_positions'].to(device), batch['end_positions'].to(device)

            out_start, out_end = model(input_ids, attention_mask, token_type_ids)
            start_pred, end_pred = torch.argmax(out_start, dim=1), torch.argmax(out_end, dim=1)
            f1_true.extend(start_true.cpu().numpy())
            f1_pred.extend(start_pred.cpu().numpy())
            answer = tokenizerFast.convert_tokens_to_string(tokenizerFast.convert_ids_to_tokens(input_ids[0][start_pred:end_pred]))
            tanswer = tokenizerFast.convert_tokens_to_string(tokenizerFast.convert_ids_to_tokens(input_ids[0][start_true[0]:end_true[0]]))
            answer_list.append([answer, tanswer])
            acc.extend([(start_pred == start_true).item(), (end_pred == end_true).item()])
    
    # Calculate F1 Score
    f1 = f1_score(f1_true, f1_pred, average='weighted')
    f1_scores.append(f1)
    print(f"F1 Score: {f1}")
    return answer_list

from evaluate import load
wer = load("wer")
model.to(device)
wer_list = []

for epoch in range(EPOCHS):
    train_acc, train_loss = train_epoch(model, train_data_loader, epoch + 1)
    print(f"Train Accuracy: {train_acc}      Train Loss: {train_loss}")
    answer_list = eval_model(model, valid_data_loader)
    
    pred_answers, true_answers = [], []
    for pred, true in answer_list:
        pred_answers.append(pred if pred else "$")
        true_answers.append(true if true else "$")
    
    wer_score = wer.compute(predictions=pred_answers, references=true_answers)
    print("epoch", epoch, ":", wer_score)
    wer_list.append(wer_score)

with open("base_model_wer.txt", 'w') as f:
    for s in wer_list:
        f.write(f"{s}\n")

print("WER List:", wer_list)
print("F1 Scores:", f1_scores)


cuda
609
261


Running Epoch: 100%|██████████| 4639/4639 [04:48<00:00, 16.06it/s]


Train Accuracy: 0.5751643673269675      Train Loss: 1.2458413775193586


Running Evaluation: 100%|██████████| 15875/15875 [02:48<00:00, 94.17it/s]


F1 Score: 0.5317225191231377
epoch 0 : 1.6405079857492548


Running Epoch: 100%|██████████| 4639/4639 [04:49<00:00, 16.02it/s]


Train Accuracy: 0.7376185600377023      Train Loss: 0.6111324815893449


Running Evaluation: 100%|██████████| 15875/15875 [02:49<00:00, 93.78it/s]


F1 Score: 0.5432685910882615
epoch 1 : 1.414313758755241


Running Epoch: 100%|██████████| 4639/4639 [04:49<00:00, 16.02it/s]


Train Accuracy: 0.8213535090745793      Train Loss: 0.3527273701356917


Running Evaluation: 100%|██████████| 15875/15875 [02:48<00:00, 94.19it/s]


F1 Score: 0.5364986629137747
epoch 2 : 1.5088824798235623
WER List: [1.6405079857492548, 1.414313758755241, 1.5088824798235623]
F1 Scores: [0.5317225191231377, 0.5432685910882615, 0.5364986629137747]
