In [None]:
!pip install torchtext==0.15.1
!pip install torch==2.1.0
!pip install transformers==4.27.1
!pip install datasets==2.17.0

In [None]:
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

qa_dataset = load_dataset('squad', split='train').shard(num_shards=40, index=0)
qa_dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/7.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 2190
})

In [None]:
import re

def text_normalize(text):
    text = re.sub(r'[^\w\s]', ' ', text)

    return text

# Define tokenizer function
tokenizer = get_tokenizer('basic_english')

# Create a function to yield list of tokens
def yield_tokens(data):
    for item in data:
        yield tokenizer('<cls> ' + text_normalize(item['context']) + ' <sep> ' + text_normalize(item['question']))

# Create vocabulary
vocab = build_vocab_from_iterator(
    yield_tokens(qa_dataset),
    specials=['<unk>', '<pad>', '<bos>', '<eos>', '<sep>', '<cls>']
)
vocab.set_default_index(vocab['<unk>'])
vocab.get_stoi()

{'하': 26911,
 '큰': 26910,
 '소': 26905,
 '사람': 26904,
 '나라이름': 26901,
 '魄': 26900,
 '魂': 26899,
 '金陵邑': 26897,
 '越城': 26892,
 '豊臣秀吉': 26891,
 '義皇帝': 26889,
 '禅': 26885,
 '现代汉语通用字表': 26884,
 '浙江': 26880,
 '法王': 26879,
 '汉语水平考试': 26877,
 '水': 26875,
 '校尉': 26874,
 '李閏': 26873,
 '方块字': 26870,
 '平安': 26868,
 '小': 26867,
 '唐入り': 26863,
 '北軍': 26861,
 '凹田': 26860,
 '冶城': 26859,
 '下': 26856,
 'トワイライトプリンセス': 26854,
 'ゼルダの伝説': 26851,
 'ἥλιος': 26850,
 'โรงเร': 26845,
 'สลาม': 26844,
 'ยนศาสนาอ': 26843,
 'ब': 26842,
 'نصراني': 26840,
 'نصارى': 26839,
 'العربية': 26835,
 'языка': 26831,
 'хийума': 26829,
 'стиль': 26827,
 'русского': 26826,
 'рмонтов': 26825,
 'пу': 26824,
 'никола': 26823,
 'насекомое': 26822,
 'н': 26821,
 'мок': 26820,
 'михаи': 26819,
 'кий': 26816,
 'замо': 26812,
 'дов': 26810,
 'голь': 26807,
 'высо': 26805,
 'χαλκός': 26800,
 'φαναῖος': 26798,
 'τεχνικά': 26796,
 'πικρό': 26795,
 'πάνορμος': 26794,
 'λύκη': 26793,
 'λύκειος': 26792,
 'λύκειο': 26791,
 'λόγος': 26790,
 'λυκ

In [None]:
MAX_SEQ_LEN = 512
PAD_IDX = vocab['<pad>']

def pad_and_truncate(input_ids, max_seq_len):
    if len(input_ids) > max_seq_len:
        input_ids = input_ids[:max_seq_len]
    elif len(input_ids) < max_seq_len:
        input_ids += [PAD_IDX] * (max_seq_len - len(input_ids))

    return input_ids

def vectorize(question, context, answer):
    input_text = '<cls> ' + text_normalize(question) + ' <sep> ' + text_normalize(context)
    input_ids = [vocab[token] for token in tokenizer(input_text)]
    input_ids = pad_and_truncate(input_ids, MAX_SEQ_LEN)

    answer_ids = [vocab[token] for token in tokenizer(text_normalize(answer))]
    try:
        start_positions = input_ids.index(answer_ids[0])
        end_positions = start_positions + len(answer_ids) - 1
    except:
        start_positions = 0
        end_positions = 0

    input_ids = torch.tensor(input_ids, dtype=torch.long)
    start_positions = torch.tensor(start_positions, dtype=torch.long)
    end_positions = torch.tensor(end_positions, dtype=torch.long)

    return input_ids, start_positions, end_positions

In [None]:
class QADataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question_text = item['question']
        context_text = item['context']
        answer_text = item['answers']['text'][0]

        input_ids, start_positions, end_positions = vectorize(question_text, context_text, answer_text)

        return input_ids, start_positions, end_positions

In [None]:
def decode(input_ids):
    return ' '.join([vocab.lookup_token(token) for token in input_ids])

In [None]:
for item in qa_dataset:
    question_text = item['question']
    context_text = item['context']
    answer_text = item['answers']['text'][0]

    input_ids, start_positions, end_positions = vectorize(question_text, context_text, answer_text)
    print(input_ids)
    text = decode(input_ids)
    answer_span = input_ids[start_positions:end_positions+1]

    print(text)
    print(decode(answer_span))

    break

tensor([    5,    10,  1256,    48,     6,  1420,   587,  8195,  1478,     9,
         4080,     9,  9276,   244,     4, 10672,     6,   188,    34,    11,
          767,  1326,  5138,     6,   371,   380,    17,  1391,  4218,    13,
           11,  1993,  3680,     7,     6,  1420,   587,  1186,     9,   850,
            7,     6,   371,   380,     8,  6072,    26,    13,    11,  1279,
         3680,     7,  1105,    18,  1434, 26013,    18,     6,  3309, 26127,
         1169,  2109, 22332,   614,    10,     6,   371,   380,    13,     6,
         4142,     7,     6,  2401,  2348,  1186,  1024,     6,  4142,    13,
            6, 12012,    11, 12672,   222,     7,  3966,     8, 23630,    26,
           13,    11, 13475,     7,     6, 12012,    27,  9276,   244,    80,
            6,  1420,   587, 23759,  1376,    10,   645, 16126, 24746,     9,
         4080,    27,     6,   169,     7,     6,   371,  1701,     8,     9,
           11,   584,   282,    19,  5955,   116,   118,  4429, 

In [None]:
BATCH_SIZE = 256
train_dataset = QADataset(qa_dataset)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
import math
import torch.nn as nn
import torch.optim as optim

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                          num_heads=num_heads)
        self.ffn = nn.Linear(in_features=embed_dim,
                             out_features=ff_dim)
        self.layernorm_1 = nn.LayerNorm(normalized_shape=embed_dim)
        self.layernorm_2 = nn.LayerNorm(normalized_shape=embed_dim)

    def forward(self, query, key, value):
        attn_output, _ = self.attn(query, key, value)
        out_1 = self.layernorm_1(query + attn_output)
        ffn_output = self.ffn(out_1)
        x = self.layernorm_2(out_1 + ffn_output)

        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]

        return x

class QAModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_heads, ff_dim, max_len):
        super(QAModel, self).__init__()
        self.input_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = PositionalEncoding(embedding_dim)
        self.transformer = TransformerBlock(embedding_dim, n_heads, ff_dim)

        self.start_linear = nn.Linear(ff_dim, 1)
        self.end_linear = nn.Linear(ff_dim, 1)

    def forward(self, text):
        input_embedded = self.input_embedding(text)
        input_embedded = self.pos_encoder(input_embedded)
        transformer_out = self.transformer(input_embedded, input_embedded, input_embedded)
        start_logits = self.start_linear(transformer_out).squeeze(-1)
        end_logits = self.end_linear(transformer_out).squeeze(-1)

        return start_logits, end_logits

