In [None]:
!wget https://objectstore.e2enetworks.net/ai4b-public-nlu-nlg/indic-corp-frozen-for-the-paper-oct-2022/mr.txt

In [1]:
# This essentially does a bigram but at a word level. But it takes word level as too literally.

In [2]:
# Load the dataset

k = 50000
input_file_path = './data/mr.txt'
output_file_path = f"./data/mr_{k}.txt"

# Function to read the first k lines from the input file and write them to the output file
def read_and_write_first_k_lines(input_file, output_file, num_lines=1000):
    try:
        with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
            for i in range(num_lines):
                line = infile.readline()
                if not line:  # End of file reached before 1000 lines
                    break
                outfile.write(line)
        print(f"Successfully wrote the first {num_lines} lines to {output_file}.")
    except Exception as e:
        print(f"An error occurred: {e}")

# Call the function
read_and_write_first_k_lines(input_file_path, output_file_path, k)

data_file = output_file_path
with open(data_file, 'r') as file:
    lines = file.readlines()

Successfully wrote the first 50000 lines to ./data/mr_50000.txt.


In [19]:
import re
import string

# Create consecutive pairs
word_pairs = []
freq_dict = {}
all_words = []
for line in lines:
    if line.strip() != "":
        words = line.split(" ")
        # Remove blank words
        words = [word.strip() for word in words if word.strip() != "" ]
        all_words.extend(words)
        # Start sos and eos character
        words_augmented  = ['sos'] + words + ['eos']
        for word1, word2 in zip(words_augmented, words_augmented[1:]):
            # Remove punctuations
            word1 = word1.translate(str.maketrans('', '', string.punctuation))
            word2 = word2.translate(str.maketrans('', '', string.punctuation))
            freq_dict[(word1, word2)] = freq_dict.get((word1, word2), 0) + 1
            word_pairs.extend([(word1, word2)])

# all_words.extend(['eos', 'sos'])
# Add bigram for eos end with eos: Not necessary, just for consistency
# bigrams[('eos', 'eos')] = 1
print(word_pairs[0])
print(word_pairs[1])
print(word_pairs[2])

('sos', 'ऊती')
('ऊती', 'संवर्धन')
('संवर्धन', 'तंत्राचे')


In [20]:
import torch

all_words = []
_ = [all_words.extend([word1, word2]) for word1, word2 in word_pairs]
distinct_words = list(set(all_words))
print(len(word_pairs))
print(f"{len(all_words)=}")
print(f"{len(distinct_words)=}")

1066668
len(all_words)=2133336
len(distinct_words)=140310


In [21]:
word_to_i = {word: i for i, word in enumerate(distinct_words)}
i_to_word = {i: word for word, i in word_to_i.items()}

In [22]:
# The Number of words is too damn high. Let's filter out the ones which occurs only once

word_pairs_indexed = [(word_to_i[word1], word_to_i[word2]) for word1, word2 in word_pairs]
high_freq_pairs = [x for x, freq in freq_dict.items() if freq > 20]
# Hack to select from large word pairs, store sum of their indexes. This will give some extra, but that's okay
high_freq_pairs_key = [word_to_i[word1] + word_to_i[word2] for word1, word2 in high_freq_pairs]
high_freq_pairs_key = set(high_freq_pairs_key)
# print(len(high_freq_pairs))
# print(high_freq_pairs_key[0])
# This seems workable now
freq_word_pairs = []
for word1, word2 in word_pairs:
    if (word_to_i[word1] + word_to_i[word2]) in high_freq_pairs_key:
        freq_word_pairs.extend([(word1, word2)])
# _ = [freq_word_pairs.extend([word1, word2]) for word1, word2 in word_pairs if (word1, word2) in high_freq_pairs]
print(f"{len(all_words)=}")
print(f"{len(distinct_words)=}")
all_words = []
_ = [all_words.extend([word1, word2]) for word1, word2 in freq_word_pairs]
distinct_words = list(set(all_words))
print(f"{len(all_words)=}")
print(f"{len(distinct_words)=}")

len(all_words)=2133336
len(distinct_words)=140310
len(all_words)=289592
len(distinct_words)=9531


In [23]:
word_to_i = {word: i for i, word in enumerate(distinct_words)}
i_to_word = {i: word for word, i in word_to_i.items()}

In [24]:
for i in freq_word_pairs[:5]:
    print(i)

('आहेत', 'या')
('तयार', 'करणे')
('झाल्या', 'आहेत')
('जाते', 'eos')
('येणार', 'आहेत')


In [25]:
print(len(word_to_i))

9531


In [92]:
# Network would be
# Create a dataset => xs => word^i amd ys => word(i+1)
xs = []
ys = []
for word1, word2 in freq_word_pairs:
    xs.append(word_to_i[word1])
    ys.append(word_to_i[word2])
xs = torch.tensor(xs)
ys = torch.tensor(ys)
print(f"{xs[0]=}")
print(f"{ys[0]=}")
embe_dims = len(distinct_words)
print(f"{embe_dims=}")

xs[0]=tensor(2039)
ys[0]=tensor(1587)
embe_dims=9531


In [82]:
print(xs.shape)

torch.Size([144796])


In [93]:
import torch.nn.functional as F

w = torch.randn((embe_dims, embe_dims), requires_grad=True)
x_enc = F.one_hot(xs, num_classes=embe_dims).float()
print(x_enc.shape)
print(x_enc[0][2039])

torch.Size([144796, 9531])
tensor(1.)


In [84]:
print(x_enc.shape)

torch.Size([144796, 9531])


In [94]:
logits = x_enc @ w
counts = logits.exp()
prob = counts / counts.sum(dim=1,keepdims=True)
print(prob.shape)
print(prob[0].shape)

