In [1]:
import pandas as pd
df = pd.read_csv('may15nov17_above130_less100.csv')
trunc = df[df.score > 5000]

In [2]:
class charVocabulary(object):
    def __init__(self, token_to_idx=None):
        if token_to_idx is None:
            token_to_idx = {}
        self.token_to_idx = token_to_idx
        self.idx_to_token = {idx: token 
                                for token, idx in self.token_to_idx.items()}

        self.mask_token = '<mask>'
        self.begin_token = '<begin>'
        self.end_token = '<end>'
        self.unk_token = '<unk>'
        self.space_token = ' '

        self.mask_idx = self.add_token(self.mask_token)
        self.begin_idx = self.add_token(self.begin_token)
        self.end_idx = self.add_token(self.end_token)
        self.unk_idx = self.add_token(self.unk_token)
        self.space_idx = self.add_token(self.space_token)

    def add_token(self, token):
        if token in self.token_to_idx:
            index = self.token_to_idx[token]
        else:
            index = len(self.token_to_idx)
            self.token_to_idx[token] = index
            self.idx_to_token[index] = token
        return index

    def __len__(self):
        assert len(self.token_to_idx) == len(self.idx_to_token)
        return len(self.token_to_idx)

    def lookup_token(self,token):
        return self.token_to_idx[token]

    def lookup_idx(self,i):
        return self.idx_to_token[i]

    def add_txt(self,path):
        with open(path, 'r') as f:
            fulltext = f.read()
            for c in fulltext:
                if c != '\n':
                    self.add_token(c)
        return None

    def add_series(self,df):
        for sentence in df:
            max_len = min(300, len(sentence))
            for char in sentence[:max_len]:
                self.add_token(char)
        return None

In [3]:
vocab = charVocabulary()
vocab.add_series(trunc.title)

In [4]:
import numpy as np
class charVectorizer(object):
    def __init__(self,vocab):
        self.vocab = vocab

    def vectorize(self, sent, max_len=-1):
        """
        max_len is used to know how much to pad
        """
        ind = [self.vocab.begin_idx]
        ind.extend(self.vocab.lookup_token(token) for token in sent)
        ind.append(self.vocab.end_idx)
        
        max_len = max(len(ind), max_len) + 1

        x = np.empty(max_len-1, dtype=np.int64)
        x[:len(ind)-1] = ind[:-1]
        x[len(ind)-1:] = self.vocab.mask_idx

        y = np.empty(max_len-1, dtype=np.int64)
        y[:len(ind)-1] = ind[1:]
        y[len(ind)-1:] = self.vocab.mask_idx

        return x,y

In [5]:
vectorizer = charVectorizer(vocab=vocab)

In [18]:
vectorizer.vectorize('i want bananas', max_len=17)

(array([ 1,  7,  4, 12, 17, 22, 13,  4, 35, 17, 22, 17, 22, 17, 18,  0,  0]),
 array([ 7,  4, 12, 17, 22, 13,  4, 35, 17, 22, 17, 22, 17, 18,  2,  0,  0]))

In [7]:
x,_ = vectorizer.vectorize('i like', max_len=30)
x.shape

(30,)

In [46]:
from torch.utils.data import Dataset, DataLoader
class charDataset(Dataset):
    def __init__(self,vectorizer,posts):
        self.posts = posts
        self.vectorizer = vectorizer

        max_len = len(posts.iloc[0])
        for sentence in posts:
            max_len = max(max_len, len(sentence))

        self.max_len = max_len + 20

    def __len__(self):
        return len(self.posts)
    
    def __getitem__(self,i):
        sent = self.posts.iloc[i]
        x,y = self.vectorizer.vectorize(sent=sent, max_len=self.max_len)
        assert x.shape == y.shape
        assert x.shape[0] == self.max_len
        return x,y

In [47]:
class fakeDS(Dataset):
    def __init__(self,vectorizer):
        self.vectorizer = vectorizer
        self.max_len = 8
        
    def __len__(self):
        return 512
    
    def __getitem__(self,i):
        x,y = self.vectorizer.vectorize(sent='hello.', max_len=8)
        return x,y

In [48]:
# ds = fakeDS(vectorizer)
posts = trunc.title
ds = charDataset(vectorizer=vectorizer,posts=posts)
dl = DataLoader(ds, batch_size=32, shuffle=True)

In [49]:
for i in range(len(ds)):
    try:
        ds.__getitem__(i)
    except:
        print(i)

x,y = ds.__getitem__(0)
l = x.shape[0]
print(x.shape)
print(y.shape)

(120,)
(120,)


In [53]:
import torch
import torch.nn.functional as F
from torch import nn
params = {}
params['max_len'] = ds.max_len
params['num_emb'] = len(vocab)
params['emb_dim'] = 128
params['mask_id'] = vocab.mask_idx