# Model parameters
EMBEDDING_DIM = 64
FF_DIM = 64
N_HEADS = 8
VOCAB_SIZE = len(vocab)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = QAModel(VOCAB_SIZE, EMBEDDING_DIM, N_HEADS, FF_DIM, MAX_SEQ_LEN).to(device)

input = torch.randint(0, 10, size=(1, 10)).to(device)
print(input.shape)
model.eval()
with torch.no_grad():
    start_logits, end_logits = model(input)

print(start_logits.shape)

torch.Size([1, 10])
torch.Size([1, 10])


In [None]:
LR = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

In [None]:
from tqdm import tqdm
EPOCHS = 150

model.train()
for epoch in tqdm(range(EPOCHS)):
    train_losses = []
    for idx, (input_ids, start_positions, end_positions) in enumerate(train_loader):
        input_ids = input_ids.to(device)
        start_positions = start_positions.to(device)
        end_positions = end_positions.to(device)
        optimizer.zero_grad()
        start_logits, end_logits = model(input_ids)
        start_loss = criterion(start_logits, start_positions)
        end_loss = criterion(end_logits, end_positions)
        total_loss = (start_loss + end_loss) / 2
        total_loss.backward()
        optimizer.step()
        train_losses.append(total_loss.item())
    train_loss = sum(train_losses) / len(train_losses)
    print(f'EPOCH {epoch + 1}\tTraining Loss: {train_loss}')

  1%|          | 1/150 [00:02<05:39,  2.28s/it]

