In [None]:
!pip install torchtext==0.15.1
!pip install torch==2.1.0
!pip install transformers==4.27.1
!pip install datasets==2.17.0

In [None]:
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

qa_dataset = [
    {
        'context': 'My name is AIVN and I am from Vietnam.',
        'question': 'What is my name?',
        'answer': 'AIVN'
    },
    {
        'context': 'I love painting and my favorite artist is Vincent Van Gogh.',
        'question': 'What is my favorite activity?',
        'answer': 'painting'
    },
    {
        'context': 'I am studying computer science at the University of Tokyo.',
        'question': 'What am I studying?',
        'answer': 'computer science'
    },
    {
        'context': 'My favorite book is "To Kill a Mockingbird" by Harper Lee.',
        'question': 'What is my favorite book?',
        'answer': '"To Kill a Mockingbird"'
    },
    {
        'context': 'I have a pet dog named Max who loves to play fetch.',
        'question': 'What is the name of my pet?',
        'answer': 'Max'
    },
    {
        'context': 'I was born in Paris, but now I live in New York City.',
        'question': 'Where do I live now?',
        'answer': 'New York City'
    }
]

data_size = len(qa_dataset)
data_size

6

In [None]:
# Define tokenizer function
tokenizer = get_tokenizer('basic_english')

# Create a function to yield list of tokens
def yield_tokens(data):
    for item in data:
        yield tokenizer('<cls> ' + item['context'] + ' <sep> ' + item['question'])

# Create vocabulary
vocab = build_vocab_from_iterator(
    yield_tokens(qa_dataset),
    specials=['<unk>', '<pad>', '<bos>', '<eos>', '<sep>', '<cls>']
)
vocab.set_default_index(vocab['<unk>'])
vocab.get_stoi()

{'vincent': 59,
 'to': 25,
 ',': 26,
 'pet': 22,
 'who': 62,
 'gogh': 40,
 'the': 24,
 'fetch': 38,
 'play': 53,
 'van': 57,
 'now': 20,
 'was': 60,
 'a': 15,
 'name': 14,
 'am': 13,
 'named': 49,
 'aivn': 28,
 'i': 6,
 'studying': 23,
 'and': 16,
 'where': 61,
 '<unk>': 0,
 'favorite': 12,
 'by': 33,
 'artist': 29,
 'live': 19,
 '<eos>': 3,
 'harper': 41,
 'dog': 37,
 'loves': 46,
 '.': 9,
 'born': 31,
 '<pad>': 1,
 'computer': 35,
 '<cls>': 5,
 'is': 7,
 'my': 8,
 'book': 17,
 'science': 54,
 'of': 21,
 '<bos>': 2,
 '<sep>': 4,
 'what': 11,
 'at': 30,
 'but': 32,
 'in': 18,
 'from': 39,
 'tokyo': 55,
 'city': 34,
 'have': 42,
 'kill': 43,
 'lee': 44,
 'love': 45,
 '?': 10,
 'do': 36,
 'max': 47,
 'mockingbird': 48,
 'york': 63,
 'new': 50,
 'painting': 51,
 'paris': 52,
 'university': 56,
 'activity': 27,
 'vietnam': 58}

In [None]:
MAX_SEQ_LEN = 22
PAD_IDX = vocab['<pad>']

def pad_and_truncate(input_ids, max_seq_len):
    if len(input_ids) > max_seq_len:
        input_ids = input_ids[:max_seq_len]
    elif len(input_ids) < max_seq_len:
        input_ids += [PAD_IDX] * (max_seq_len - len(input_ids))

    return input_ids

def vectorize(question, context, answer):
    input_text = '<cls> ' + question + ' <sep> ' + context
    input_ids = [vocab[token] for token in tokenizer(input_text)]
    input_ids = pad_and_truncate(input_ids, MAX_SEQ_LEN)

    answer_ids = [vocab[token] for token in tokenizer(answer)]
    start_positions = input_ids.index(answer_ids[0])
    end_positions = start_positions + len(answer_ids) - 1

    input_ids = torch.tensor(input_ids, dtype=torch.long)
    start_positions = torch.tensor(start_positions, dtype=torch.long)
    end_positions = torch.tensor(end_positions, dtype=torch.long)

    return input_ids, start_positions, end_positions

In [None]:
class QADataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question_text = item['question']
        context_text = item['context']
        answer_text = item['answer']

        input_ids, start_positions, end_positions = vectorize(
            question_text, context_text, answer_text
        )

        return input_ids, start_positions, end_positions

In [None]:
def decode(input_ids):
    return ' '.join([vocab.lookup_token(token) for token in input_ids])

In [None]:
for item in qa_dataset:
    question_text = item['question']
    context_text = item['context']
    answer_text = item['answer']
    input_ids, start_positions, end_positions = vectorize(question_text, context_text, answer_text)
    print(input_ids)
    text = decode(input_ids)
    answer_span = input_ids[start_positions:end_positions+1]

    print(text)
    print(decode(answer_span))

tensor([ 5, 11,  7,  8, 14, 10,  4,  8, 14,  7, 28, 16,  6, 13, 39, 58,  9,  1,
         1,  1,  1,  1])
<cls> what is my name ? <sep> my name is aivn and i am from vietnam . <pad> <pad> <pad> <pad> <pad>
aivn
tensor([ 5, 11,  7,  8, 12, 27, 10,  4,  6, 45, 51, 16,  8, 12, 29,  7, 59, 57,
        40,  9,  1,  1])
