<a href="https://colab.research.google.com/github/ferjorosa/learn-pytorch/blob/main/Examples/cbow_human_numbers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Objective of this notebook:

* To implement a simple CBOW model and compare its results in the "human numbers" data with those produced by our LSTM and GRU models from chapter 12 in FastAI book.

* To better understand the output of nn.Embedding when multiple words are provided. 

In the data example with a context of size 3, a batch size of 64, and a embedding dimension of 64, we would have the following tensor shapes:

```python
> inputs.shape
torch.Size([64, 3])
> x.shape
torch.Size([64, 3, 64])
> y.shape
torch.Size([64, 64])
> out.shape
torch.Size([64, 30])
```

In [1]:
#hide (Google Colab)
!pip install fastai --upgrade -q
import fastai
print(fastai.__version__)

!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()


[K     |████████████████████████████████| 189 kB 4.4 MB/s 
[K     |████████████████████████████████| 55 kB 3.5 MB/s 
[?25h2.5.3
[K     |████████████████████████████████| 720 kB 3.5 MB/s 
[K     |████████████████████████████████| 1.2 MB 13.0 MB/s 
[K     |████████████████████████████████| 48 kB 4.0 MB/s 
[K     |████████████████████████████████| 51 kB 269 kB/s 
[K     |████████████████████████████████| 558 kB 39.8 MB/s 
[K     |████████████████████████████████| 130 kB 36.4 MB/s 
[?25hMounted at /content/gdrive


In [2]:
# hide (debugging)
!pip install -Uqq ipdb
import ipdb
%pdb on

[K     |████████████████████████████████| 792 kB 5.2 MB/s 
[K     |████████████████████████████████| 380 kB 36.8 MB/s 
[?25h  Building wheel for ipdb (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-console 5.2.0 requires prompt-toolkit<2.0.0,>=1.0.0, but you have prompt-toolkit 3.0.28 which is incompatible.
google-colab 1.0.0 requires ipython~=5.5.0, but you have ipython 7.31.1 which is incompatible.[0m
Automatic pdb calling has been turned ON


In [3]:
import torch.nn as nn
from fastbook import *
from fastai.text.all import *

In [4]:
path = untar_data(URLs.HUMAN_NUMBERS)

Path.BASE_PATH = path

In [5]:
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
lines

(#9998) ['one \n','two \n','three \n','four \n','five \n','six \n','seven \n','eight \n','nine \n','ten \n'...]

In [6]:
text = ' . '.join([l.strip() for l in lines])
tokens = text.split(' ')
vocab = L(*tokens).unique()
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)

In [7]:
#seqs_raw = L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3)) # raw form

seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3)) # coded-number form
seqs

(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10,  1, 11]), 1),(tensor([ 1, 12,  1]), 13),(tensor([13,  1, 14]), 1),(tensor([ 1, 15,  1]), 16)...]

**Note:** `seqs_raw` is not valid because our model expects tensor data and **tensors can only be in numeric form**

In [8]:
bs = 64
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False) # train, validation

In [9]:
class CBOW(Module):

  def __init__(self, vsz, nh):
    self.i_h = nn.Embedding(vsz, nh)
    self.h_o = nn.Linear(nh, vsz)
  
  def forward(self, inputs):
    x = self.i_h(inputs)
    y = torch.mean(x, axis=1)
    out = self.h_o(y)
    #ipdb.set_trace()
    return out

In [10]:
learn = Learner(dls, CBOW(len(vocab), 64), loss_func=F.cross_entropy, 
                metrics=accuracy)
learn.fit_one_cycle(4, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,2.732165,2.553152,0.431662,00:02
1,2.13393,2.170956,0.435465,00:02
2,1.984797,2.079117,0.435465,00:04
3,1.940968,2.074865,0.434039,00:04