EPOCH 1	Training Loss: 5.084600501590305


  1%|▏         | 2/150 [00:04<05:01,  2.03s/it]

EPOCH 2	Training Loss: 4.619819164276123


  2%|▏         | 3/150 [00:05<04:45,  1.94s/it]

EPOCH 3	Training Loss: 4.3725600772433815


  3%|▎         | 4/150 [00:08<04:59,  2.05s/it]

EPOCH 4	Training Loss: 4.090801927778456


  3%|▎         | 5/150 [00:09<04:42,  1.95s/it]

EPOCH 5	Training Loss: 3.7885384294721813


  4%|▍         | 6/150 [00:11<04:31,  1.89s/it]

EPOCH 6	Training Loss: 3.495330492655436


  5%|▍         | 7/150 [00:13<04:24,  1.85s/it]

EPOCH 7	Training Loss: 3.2894284990098743


  5%|▌         | 8/150 [00:15<04:20,  1.83s/it]

EPOCH 8	Training Loss: 3.130997313393487


  6%|▌         | 9/150 [00:17<04:17,  1.83s/it]

EPOCH 9	Training Loss: 2.976626475652059


  7%|▋         | 10/150 [00:19<04:26,  1.90s/it]

EPOCH 10	Training Loss: 2.81812940703498


  7%|▋         | 11/150 [00:21<04:25,  1.91s/it]

EPOCH 11	Training Loss: 2.7312698894076877


  8%|▊         | 12/150 [00:22<04:17,  1.86s/it]

EPOCH 12	Training Loss: 2.6438990698920355


  9%|▊         | 13/150 [00:24<04:12,  1.84s/it]

EPOCH 13	Training Loss: 2.5895522700415716


  9%|▉         | 14/150 [00:26<04:08,  1.83s/it]

EPOCH 14	Training Loss: 2.520303964614868


 10%|█         | 15/150 [00:28<04:04,  1.81s/it]

EPOCH 15	Training Loss: 2.4587179289923773


 11%|█         | 16/150 [00:30<04:04,  1.83s/it]

EPOCH 16	Training Loss: 2.4276238282521567


 11%|█▏        | 17/150 [00:32<04:16,  1.93s/it]

EPOCH 17	Training Loss: 2.345514270994398


 12%|█▏        | 18/150 [00:34<04:09,  1.89s/it]

EPOCH 18	Training Loss: 2.325292189915975


 13%|█▎        | 19/150 [00:35<04:04,  1.87s/it]

EPOCH 19	Training Loss: 2.287243790096707


 13%|█▎        | 20/150 [00:37<03:58,  1.83s/it]

EPOCH 20	Training Loss: 2.250607172648112


 14%|█▍        | 21/150 [00:39<03:54,  1.82s/it]

EPOCH 21	Training Loss: 2.2156767580244274


 15%|█▍        | 22/150 [00:41<03:52,  1.81s/it]

EPOCH 22	Training Loss: 2.196060578028361


 15%|█▌        | 23/150 [00:43<04:06,  1.94s/it]

EPOCH 23	Training Loss: 2.165381736225552


 16%|█▌        | 24/150 [00:45<03:58,  1.89s/it]

