In [24]:
# A simple predict the next character by looking at last 1 character.
# Unlike Andrej's video we don't predict words, we predict sentences

In [1]:
!wget https://objectstore.e2enetworks.net/ai4b-public-nlu-nlg/indic-corp-frozen-for-the-paper-oct-2022/mr.txt

--2024-07-08 16:08:50--  https://objectstore.e2enetworks.net/ai4b-public-nlu-nlg/indic-corp-frozen-for-the-paper-oct-2022/mr.txt
Resolving objectstore.e2enetworks.net (objectstore.e2enetworks.net)... 164.52.210.97, 101.53.152.33, 164.52.206.154, ...
Connecting to objectstore.e2enetworks.net (objectstore.e2enetworks.net)|164.52.210.97|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 14553660884 (14G) [text/plain]
Saving to: ‘mr.txt’

mr.txt                0%[                    ]       0  --.-KB/s               ^C


In [6]:
# Load the dataset

k = 50000
input_file_path = './data/mr.txt'
output_file_path = f"./data/mr_{k}.txt"

# Function to read the first k lines from the input file and write them to the output file
def read_and_write_first_k_lines(input_file, output_file, num_lines=1000):
    try:
        with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
            for i in range(num_lines):
                line = infile.readline()
                if not line:  # End of file reached before 1000 lines
                    break
                outfile.write(line)
        print(f"Successfully wrote the first {num_lines} lines to {output_file}.")
    except Exception as e:
        print(f"An error occurred: {e}")

# Call the function
read_and_write_first_k_lines(input_file_path, output_file_path, k)

data_file = output_file_path
with open(data_file, 'r') as file:
    lines = file.readlines()

Successfully wrote the first 50000 lines to ./data/mr_50000.txt.


In [35]:
# Let's build a vocabulary
# Note that I'm still using <sos> and <eos> to mark start and end of sentence. Our vocab us character level,
# but just as exception even though '<sos>' and '<eos>' are not a char, we are still treating them as one and mapping 
# them to a single digit. Weird but wht not
vocab = set()
xs = []
ys = []
for line in lines:
    if line.strip() != "":
        line = line.strip()
        xs.append('<sos>')
        ys.append(line[0]) # or ys.append(line[:1])

        for ch1, ch2 in zip(line, line[1:]):
            vocab.add(ch1)
            xs.append(ch1)
            ys.append(ch2)

        # Last char in the line isn't added to vocab yet
        vocab.add(line[-1:])
        xs.append(line[-1:])
        ys.append('<eos>')

vocab.add('<sos>')
vocab.add('<eos>')
vocab = list(set(vocab))
vocab_size = len(vocab)
print(vocab_size)

337


In [36]:
word_to_i = {word: i for i, word in enumerate(vocab)}
i_to_word = {i: word for word, i in word_to_i.items()}

In [37]:
print(word_to_i['.'])
print(word_to_i['~'])
print(word_to_i['e'])
print(word_to_i[' '])

289
132
93
288


In [38]:
print(xs[:20], ys[:20])

['<sos>', 'ऊ', 'त', 'ी', ' ', 'स', 'ं', 'व', 'र', '्', 'ध', 'न', ' ', 'त', 'ं', 'त', '्', 'र', 'ा', 'च'] ['ऊ', 'त', 'ी', ' ', 'स', 'ं', 'व', 'र', '्', 'ध', 'न', ' ', 'त', 'ं', 'त', '्', 'र', 'ा', 'च', 'े']


In [76]:
import torch

xs_enc = []
ys_enc = []
for inp, out in zip(xs, ys):
    temp = []
    xs_enc.append(word_to_i[out])
    ys_enc.append(word_to_i[out])
xs_enc = torch.tensor(xs_enc)
ys_enc = torch.tensor(ys_enc)
print(xs_enc.shape)
print(ys_enc.shape)
num_samples = len(xs)

torch.Size([7182887])
torch.Size([7182887])


In [77]:
# Initialise weights

w = torch.randn((vocab_size, vocab_size), requires_grad=True)

In [86]:
import torch.nn.functional as F
# Forwards pass

xs_one_hot = F.one_hot(xs_enc, num_classes=vocab_size).float()
# print(xs_one_hot.shape) # This should be: 7182887 * 337
logits = xs_one_hot @ w
counts = logits.exp()
probs = counts / counts.sum(dim=1, keepdim=True)

In [87]:
# Get the loss
prob_assigned_to_correct_label = probs[torch.arange(num_samples), ys_enc]
loss = -1 * prob_assigned_to_correct_label.log().mean()
loss

tensor(6.8572, grad_fn=<MulBackward0>)

In [88]:
# Backword prop
w.grad = None
loss.backward()

In [85]:
# Update weights
w.data += 10.0 * w.grad

In [90]:
# Let's combine all
epoch = 10

# Initialise weight
w = torch.randn((vocab_size, vocab_size), requires_grad=True)
xs_one_hot = F.one_hot(xs_enc, num_classes=vocab_size).float()

for _ in range(epoch):
    # Forward pass
    logits = xs_one_hot @ w
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    # Get the loss
    prob_assigned_to_correct_label = probs[torch.arange(num_samples), ys_enc]
    loss = -1 * prob_assigned_to_correct_label.log().mean()
    print(loss)
    # Backword prop
    w.grad = None
    loss.backward()
    # Update weights
    w.data += -0.5 * w.grad

tensor(6.6308, grad_fn=<MulBackward0>)
tensor(6.6049, grad_fn=<MulBackward0>)
tensor(6.5790, grad_fn=<MulBackward0>)
tensor(6.5530, grad_fn=<MulBackward0>)
tensor(6.5271, grad_fn=<MulBackward0>)
tensor(6.5012, grad_fn=<MulBackward0>)
tensor(6.4753, grad_fn=<MulBackward0>)
tensor(6.4494, grad_fn=<MulBackward0>)
tensor(6.4234, grad_fn=<MulBackward0>)
tensor(6.3975, grad_fn=<MulBackward0>)


In [93]:
# Let's try to predict some sentences

for _ in range(10):
    # Start with 'sos'
    last_token = '<sos>'
    num_tokens = 0
    sentence = ''
    while num_tokens < 30:
        # Feed forward the word
        xs = word_to_i[last_token]
        xs_enc = torch.tensor([xs])
        xs_one_hot = F.one_hot(xs_enc, num_classes=vocab_size).float()
        logits = xs_one_hot @ w
        counts = logits.exp()
        prob = counts / counts.sum(dim=1, keepdims=True)
        max_values, max_indices = torch.max(prob, dim=1)
        last_token = i_to_word[max_indices[0].item()]
        sentence += last_token
        if last_token == '<eos>':
            break
        num_tokens += 1

    print(sentence, num_tokens)

ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
ॲॕன0yуೀ�7�7�7�7�7�7�7�7�7�7�7� 30