<cls> what is my favorite activity ? <sep> i love painting and my favorite artist is vincent van gogh . <pad> <pad>
painting
tensor([ 5, 11, 13,  6, 23, 10,  4,  6, 13, 23, 35, 54, 30, 24, 56, 21, 55,  9,
         1,  1,  1,  1])
<cls> what am i studying ? <sep> i am studying computer science at the university of tokyo . <pad> <pad> <pad> <pad>
computer science
tensor([ 5, 11,  7,  8, 12, 17, 10,  4,  8, 12, 17,  7, 25, 43, 15, 48, 33, 41,
        44,  9,  1,  1])
<cls> what is my favorite book ? <sep> my favorite book is to kill a mockingbird by harper lee . <pad> <pad>
to kill a mockingbird
tensor([ 5, 11,  7, 24, 14, 21,  8, 22, 10,  4,  6, 42, 15, 22, 37, 49, 47, 62,
     

In [None]:
train_dataset = QADataset(qa_dataset)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [None]:
import math
import torch.nn as nn
import torch.optim as optim

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                          num_heads=num_heads)
        self.ffn = nn.Linear(in_features=embed_dim,
                             out_features=ff_dim)
        self.layernorm_1 = nn.LayerNorm(normalized_shape=embed_dim)
        self.layernorm_2 = nn.LayerNorm(normalized_shape=embed_dim)

    def forward(self, query, key, value):
        attn_output, _ = self.attn(query, key, value)
        out_1 = self.layernorm_1(query + attn_output)
        ffn_output = self.ffn(out_1)
        x = self.layernorm_2(out_1 + ffn_output)

        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]

        return x

class QAModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_heads, ff_dim, max_len):
        super(QAModel, self).__init__()
        self.input_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = PositionalEncoding(embedding_dim, max_len)
        self.transformer = TransformerBlock(embedding_dim, n_heads, ff_dim)

        self.start_linear = nn.Linear(ff_dim, 1)
        self.end_linear = nn.Linear(ff_dim, 1)

    def forward(self, text):
        input_embedded = self.input_embedding(text)
        input_embedded = self.pos_encoder(input_embedded)
        transformer_out = self.transformer(input_embedded, input_embedded, input_embedded)
        start_logits = self.start_linear(transformer_out).squeeze(-1)
        end_logits = self.end_linear(transformer_out).squeeze(-1)

        return start_logits, end_logits

# Model parameters
EMBEDDING_DIM = 128
FF_DIM = 128
N_HEADS = 1
VOCAB_SIZE = len(vocab)

model = QAModel(VOCAB_SIZE, EMBEDDING_DIM, N_HEADS, FF_DIM, MAX_SEQ_LEN)

input = torch.randint(0, 10, size=(1, 10))
print(input.shape)
model.eval()
with torch.no_grad():
    start_logits, end_logits = model(input)

print(start_logits.shape)

torch.Size([1, 10])
torch.Size([1, 10])


In [None]:
LR = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

In [None]:
EPOCHS = 10

model.train()
for _ in range(EPOCHS):
    for idx, (input_ids, start_positions, end_positions) in enumerate(train_loader):
        optimizer.zero_grad()
        start_logits, end_logits = model(input_ids)
        start_loss = criterion(start_logits, start_positions)
        end_loss = criterion(end_logits, end_positions)
        total_loss = (start_loss + end_loss) / 2
        total_loss.backward()
        optimizer.step()
        print(total_loss.item())

3.1906681060791016
3.0200304985046387
3.717452049255371
3.034095287322998
2.224654197692871
3.13753342628479
1.8548295497894287
2.875330924987793
1.7314162254333496
1.935828447341919
1.2891161441802979
1.709120750427246
1.2487807273864746
0.9418723583221436
1.5946886539459229
0.8978770971298218
0.737031877040863
1.4568142890930176
1.0518808364868164
0.5805083513259888
0.6011573672294617
0.37238582968711853
0.38469740748405457
0.30013278126716614
0.34407907724380493
0.47778111696243286
0.06372884660959244
0.14127442240715027
0.2023245245218277
0.1742343306541443


In [None]:
model.eval()
with torch.no_grad():
    sample = qa_dataset[4]
    context, question, answer = sample.values()
    input_ids, start_positions, end_positions = vectorize(question, context, answer)
    input_ids = input_ids.unsqueeze(0)
    start_logits, end_logits = model(input_ids)

    offset = len(tokenizer(question)) + 2
    start_position = torch.argmax(start_logits, dim=1).numpy()[0]
    end_position = torch.argmax(end_logits, dim=1).numpy()[0]

    start_position -= offset
    end_position -= offset

    start_position = max(start_position, 0)
    end_position = min(end_position, len(tokenizer(context)) - 1)

    if end_position >= start_position:
        # Extract the predicted answer span
        context_tokens = tokenizer(context)
        predicted_answer_tokens = context_tokens[start_position:end_position + 1]
        predicted_answer = ' '.join(predicted_answer_tokens)
    else:
        predicted_answer = ''

    print(f'Context: {context}')
    print(f'Question: {question}')
    print(f'Start position: {start_position}')
    print(f'End position: {end_position}')
    print(f'Answer span: {predicted_answer}')

Context: I have a pet dog named Max who loves to play fetch.
Question: What is the name of my pet?
Start position: 6
End position: 6
Answer span: max