EPOCH 24	Training Loss: 2.1360323561562433


 17%|█▋        | 25/150 [00:47<03:53,  1.87s/it]

EPOCH 25	Training Loss: 2.1144505871666803


 17%|█▋        | 26/150 [00:48<03:49,  1.85s/it]

EPOCH 26	Training Loss: 2.083683000670539


 18%|█▊        | 27/150 [00:50<03:44,  1.82s/it]

EPOCH 27	Training Loss: 2.0773207479053073


 19%|█▊        | 28/150 [00:52<03:40,  1.81s/it]

EPOCH 28	Training Loss: 2.0411848492092557


 19%|█▉        | 29/150 [00:54<03:47,  1.88s/it]

EPOCH 29	Training Loss: 2.017866757180956


 20%|██        | 30/150 [00:56<03:50,  1.92s/it]

EPOCH 30	Training Loss: 1.9945792224672105


 21%|██        | 31/150 [00:58<03:44,  1.88s/it]

EPOCH 31	Training Loss: 1.9633305072784424


 21%|██▏       | 32/150 [00:59<03:38,  1.85s/it]

EPOCH 32	Training Loss: 1.9390025403764513


 22%|██▏       | 33/150 [01:01<03:34,  1.83s/it]

EPOCH 33	Training Loss: 1.9441976282331679


 23%|██▎       | 34/150 [01:03<03:31,  1.82s/it]

EPOCH 34	Training Loss: 1.9076038996378581


 23%|██▎       | 35/150 [01:05<03:30,  1.83s/it]

EPOCH 35	Training Loss: 1.866251164012485


 24%|██▍       | 36/150 [01:08<03:56,  2.07s/it]

EPOCH 36	Training Loss: 1.8830842574437459


 25%|██▍       | 37/150 [01:10<03:49,  2.03s/it]

EPOCH 37	Training Loss: 1.8686563703748915


 25%|██▌       | 38/150 [01:11<03:39,  1.96s/it]

EPOCH 38	Training Loss: 1.8196662134594388


 26%|██▌       | 39/150 [01:13<03:33,  1.92s/it]

EPOCH 39	Training Loss: 1.8365361955430772


 27%|██▋       | 40/150 [01:15<03:27,  1.88s/it]

EPOCH 40	Training Loss: 1.7979642020331488


 27%|██▋       | 41/150 [01:17<03:25,  1.88s/it]

EPOCH 41	Training Loss: 1.7866535186767578


 28%|██▊       | 42/150 [01:19<03:33,  1.98s/it]

EPOCH 42	Training Loss: 1.7713714175754123


 29%|██▊       | 43/150 [01:21<03:26,  1.93s/it]

EPOCH 43	Training Loss: 1.7144736051559448


 29%|██▉       | 44/150 [01:23<03:20,  1.89s/it]

EPOCH 44	Training Loss: 1.7190001143349543


 30%|███       | 45/150 [01:24<03:15,  1.86s/it]

EPOCH 45	Training Loss: 1.7166011995739408


 31%|███       | 46/150 [01:26<03:11,  1.84s/it]

EPOCH 46	Training Loss: 1.680895487467448


 31%|███▏      | 47/150 [01:28<03:07,  1.82s/it]

EPOCH 47	Training Loss: 1.6962439616521199


 32%|███▏      | 48/150 [01:30<03:16,  1.93s/it]

EPOCH 48	Training Loss: 1.6475027667151556


 33%|███▎      | 49/150 [01:32<03:12,  1.91s/it]

EPOCH 49	Training Loss: 1.5961262120140924


 33%|███▎      | 50/150 [01:34<03:07,  1.88s/it]

EPOCH 50	Training Loss: 1.6014500988854303


 34%|███▍      | 51/150 [01:36<03:04,  1.86s/it]

EPOCH 51	Training Loss: 1.585058675871955


 35%|███▍      | 52/150 [01:37<02:59,  1.83s/it]

EPOCH 52	Training Loss: 1.5958850516213312


 35%|███▌      | 53/150 [01:39<02:56,  1.82s/it]