torch.Size([144796, 9531])
torch.Size([9531])


In [48]:
print(prob[0].sum())

tensor(1.0000)


In [95]:
for i, prediction in enumerate(prob[:10]):
    # get the actual word
    inp_word_index = xs[i].item()
    inp_word = i_to_word[inp_word_index]
    acc_word_index = ys[i].item()
    acc_word = i_to_word[acc_word_index]
    pred_word_prob = prob[i][acc_word_index]
    print(f"{inp_word=}")
    print(f"{acc_word=}")
    print(f"{pred_word_prob=}")
    print(">>>>>>>>>>>>>>>>>>>>>>")

inp_word='आहेत'
acc_word='या'
pred_word_prob=tensor(7.9002e-05, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='तयार'
acc_word='करणे'
pred_word_prob=tensor(2.4228e-05, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='झाल्या'
acc_word='आहेत'
pred_word_prob=tensor(2.5861e-05, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='जाते'
acc_word='eos'
pred_word_prob=tensor(7.8676e-05, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='येणार'
acc_word='आहेत'
pred_word_prob=tensor(8.3468e-06, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='आहेत'
acc_word='मात्र'
pred_word_prob=tensor(0.0002, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='केले'
acc_word='आहे'
pred_word_prob=tensor(0.0002, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='आहे'
acc_word='त्यामुळे'
pred_word_prob=tensor(1.6232e-05, grad_fn=<SelectBackward0>)
>>>>>>>>>>>>>>>>>>>>>>
inp_word='केली'
acc_word='आहे'
pred_word_prob=tensor(0.0001, grad_fn=<SelectB

In [97]:
w = torch.randn((embe_dims, embe_dims), requires_grad=True)
x_enc = F.one_hot(xs, num_classes=embe_dims).float()

In [98]:
num_epoch = 2
for i in range(num_epoch):
    print(f"Executing epoch: {i}")
    logits = x_enc @ w
    counts = logits.exp()
    # Probabilities of the next character
    prob = counts / counts.sum(dim=1,keepdims=True) 
    loss = -prob[torch.arange(num), ys].log().mean()
    print(loss.item())

    # Backward pass
    w.grad = None
    loss.backward()

    w.data += 100 * w.grad

Executing epoch: 0
9.567719459533691
Executing epoch: 1
9.81497859954834


In [99]:
# Let's try to get this to predict the words
for _ in range(5):
    # Start with 'sos'
    last_token = 'sos'
    num_tokens = 0
    word = ''
    while num_tokens < 30:
        # Feed forward the word
        x = word_to_i[last_token]
        x = torch.tensor([x])
        x_enc = F.one_hot(x, num_classes=embe_dims).float()
        # print(x_enc.shape)
        logits = x_enc @ w
        counts = logits.exp()
        # print(counts.shape)
        # Probabilities of the next character
        prob = counts / counts.sum(dim=1,keepdims=True)
        max_values, max_indices = torch.max(prob, dim=1)
        # print(max_values, max_indices)
        last_token = i_to_word[max_indices[0].item()]
        word += ' '
        word += last_token
        # print(prob)
        # if last_token == 'eos':
        #     break
        # word += ' ' 
        # word += last_token
        num_tokens += 1

    print(word, num_tokens)

 निवडणूक नेतृत्वाखाली पूर्ण बीकॉम जनावरांचे बर पोहोचविण्यासाठी खावा’ योजनेसाठी संप देवदेवतांचे सदस्यांची व्यक्तीस कॅम्पसमध्येच घटनादुरुस्तीने तलावांची वेदान्ताशिवाय निवृत्त बुवांनी ‘मानिनी मतदारसंघात पायरोगनियस काळू सामन्याचे अयोध्येत त्यांचे एकमार्गी च वाघ चालक 30
 निवडणूक नेतृत्वाखाली पूर्ण बीकॉम जनावरांचे बर पोहोचविण्यासाठी खावा’ योजनेसाठी संप देवदेवतांचे सदस्यांची व्यक्तीस कॅम्पसमध्येच घटनादुरुस्तीने तलावांची वेदान्ताशिवाय निवृत्त बुवांनी ‘मानिनी मतदारसंघात पायरोगनियस काळू सामन्याचे अयोध्येत त्यांचे एकमार्गी च वाघ चालक 30
 निवडणूक नेतृत्वाखाली पूर्ण बीकॉम जनावरांचे बर पोहोचविण्यासाठी खावा’ योजनेसाठी संप देवदेवतांचे सदस्यांची व्यक्तीस कॅम्पसमध्येच घटनादुरुस्तीने तलावांची वेदान्ताशिवाय निवृत्त बुवांनी ‘मानिनी मतदारसंघात पायरोगनियस काळू सामन्याचे अयोध्येत त्यांचे एकमार्गी च वाघ चालक 30
 निवडणूक नेतृत्वाखाली पूर्ण बीकॉम जनावरांचे बर पोहोचविण्यासाठी खावा’ योजनेसाठी संप देवदेवतांचे सदस्यांची व्यक्तीस कॅम्पसमध्येच घटनादुरुस्तीने तलावांची वेदान्ताशिवाय निवृत्त बुवांनी ‘मानिनी मतदारसंघात पा

In [75]:
print(word_to_i['sos'])
print(word_to_i['eos'])

1220
4060


In [1]:
import torch

matrix = torch.randn(1, 2, 3)
matrix

tensor([[[-0.2439, -0.6762, -1.1942],
         [ 0.7649, -0.0185,  0.5003]]])

In [3]:
(matrix + matrix).shape

torch.Size([1, 2, 3])