In [54]:
class mini_transformer(nn.Module):
    def __init__(self,num_emb,emb_dim,max_len,mask_id):
        super(mini_transformer,self).__init__()
        
        self.max_len = max_len
        
        self.emb = nn.Embedding(num_embeddings=num_emb, embedding_dim=emb_dim)#, padding_idx=mask_id)
        self.pos_emb = nn.Embedding(num_embeddings=max_len,embedding_dim=emb_dim)
        self.query = nn.Linear(in_features=emb_dim, out_features=emb_dim)
        self.key = nn.Linear(in_features=emb_dim, out_features=emb_dim)
        self.value = nn.Linear(in_features=emb_dim, out_features=emb_dim)
        self.fc = nn.Linear(in_features=emb_dim, out_features=num_emb)
    
    def forward(self,x_in,verbose=False):
        x = self.emb(x_in)
        b,s,d = x.size()

        positions = torch.arange(s)
        positions = self.pos_emb(positions)
        positions = positions[None, :, :]
        positions = positions.expand(b, s, d)

        x = x + positions
        
        if verbose:
            print(x_in.shape)
        # each row is a vector of size emb_dim
        if verbose:
            print(x.shape)
            
        q = self.query(x)
        k = self.query(x)
        v = self.query(x)
            
        raw_weights = torch.bmm(q, k.transpose(1,2))
        if verbose:
            print(raw_weights.shape)
        _,m,n = raw_weights.size()
        indices = torch.triu_indices(m,n, offset=1)
        raw_weights[:, indices[0], indices[1]] = float('-inf')
        weights = F.softmax(raw_weights, dim=2)
        x1 = torch.bmm(weights, v)
        if verbose:
            print(x1.shape)

        x2 = x1.contiguous().view(b*s, -1)
        x3 = self.fc(x2)
        out = x3.view(b,s,-1)
        if verbose:
            print(out.shape)
        return out

In [55]:
class silly(nn.Module):
    def __init__(self,num_emb,emb_dim,max_len,mask_id):
        super(silly,self).__init__()
        
        self.emb = nn.Embedding(num_embeddings=num_emb, embedding_dim=emb_dim)#, padding_idx=mask_id)
        self.lin1 = nn.Linear(in_features=emb_dim, out_features=emb_dim)
        self.lu = nn.ReLU()
        self.fc = nn.Linear(in_features=emb_dim, out_features=num_emb)
    
    def forward(self,x_in,verbose=False):
        x = self.emb(x_in)
        b,s,d = x.size()
        x = x.contiguous().view(b*s, -1)
        x = self.lin1(x)
        x = self.lu(x)
        x = self.fc(x)
        out = x.view(b,s,-1)
        return out

In [57]:
def decode_seq(vocab,vectors):
    b,s,d = vectors.size()
    assert d == len(vocab)
    x = vectors[0]
    probs = F.softmax(x, dim=1)
    sent = ''
    for i in range(s):
        v = probs[i,:]
        # replace with argmax?
        win = torch.multinomial(v, num_samples=1)
        idx = win.item()
        sent += vocab.lookup_idx(idx)
    
    return sent

In [59]:
model = mini_transformer(**params)

In [60]:
optimizer = torch.optim.Adam(model.parameters())
num_epochs = 2
device = 'cpu'
from tqdm import tqdm
bestloss = float('inf')
for epoch in range(num_epochs):
    ### train ----
    model.train()
    model.to(device)
    for data in dl:
        x,y = data
        x.to(device)
        y.to(device)
        y_pred = model(x)
        b,s,d = y_pred.shape
        y_pred_to_loss = y_pred.view(b*s,d)
        y_to_loss = y.view(-1)
        
        optimizer.zero_grad()
        loss = F.cross_entropy(y_pred_to_loss, y_to_loss)#, ignore_index=mask_id)
        loss.backward()
        optimizer.step()
        
        if loss.item() < bestloss:
            bestloss = loss.item()
            print(loss.item(), f"epoch {epoch+1}")   

4.717811107635498 epoch 1
4.483091831207275 epoch 1
4.156666278839111 epoch 1
3.993055582046509 epoch 1
3.7387778759002686 epoch 1
3.417160987854004 epoch 1
3.297903299331665 epoch 1
3.0389068126678467 epoch 1
2.9301400184631348 epoch 1
2.8643951416015625 epoch 1
2.8088467121124268 epoch 1
2.4485177993774414 epoch 1
2.439161777496338 epoch 1
2.363636016845703 epoch 1
2.347338914871216 epoch 1
2.2844488620758057 epoch 1
2.101871967315674 epoch 1
2.0506784915924072 epoch 1
1.9568417072296143 epoch 1
1.9176318645477295 epoch 1
1.8144317865371704 epoch 2
1.8043575286865234 epoch 2
1.70589017868042 epoch 2
1.6552937030792236 epoch 2
1.5934199094772339 epoch 2
1.437622308731079 epoch 2


In [61]:
x,y = next(iter(dl))

In [62]:
decode_vector(vocab,model(x))

'Idtk stsSu irohuh  tifcsbseyg😂bsny eeo ebci it  d e,tgus<unk>nreaarirotkolNoocgtla d#vohde h<end><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask>'

In [75]:
gen_samp(model,vectorizer)

torch.Size([1, 2])
W<mask>


In [141]:
def decode_idx(ind):
    s = ''
    for idx in ind:
        if idx == 0:
            break
        s += vocab.lookup_idx(idx)
    return s

In [155]:
def gen_samp(model,vocab,sample_size=30,prompt=""):
    ind = [vocab.begin_idx]
    ind.extend([vocab.lookup_token(char) for char in prompt])
    ind.extend([vocab.mask_idx for _ in range(len(prompt), model.max_len - 1)]) # plus or minus 1...
    assert model.max_len == len(ind)

    for i in range(len(prompt), sample_size):
        x = torch.tensor(ind).unsqueeze(dim=0)
        pred = model(x)
        
        b,s,d = pred.size()
        assert d == len(vocab)
        z = pred[0,i,:] # plus or minus one?
        prob = F.softmax(z,dim=0)
        win = torch.multinomial(prob, num_samples=1)
        ind[i+1] = win.item()
        
    return decode_idx(ind)

In [156]:
print(gen_samp(model,vocab,prompt='vai'))

<begin>vaillor b0 trac wof atal i ofe


In [104]:
bos = [vocab.begin_idx]
# ind = bos.extend([vocab.lookup_token(char) for char in prompt])
print(bos)


[1]


In [105]:
a = [1,2]
b = [3,4]

In [106]:
a.extend(b)

In [107]:
print(a)

[1, 2, 3, 4]