EPOCH 53	Training Loss: 1.5361384153366089


 36%|███▌      | 54/150 [01:41<02:59,  1.87s/it]

EPOCH 54	Training Loss: 1.5345568259557087


 37%|███▋      | 55/150 [01:43<03:02,  1.93s/it]

EPOCH 55	Training Loss: 1.5266283353169758


 37%|███▋      | 56/150 [01:45<02:57,  1.89s/it]

EPOCH 56	Training Loss: 1.5213567283418443


 38%|███▊      | 57/150 [01:47<02:54,  1.88s/it]

EPOCH 57	Training Loss: 1.4893718825446234


 39%|███▊      | 58/150 [01:49<02:49,  1.85s/it]

EPOCH 58	Training Loss: 1.4576346344417996


 39%|███▉      | 59/150 [01:50<02:46,  1.83s/it]

EPOCH 59	Training Loss: 1.4395011530982122


 40%|████      | 60/150 [01:52<02:44,  1.83s/it]

EPOCH 60	Training Loss: 1.4310261408487956


 41%|████      | 61/150 [01:55<02:53,  1.95s/it]

EPOCH 61	Training Loss: 1.4050494035085042


 41%|████▏     | 62/150 [01:56<02:46,  1.90s/it]

EPOCH 62	Training Loss: 1.413007656733195


 42%|████▏     | 63/150 [01:58<02:42,  1.87s/it]

EPOCH 63	Training Loss: 1.3589330116907756


 43%|████▎     | 64/150 [02:00<02:39,  1.85s/it]

EPOCH 64	Training Loss: 1.3070355256398518


 43%|████▎     | 65/150 [02:02<02:35,  1.83s/it]

EPOCH 65	Training Loss: 1.3309393723805745


 44%|████▍     | 66/150 [02:03<02:33,  1.82s/it]

EPOCH 66	Training Loss: 1.2884756061765883


 45%|████▍     | 67/150 [02:06<02:40,  1.93s/it]

EPOCH 67	Training Loss: 1.308069109916687


 45%|████▌     | 68/150 [02:08<02:37,  1.93s/it]

EPOCH 68	Training Loss: 1.2729564640257094


 46%|████▌     | 69/150 [02:10<02:42,  2.00s/it]

EPOCH 69	Training Loss: 1.2710354460610285


 47%|████▋     | 70/150 [02:12<02:35,  1.95s/it]

EPOCH 70	Training Loss: 1.2479805019166734


 47%|████▋     | 71/150 [02:13<02:30,  1.91s/it]

EPOCH 71	Training Loss: 1.2255434062745836


 48%|████▊     | 72/150 [02:15<02:26,  1.88s/it]

EPOCH 72	Training Loss: 1.2208065059449937


 49%|████▊     | 73/150 [02:17<02:31,  1.97s/it]

EPOCH 73	Training Loss: 1.211789197391934


 49%|████▉     | 74/150 [02:19<02:28,  1.96s/it]

EPOCH 74	Training Loss: 1.1810348563724093


 50%|█████     | 75/150 [02:21<02:23,  1.91s/it]

EPOCH 75	Training Loss: 1.1550989813274808


 51%|█████     | 76/150 [02:23<02:19,  1.88s/it]

EPOCH 76	Training Loss: 1.1624592145284016


 51%|█████▏    | 77/150 [02:25<02:15,  1.86s/it]

EPOCH 77	Training Loss: 1.1415471368365817


 52%|█████▏    | 78/150 [02:27<02:12,  1.85s/it]

EPOCH 78	Training Loss: 1.1466715865665011


 53%|█████▎    | 79/150 [02:29<02:13,  1.88s/it]

EPOCH 79	Training Loss: 1.1157784660657246


 53%|█████▎    | 80/150 [02:31<02:15,  1.93s/it]

EPOCH 80	Training Loss: 1.1002569728427463


 54%|█████▍    | 81/150 [02:33<02:14,  1.95s/it]

EPOCH 81	Training Loss: 1.0946767727533977


 55%|█████▍    | 82/150 [02:34<02:09,  1.91s/it]

