<a href="https://colab.research.google.com/github/deguc/Shannon/blob/main/004_CBOW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

def onehot(x,k):

    return np.identity(k)[x]


def dataset(token,vocab_size):

    n = len(token)
    x,y = [],[]

    for i in range(n-2):
        x += [[token[i],token[i+2]]]
        y += [token[i+1]]


    return np.vstack(x),onehot(np.hstack(y),vocab_size)


class DataLoader:

    def __init__(self,dataset,batch_size=10):

        self.x,self.y = dataset
        self.batch_size = batch_size
        self.data_size = self.x.shape[0]
        self.cnt = 0

    def shuffle(self):

        idx = np.random.permutation(self.data_size)
        self.x = self.x[idx]
        self.y = self.y[idx]

    def __iter__(self):
        return self

    def __len__(self):
        return self.data_size // self.batch_size

    def __getitem__(self,idx):

        if idx < 0 or idx > self.data_size:
            raise IndexError('out of index')

        i = idx*self.batch_size
        j = i + self.batch_size

        return self.x[i:j],self.y[i:j]

    def __next__(self):

        if self.cnt == 0:
            self.shuffle()

        i = self.cnt
        j = self.cnt + self.batch_size

        if j > self.data_size:
            self.cnt = 0
            raise StopIteration

        else:
            self.cnt = j
            return self.x[i:j],self.y[i:j]


def zeros_ps(ps):

    gs = []

    for p in ps:
        gs += [np.zeros_like(p)]

    return gs


class Module:

    def __init__(self):
        self.ps,self.gs = [],[]
        self.inputs = None
        self.train_flag = False
        self.mask = None


class Affine(Module):

    def __init__(self,d_in,d_out):
        super().__init__()

        std = np.sqrt(d_in/2)
        self.ps = [
            np.random.randn(d_in,d_out)/std
        ]
        self.gs = zeros_ps(self.ps)

    def __call__(self,x):

        self.inputs = x

        return x @ self.ps[0]

    def backward(self,dout):

        self.gs[0][...] = self.inputs.T @ dout

        return dout @ self.ps[0].T


class CBOW:

    def __init__(self,vocab_size,d_emb):

        W = np.random.randn(vocab_size,d_emb)
        self.emb1 = Embedding(W)
        self.emb2 = Embedding(W)
        self.aff = Affine(d_emb,vocab_size)

        self.layers = [self.emb1,self.emb2,self.aff]
        self.ps =[[],[]]

        for l in self.layers:
            self.ps[0] += l.ps
            self.ps[1] += l.gs

    def __call__(self,x):

        out1 = self.emb1(x[:,0])
        out2 = self.emb2(x[:,1])

        out = (out1+out2)*0.5

        return self.aff(out)


    def backward(self,dout):

        dout = self.aff.backward(dout)
        dout *= 0.5
        self.emb1.backward(dout)
        self.emb2.backward(dout)

    def pred(self,x):
        return np.argmax(self(x),axis=1)

    def train(self):
        for l in self.layers:
            l.train_flag = True

    def eval(self):
        for l in self.layers:
            l.train_flag = False


def cross_entropy(y,t):

    eps = 1e-6

    return -np.sum(t*np.log(y+eps))


def softmax(x):

    c = np.max(x,axis=-1,keepdims=True)
    z = np.exp(x-c)

    return z/np.sum(z,axis=-1,keepdims=True)


class Loss:

    def __init__(self,model,clf=softmax,loss=cross_entropy):

        self.model = model
        self.clf = clf
        self.loss = cross_entropy
        self.dout = None

    def __call__(self,y,t):

        out = self.clf(y)
        self.dout = out - t

        return self.loss(out,t)

    def backward(self):

        self.model.backward(self.dout)


class AdamW:

    def __init__(self,ps,lr=0.1,beta1=0.2,beta2=0.9,weight_decay=0.1):

        self.ps = ps
        self.cache = (lr,beta1,beta2,weight_decay)
        self.cnt = 0
        self.hs =[
            zeros_ps(ps[0]),
            zeros_ps(ps[0])
        ]

    def __call__(self):

        eps = 1e-6
        ps,gs = self.ps
        ms,hs = self.hs
        lr,b1,b2,w = self.cache
        self.cnt += 1
        n = self.cnt

        for p,g,m,h in zip(ps,gs,ms,hs):

            m[...] = b1*m + (1-b1)*g
            h[...] = b2*h + (1-b2)*g*g

            m0 = m/(1-b1**n)
            h0 = h/(1-b2**n)

            p -= lr*w*g

            p -= lr*m0 /(np.sqrt(h0)+eps)


def trainer(model,loss,optimizer,data,epochs=100):

    L = []

    for i in range(epochs):

        model.train()
        l = 0

        for x,t in data:

            y = model(x)
            l += loss(y,t)
            loss.backward()
            optimizer()

        L += [l/len(data)]

        model.eval()

    return L


def disp_loss(loss):

    plt.title('Loss Fuction')
    plt.xlabel('epochs')
    plt.ylabel('cross entropy')
    plt.plot(loss)
    plt.show()


class Tokenizer:

    def __init__(self,text):

        self.text = text
        words = [c for c in text]
        self.vocab = {}
        for w in words:
            if w not in self.vocab:
                self.vocab[w] = len(self.vocab)
        self.dic = {k:v for v,k in self.vocab.items()}
        self.token = [self.vocab[w] for w in words ]
        self.vocab_size = len(self.vocab)


class Embedding(Module):

    def __init__(self,W):
        self.idx = None
        self.ps = [W]
        self.gs = zeros_ps(self.ps)

    def __call__(self,idx):

        W, = self.ps
        self.idx = idx

        return W[idx]

    def backward(self,dout):

        dW, = self.gs
        dW[...] = 0

        np.add.at(dW,self.idx,dout)


np.set_printoptions(precision=2,suppress=True)

text = '素数は無限に存在する。素数とは自分自身と１でしか割り切れない数である。２は素数である。５７は３と１９で割り切れるので素数ではない。'

tk = Tokenizer(text)
token = tk.token
vocab_size = tk.vocab_size
dataset = dataset(token,vocab_size)
data = DataLoader(dataset,batch_size=5)

model = CBOW(vocab_size,4*vocab_size)
loss = Loss(model)
optimizer = AdamW(model.ps,lr=0.01)
l = trainer(model,loss,optimizer,data,epochs=200)

test_x,test_y = data[0]
print(test_x)
print(np.argmax(test_y,axis=1))
pred = model.pred(test_x)
print(pred)