EPOCH 82	Training Loss: 1.0642831789122686


 55%|█████▌    | 83/150 [02:36<02:06,  1.88s/it]

EPOCH 83	Training Loss: 1.056271453698476


 56%|█████▌    | 84/150 [02:38<02:02,  1.86s/it]

EPOCH 84	Training Loss: 1.066340532567766


 57%|█████▋    | 85/150 [02:40<02:00,  1.85s/it]

EPOCH 85	Training Loss: 1.0618784957461886


 57%|█████▋    | 86/150 [02:42<02:05,  1.96s/it]

EPOCH 86	Training Loss: 1.0352265172534518


 58%|█████▊    | 87/150 [02:44<02:00,  1.92s/it]

EPOCH 87	Training Loss: 1.0303849180539448


 59%|█████▊    | 88/150 [02:46<01:56,  1.88s/it]

EPOCH 88	Training Loss: 1.016437702708774


 59%|█████▉    | 89/150 [02:48<01:54,  1.88s/it]

EPOCH 89	Training Loss: 0.9846938649813334


 60%|██████    | 90/150 [02:49<01:51,  1.86s/it]

EPOCH 90	Training Loss: 0.9588468472162882


 61%|██████    | 91/150 [02:51<01:48,  1.84s/it]

EPOCH 91	Training Loss: 0.988774471812778


 61%|██████▏   | 92/150 [02:53<01:54,  1.97s/it]

EPOCH 92	Training Loss: 0.9639181163575914


 62%|██████▏   | 93/150 [02:55<01:49,  1.92s/it]

EPOCH 93	Training Loss: 0.9536265995767381


 63%|██████▎   | 94/150 [02:57<01:45,  1.89s/it]

EPOCH 94	Training Loss: 0.9313975241449144


 63%|██████▎   | 95/150 [02:59<01:42,  1.86s/it]

EPOCH 95	Training Loss: 0.9257659051153395


 64%|██████▍   | 96/150 [03:01<01:39,  1.84s/it]

EPOCH 96	Training Loss: 0.9203978247112699


 65%|██████▍   | 97/150 [03:02<01:37,  1.83s/it]

EPOCH 97	Training Loss: 0.9302511612574259


 65%|██████▌   | 98/150 [03:05<01:39,  1.91s/it]

EPOCH 98	Training Loss: 0.9203450083732605


 66%|██████▌   | 99/150 [03:07<01:38,  1.94s/it]

EPOCH 99	Training Loss: 0.8865557577874925


 67%|██████▋   | 100/150 [03:08<01:34,  1.89s/it]

EPOCH 100	Training Loss: 0.8836608396636115


 67%|██████▋   | 101/150 [03:10<01:31,  1.86s/it]

EPOCH 101	Training Loss: 0.8680721190240648


 68%|██████▊   | 102/150 [03:12<01:28,  1.84s/it]

EPOCH 102	Training Loss: 0.8495398561159769


 69%|██████▊   | 103/150 [03:14<01:25,  1.83s/it]

EPOCH 103	Training Loss: 0.8457812401983473


 69%|██████▉   | 104/150 [03:16<01:25,  1.85s/it]

EPOCH 104	Training Loss: 0.832474496629503


 70%|███████   | 105/150 [03:18<01:27,  1.95s/it]

EPOCH 105	Training Loss: 0.8496011164453294


 71%|███████   | 106/150 [03:20<01:23,  1.91s/it]

EPOCH 106	Training Loss: 0.8456379837459989


 71%|███████▏  | 107/150 [03:21<01:20,  1.87s/it]

EPOCH 107	Training Loss: 0.7921909027629428


 72%|███████▏  | 108/150 [03:23<01:17,  1.85s/it]

EPOCH 108	Training Loss: 0.8092865347862244


 73%|███████▎  | 109/150 [03:25<01:15,  1.84s/it]

EPOCH 109	Training Loss: 0.7863647209273444


 73%|███████▎  | 110/150 [03:27<01:13,  1.83s/it]

EPOCH 110	Training Loss: 0.7913515567779541


 74%|███████▍  | 111/150 [03:29<01:16,  1.96s/it]

EPOCH 111	Training Loss: 0.8062997063000997


 75%|███████▍  | 112/150 [03:31<01:12,  1.91s/it]

EPOCH 112	Training Loss: 0.8470706144968668


 75%|███████▌  | 113/150 [03:33<01:09,  1.89s/it]

EPOCH 113	Training Loss: 0.7803338832325406


 76%|███████▌  | 114/150 [03:35<01:07,  1.86s/it]

EPOCH 114	Training Loss: 0.7861003875732422


 77%|███████▋  | 115/150 [03:36<01:05,  1.86s/it]

EPOCH 115	Training Loss: 0.7859143747223748


 77%|███████▋  | 116/150 [03:38<01:02,  1.84s/it]

EPOCH 116	Training Loss: 0.8040179544024997


 78%|███████▊  | 117/150 [03:40<01:03,  1.93s/it]

EPOCH 117	Training Loss: 0.7608887685669793


 79%|███████▊  | 118/150 [03:42<01:01,  1.93s/it]

EPOCH 118	Training Loss: 0.7506815923584832


 79%|███████▉  | 119/150 [03:44<00:58,  1.89s/it]

EPOCH 119	Training Loss: 0.7319479783376058


 80%|████████  | 120/150 [03:46<00:56,  1.88s/it]

EPOCH 120	Training Loss: 0.7519436412387424


 81%|████████  | 121/150 [03:48<00:54,  1.87s/it]

EPOCH 121	Training Loss: 0.7234826750225491


 81%|████████▏ | 122/150 [03:50<00:51,  1.86s/it]

EPOCH 122	Training Loss: 0.745041999551985


 82%|████████▏ | 123/150 [03:52<00:51,  1.89s/it]

EPOCH 123	Training Loss: 0.7578997479544746


 83%|████████▎ | 124/150 [03:54<00:50,  1.95s/it]

EPOCH 124	Training Loss: 0.7319431238704257


 83%|████████▎ | 125/150 [03:55<00:47,  1.90s/it]

EPOCH 125	Training Loss: 0.6957673364215426


 84%|████████▍ | 126/150 [03:57<00:44,  1.87s/it]

EPOCH 126	Training Loss: 0.7050213283962674


 85%|████████▍ | 127/150 [03:59<00:42,  1.85s/it]

EPOCH 127	Training Loss: 0.7103815277417501


 85%|████████▌ | 128/150 [04:01<00:40,  1.84s/it]

EPOCH 128	Training Loss: 0.7029147413041856


 86%|████████▌ | 129/150 [04:03<00:38,  1.82s/it]

EPOCH 129	Training Loss: 0.6500289506382413


 87%|████████▋ | 130/150 [04:05<00:39,  1.95s/it]

EPOCH 130	Training Loss: 0.6923292875289917


 87%|████████▋ | 131/150 [04:07<00:36,  1.91s/it]

EPOCH 131	Training Loss: 0.7124749024709066


 88%|████████▊ | 132/150 [04:08<00:33,  1.88s/it]

EPOCH 132	Training Loss: 0.677945527765486


 89%|████████▊ | 133/150 [04:10<00:32,  1.91s/it]

EPOCH 133	Training Loss: 0.6842288242446052


 89%|████████▉ | 134/150 [04:12<00:30,  1.89s/it]

EPOCH 134	Training Loss: 0.6427081425984701


 90%|█████████ | 135/150 [04:14<00:27,  1.86s/it]

EPOCH 135	Training Loss: 0.6520717607604133


 91%|█████████ | 136/150 [04:16<00:27,  1.96s/it]

EPOCH 136	Training Loss: 0.6734937561882867


 91%|█████████▏| 137/150 [04:18<00:25,  1.97s/it]

EPOCH 137	Training Loss: 0.6753495666715834


 92%|█████████▏| 138/150 [04:20<00:24,  2.02s/it]

EPOCH 138	Training Loss: 0.6482662691010369


 93%|█████████▎| 139/150 [04:22<00:21,  1.96s/it]

EPOCH 139	Training Loss: 0.6228387819396125


 93%|█████████▎| 140/150 [04:24<00:19,  1.92s/it]

EPOCH 140	Training Loss: 0.6172140902943082


 94%|█████████▍| 141/150 [04:26<00:17,  1.89s/it]

EPOCH 141	Training Loss: 0.6352841787868075


 95%|█████████▍| 142/150 [04:28<00:15,  1.99s/it]

EPOCH 142	Training Loss: 0.6477331452899509


 95%|█████████▌| 143/150 [04:30<00:13,  1.94s/it]

EPOCH 143	Training Loss: 0.6149721344312032


 96%|█████████▌| 144/150 [04:32<00:11,  1.90s/it]

EPOCH 144	Training Loss: 0.6276714536878798


 97%|█████████▋| 145/150 [04:34<00:09,  1.87s/it]

EPOCH 145	Training Loss: 0.5764651364750333


 97%|█████████▋| 146/150 [04:35<00:07,  1.85s/it]

EPOCH 146	Training Loss: 0.6321766409609053


 98%|█████████▊| 147/150 [04:37<00:05,  1.85s/it]

EPOCH 147	Training Loss: 0.6218758424123129


 99%|█████████▊| 148/150 [04:39<00:03,  1.92s/it]

EPOCH 148	Training Loss: 0.6408179733488295


 99%|█████████▉| 149/150 [04:41<00:01,  1.93s/it]

EPOCH 149	Training Loss: 0.6240955922338698


100%|██████████| 150/150 [04:43<00:00,  1.89s/it]

EPOCH 150	Training Loss: 0.6417749457889133





In [None]:
model.eval()
with torch.no_grad():
    sample = qa_dataset[150]
    context, question, answer = sample['context'], sample['question'], sample['answers']['text'][0]
    # context = 'Jane is a student and she is from Canada'
    # question = 'What does Jane do?'
    # answer = 'student'
    context = text_normalize(context)
    question = text_normalize(question)
    answer = text_normalize(answer)
    input_ids, start_positions, end_positions = vectorize(question, context, answer)
    input_ids = input_ids.to(device)
    start_positions = start_positions.to(device)
    end_positions = end_positions.to(device)
    input_ids = input_ids.unsqueeze(0)
    start_logits, end_logits = model(input_ids)

    offset = len(tokenizer(question)) + 2
    start_position = torch.argmax(start_logits, dim=1).cpu().numpy()[0]
    end_position = torch.argmax(end_logits, dim=1).cpu().numpy()[0]

    start_position -= offset
    end_position -= offset

    start_position = max(start_position, 0)
    end_position = min(end_position, len(tokenizer(context)) - 1)

    if end_position >= start_position:
        # Extract the predicted answer span
        context_tokens = tokenizer(context)
        predicted_answer_tokens = context_tokens[start_position:end_position + 1]
        predicted_answer = ' '.join(predicted_answer_tokens)
    else:
        predicted_answer = ''

    print(f'Context: {context}')
    print(f'Question: {question}')
    print(f'Start position: {start_position}')
    print(f'End position: {end_position}')
    print(f'Answer span: {predicted_answer}')
    print(answer)

Context: The College Dropout was eventually issued by Roc A Fella in February 2004  shooting to number two on the Billboard 200 as his debut single   Through the Wire  peaked at number fifteen on the Billboard Hot 100 chart for five weeks   Slow Jamz   his second single featuring Twista and Jamie Foxx  became an even bigger success  it became the three musicians  first number one hit  The College Dropout received near universal critical acclaim from contemporary music critics  was voted the top album of the year by two major music publications  and has consistently been ranked among the great hip hop works and debut albums by artists   Jesus Walks   the album s fourth single  perhaps exposed West to a wider audience  the song s subject matter concerns faith and Christianity  The song nevertheless reached the top 20 of the Billboard pop charts  despite industry executives  predictions that a song containing such blatant declarations of faith would never make it to radio  The College Dro