# Model building

Now we start designing models to do stuff

In [1]:
import math
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from ataarangi.utils import split_chunks
from ataarangi.data import encode_world_state, TextTokenizer, WorldStateTokenizer, SequenceTokenizer, RākauDataset

In [2]:
# Initialize tokenizers
tokenizer = SequenceTokenizer('../data/worldstate_tokens.txt', '../data/tokens.txt')
world_state_tokenizer = WorldStateTokenizer('../data/worldstate_tokens.txt')
text_tokenizer = TextTokenizer('../data/tokens.txt')

In [3]:
rākau_data = pd.read_csv('../data/rākau_data.csv')
rākau_data['rākau'] = rākau_data.rākau.apply(json.loads)
rākau_data = rākau_data[rākau_data.num_rākau <= 10].reset_index(drop=True)

In [4]:
rākau_data.sort_values('num_rākau', ascending=False)

Unnamed: 0,id,entropy,num_rākau,rākau,description
653,eb825132-1e22-44ff-8b94-c62644d80390,8.090296,10,"[{'color': 'white', 'height': 7, 'location': 1...",te rākau mā me te rākau māwhero nui rawa me te...
323,ab4d2e1e-880e-4664-8ce1-01593cedbf76,7.690296,10,"[{'color': 'pink', 'height': 2, 'location': 1,...",ngā rākau kikorangi me ngā rākau whero me ngā ...
337,9ef16aba-598c-4c55-bf0b-a0d1fdfc56da,8.490296,10,"[{'color': 'pink', 'height': 3, 'location': 1,...",te rākau kikorangi me te rākau mā iti
972,8ee330e2-3f6a-42bd-b28b-68a8bc9e757f,7.814807,10,"[{'color': 'white', 'height': 10, 'location': ...",te rākau mā me ngā rākau parauri nui rawa e ru...
647,9cc7f006-0cac-4ef7-9769-810204eae383,7.890296,10,"[{'color': 'black', 'height': 10, 'location': ...",ngā rākau katoa hāunga te rākau mā iti rawa
...,...,...,...,...,...
420,9fc48bf4-9c4c-4726-8829-113665d0558c,3.000000,2,"[{'color': 'brown', 'height': 2, 'location': 1...",te rākau pango
419,c762566a-b4e1-4191-bf2b-6d7ab7e77f82,3.000000,2,"[{'color': 'pink', 'height': 3, 'location': 1,...",ngā rākau
418,56ee6946-e664-4b23-bd1c-824d5ce9436e,2.000000,2,"[{'color': 'brown', 'height': 4, 'location': 1...",te rākau pango
417,edfa86f8-1ec4-4470-bfc2-03684bd2d516,3.000000,2,"[{'color': 'yellow', 'height': 7, 'location': ...",te rākau māwhero


In [5]:
rākau_data['input'] = rākau_data.rākau.apply(world_state_tokenizer.tokenize)
rākau_data['target'] = rākau_data.description.apply(text_tokenizer.tokenize)

In [6]:
rākau_data

Unnamed: 0,id,entropy,num_rākau,rākau,description,input,target
0,800d23b5-574c-46d3-94cf-1066082e9d7d,3.0,2,"[{'color': 'blue', 'height': 4, 'location': 1,...",te rākau mā,"[1, 3, 5, 15, 2, 9, 19, 22]","[24, 23, 27, 53]"
1,659289fa-7473-4617-a8f0-61324cc0e3b1,2.0,2,"[{'color': 'blue', 'height': 1, 'location': 1,...",ngā rākau,"[1, 2, 5, 12, 2, 6, 12, 22]","[25, 23, 53]"
2,7629eca4-df36-4921-973b-0ca7859c5018,3.0,2,"[{'color': 'blue', 'height': 1, 'location': 1,...",te rākau iti,"[1, 2, 5, 12, 3, 6, 20, 22]","[24, 23, 35, 53]"
3,89413927-94ca-4078-8198-85b957338400,3.0,2,"[{'color': 'white', 'height': 8, 'location': 1...",te rākau mā,"[1, 2, 9, 19, 3, 6, 13, 22]","[24, 23, 27, 53]"
4,eb110390-7be8-4d7a-93e0-211c33c033f1,3.0,2,"[{'color': 'black', 'height': 10, 'location': ...",ngā rākau,"[1, 2, 8, 21, 2, 9, 19, 22]","[25, 23, 53]"
...,...,...,...,...,...,...,...
1007,3d543250-f7c8-47c4-9491-ab29d5c8c8fd,3.0,2,"[{'color': 'black', 'height': 9, 'location': 1...",te rākau kikorangi,"[1, 3, 8, 20, 2, 5, 19, 22]","[24, 23, 30, 53]"
1008,b6fcfe8b-5604-47d5-8e1a-1a2830219683,3.0,2,"[{'color': 'white', 'height': 4, 'location': 1...",te rākau mā,"[1, 2, 9, 15, 3, 8, 21, 22]","[24, 23, 27, 53]"
1009,ecc53392-b801-4159-8fe3-fe9475e50672,3.0,2,"[{'color': 'red', 'height': 9, 'location': 1, ...",te rākau whero,"[1, 2, 4, 20, 3, 8, 13, 22]","[24, 23, 33, 53]"
1010,bd1d8542-5e5f-4d03-9bf3-8b3d889fe737,3.0,2,"[{'color': 'red', 'height': 10, 'location': 1,...",ngā rākau,"[1, 2, 4, 21, 2, 7, 14, 22]","[25, 23, 53]"


In [39]:
rākau_data.loc[
    rākau_data.description.str.contains('taha [^m]'),
    'description'
].values

array([], dtype=object)

### Creating an n-gram model

We're gonna make a simple 1-gram model so we can use it to drive the transformer during training. The idea is that if we can constrain the model so that it can only generate tokens that fit the pattern of the examples in the example data, then perhaps it will learn more efficiently.

In [8]:
rākau_data['tokens'] = rākau_data.description.apply(lambda x: ' '.join(['[CLS]', x, '[END]']))
vocab_data = rākau_data.tokens.str.split().explode().value_counts().to_frame().reset_index()

vocab_type_dict = {
    'colour': ['mā', 'kōwhai', 'kikorangi', 'kākāriki', 'pango', 'māwhero', 'whero', 'parauri'],
    'det': ['te', 'ngā'],
    'conjunction': ['me'],
    'noun': ['rākau'],
    'size': ['iti', 'nui'],
    'except': ['hāunga'],
    'locative': ['kei'],
    'position': ['taha'],
    'side': ['mauī', 'matau'],
    'number': ['rua', 'toru'],
    'all': ['katoa'],
    'furthest': ['tawhiti_rawa'],
    'most': ['rawa'],
    'to': ['ki'],
    'in': ['i'],
    'e': ['e'],
    'preposition': ['waenganui'],
    'particle': ['mai'],
    'ordinal': ['tuarua'],
    '[END]': ['[END]'],
    '[CLS]': ['[CLS]']
}

word_to_type_dict = {word: typ for typ, words in vocab_type_dict.items() for word in words}

vocab_data['class'] = vocab_data.tokens.apply(lambda x: word_to_type_dict[x] if x in word_to_type_dict else None)
vocab_data = vocab_data.rename(columns={'tokens': 'kupu'})[['kupu', 'class', 'count']]
vocab_data

Unnamed: 0,kupu,class,count
0,rākau,noun,1467
1,te,det,1018
2,[CLS],[CLS],1012
3,[END],[END],1012
4,ngā,det,505
5,me,conjunction,392
6,mā,colour,136
7,kōwhai,colour,133
8,kākāriki,colour,132
9,kikorangi,colour,132


Now I have the data that I need to make the ngram model, including the class labels for each word (source + target). I can convert this into a matrix that will make the token generation better, but I should also do the same for the world state tokens as well.

It's possible we might want to implement other rules, such as to stop needless repetition, for example.

In [9]:
def get_grams(tokens, n=1):
    ngrams = [tokens[i:i+n+1] for i in range(len(tokens)-n)]
    ngram_dict = [{' '.join(ngram[:-1]): ngram[-1]} for ngram in ngrams]
    return ngram_dict
    
ngram_data = rākau_data.tokens.str.split().apply(get_grams).explode().value_counts().reset_index()
ngram_data['source'] = ngram_data.tokens.apply(lambda x: list(x.keys())[0])
ngram_data['target'] = ngram_data.tokens.apply(lambda x: list(x.values())[0])
ngram_data = ngram_data.drop(columns=['tokens'])
ngram_data = ngram_data[['source', 'target', 'count']]

ngram_data

Unnamed: 0,source,target,count
0,te,rākau,962
1,[CLS],te,540
2,ngā,rākau,505
3,[CLS],ngā,472
4,rākau,[END],361
...,...,...,...
114,taha,mā,1
115,kikorangi,ki,1
116,kākāriki,e,1
117,iti,e,1


## World state token rules

I need to get the world state token rules as well. The world state tokens are set up in a particular way, I might as well tell the model what the rules are for this step as well as for the text data.

In [10]:
' '.join([world_state_tokenizer.id_map[id] for id in world_state_tokenizer.tokenize(rākau_data['rākau'][0])])

'[SOS] [NOT_SELECTED] [COLOUR_BLUE] [HEIGHT_4] [SELECTED] [COLOUR_WHITE] [HEIGHT_8] [CLS]'

So the rules are:
- Start with `[SOS]`
- Then either `[SELECTED]` or `[NOT_SELECTED]`
- Then the colour
- Then the height
- then either another `[SELECTED]`/`[NOT_SELECTED]` and start again or `[CLS]` to end which will then start the text sequence

In [11]:
token_class_dict = {
    '[PAD]': ['[PAD]'],
    '[SOS]': ['[SOS]'],
    '[SELECTION]': ['[SELECTED]', '[NOT_SELECTED]'],
    '[COLOUR]': [
        '[COLOUR_RED]',
        '[COLOUR_BLUE]',
        '[COLOUR_GREEN]',
        '[COLOUR_YELLOW]',
        '[COLOUR_BLACK]',
        '[COLOUR_WHITE]',
        '[COLOUR_BROWN]',
        '[COLOUR_PINK]'],
    '[HEIGHT]': [f'[HEIGHT_{i}]' for i in range(1,11)],
    '[CLS]': ['[CLS]']
}
token_to_class_dict = {token: type for type, tokens in token_class_dict.items() for token in tokens}
token_to_class_dict.update(word_to_type_dict)
token_to_class_dict

{'[SOS]': '[SOS]',
 '[SELECTED]': '[SELECTION]',
 '[NOT_SELECTED]': '[SELECTION]',
 '[COLOUR_RED]': '[COLOUR]',
 '[COLOUR_BLUE]': '[COLOUR]',
 '[COLOUR_GREEN]': '[COLOUR]',
 '[COLOUR_YELLOW]': '[COLOUR]',
 '[COLOUR_BLACK]': '[COLOUR]',
 '[COLOUR_WHITE]': '[COLOUR]',
 '[COLOUR_BROWN]': '[COLOUR]',
 '[COLOUR_PINK]': '[COLOUR]',
 '[HEIGHT_1]': '[HEIGHT]',
 '[HEIGHT_2]': '[HEIGHT]',
 '[HEIGHT_3]': '[HEIGHT]',
 '[HEIGHT_4]': '[HEIGHT]',
 '[HEIGHT_5]': '[HEIGHT]',
 '[HEIGHT_6]': '[HEIGHT]',
 '[HEIGHT_7]': '[HEIGHT]',
 '[HEIGHT_8]': '[HEIGHT]',
 '[HEIGHT_9]': '[HEIGHT]',
 '[HEIGHT_10]': '[HEIGHT]',
 '[CLS]': '[CLS]',
 'mā': 'colour',
 'kōwhai': 'colour',
 'kikorangi': 'colour',
 'kākāriki': 'colour',
 'pango': 'colour',
 'māwhero': 'colour',
 'whero': 'colour',
 'parauri': 'colour',
 'te': 'det',
 'ngā': 'det',
 'me': 'conjunction',
 'rākau': 'noun',
 'iti': 'size',
 'nui': 'size',
 'hāunga': 'except',
 'kei': 'locative',
 'taha': 'position',
 'mauī': 'side',
 'matau': 'side',
 'rua': 'number

In [33]:
json.dump(token_to_class_dict, open('../data/token_to_class.json', 'w'), indent=4)

In [12]:
token_data = (rākau_data
     .rākau
     .apply(world_state_tokenizer.tokenize)
     .apply(lambda x: [world_state_tokenizer.id_map[id] for id in x])
     .explode()
     .value_counts()
     .to_frame()
     .reset_index()
     .rename(columns={'rākau': 'token'})
)
token_data['class'] = token_data.token.apply(lambda x: token_to_class_dict[x])
token_data = token_data[['token', 'class', 'count']]
token_data['data_type'] = 'world_state'
vocab_data['data_type'] = 'text'

token_data = pd.concat([token_data, vocab_data.rename(columns={'kupu': 'token'})]).reset_index(drop=True)
token_data = token_data[['token', 'data_type', 'class', 'count']]
token_data

Unnamed: 0,token,data_type,class,count
0,[SELECTED],world_state,[SELECTION],2512
1,[NOT_SELECTED],world_state,[SELECTION],1381
2,[SOS],world_state,[SOS],1012
3,[CLS],world_state,[CLS],1012
4,[COLOUR_WHITE],world_state,[COLOUR],532
5,[COLOUR_BLACK],world_state,[COLOUR],501
6,[COLOUR_RED],world_state,[COLOUR],491
7,[COLOUR_GREEN],world_state,[COLOUR],483
8,[COLOUR_PINK],world_state,[COLOUR],479
9,[COLOUR_YELLOW],world_state,[COLOUR],475


In [13]:
token_data.to_csv('../data/token_data.csv', index=False)

In [14]:
ngram_data['source_class'] = ngram_data['source'].apply(lambda x: token_to_class_dict[x])
ngram_data['target_class'] = ngram_data['target'].apply(lambda x: token_to_class_dict[x])
ngram_data = ngram_data[['source', 'target', 'source_class', 'target_class', 'count']]
ngram_data

Unnamed: 0,source,target,source_class,target_class,count
0,te,rākau,det,noun,962
1,[CLS],te,[CLS],det,540
2,ngā,rākau,det,noun,505
3,[CLS],ngā,[CLS],det,472
4,rākau,[END],noun,[END],361
...,...,...,...,...,...
114,taha,mā,position,colour,1
115,kikorangi,ki,colour,to,1
116,kākāriki,e,colour,e,1
117,iti,e,size,e,1


In [15]:
token_to_class_dict

{'[SOS]': '[SOS]',
 '[SELECTED]': '[SELECTION]',
 '[NOT_SELECTED]': '[SELECTION]',
 '[COLOUR_RED]': '[COLOUR]',
 '[COLOUR_BLUE]': '[COLOUR]',
 '[COLOUR_GREEN]': '[COLOUR]',
 '[COLOUR_YELLOW]': '[COLOUR]',
 '[COLOUR_BLACK]': '[COLOUR]',
 '[COLOUR_WHITE]': '[COLOUR]',
 '[COLOUR_BROWN]': '[COLOUR]',
 '[COLOUR_PINK]': '[COLOUR]',
 '[HEIGHT_1]': '[HEIGHT]',
 '[HEIGHT_2]': '[HEIGHT]',
 '[HEIGHT_3]': '[HEIGHT]',
 '[HEIGHT_4]': '[HEIGHT]',
 '[HEIGHT_5]': '[HEIGHT]',
 '[HEIGHT_6]': '[HEIGHT]',
 '[HEIGHT_7]': '[HEIGHT]',
 '[HEIGHT_8]': '[HEIGHT]',
 '[HEIGHT_9]': '[HEIGHT]',
 '[HEIGHT_10]': '[HEIGHT]',
 '[CLS]': '[CLS]',
 'mā': 'colour',
 'kōwhai': 'colour',
 'kikorangi': 'colour',
 'kākāriki': 'colour',
 'pango': 'colour',
 'māwhero': 'colour',
 'whero': 'colour',
 'parauri': 'colour',
 'te': 'det',
 'ngā': 'det',
 'me': 'conjunction',
 'rākau': 'noun',
 'iti': 'size',
 'nui': 'size',
 'hāunga': 'except',
 'kei': 'locative',
 'taha': 'position',
 'mauī': 'side',
 'matau': 'side',
 'rua': 'number

In [16]:
class_gram_data = ngram_data.groupby(['source_class', 'target_class']).sum('count').reset_index()
class_gram_data = class_gram_data.drop(columns=['count'])
class_gram_data = class_gram_data.groupby(['source_class'])['target_class'].apply(list).to_frame().reset_index()

class_gram_dict = {row.source_class: row.target_class for i, row in class_gram_data.iterrows()}
class_gram_dict.update({
    '[PAD]': ['[PAD]'],
    '[SOS]': ['[SELECTION]'],
    '[SELECTION]': ['[COLOUR]'],
    '[COLOUR]': ['[HEIGHT]'],
    '[HEIGHT]': ['[CLS]', '[SELECTION]'],
    '[END]': ['[SOS]']
})
class_gram_dict

{'[CLS]': ['det'],
 'all': ['except'],
 'colour': ['[END]',
  'all',
  'conjunction',
  'e',
  'furthest',
  'in',
  'locative',
  'ordinal',
  'size',
  'to'],
 'conjunction': ['det'],
 'det': ['noun', 'position', 'side'],
 'e': ['number'],
 'except': ['det'],
 'furthest': ['to'],
 'in': ['det', 'preposition'],
 'locative': ['det', 'preposition'],
 'most': ['[END]', 'conjunction', 'e'],
 'noun': ['[END]', 'all', 'colour', 'e', 'except', 'furthest', 'size'],
 'number': ['[END]', 'conjunction', 'furthest', 'locative'],
 'ordinal': ['particle'],
 'particle': ['in'],
 'position': ['colour', 'side'],
 'preposition': ['[END]', 'conjunction'],
 'side': ['[END]', 'conjunction'],
 'size': ['[END]', 'conjunction', 'e', 'locative', 'most'],
 'to': ['det'],
 '[SOS]': ['[SELECTION]'],
 '[SELECTION]': ['[COLOUR]'],
 '[COLOUR]': ['[HEIGHT]'],
 '[HEIGHT]': ['[CLS]', '[SELECTION]'],
 '[END]': ['[SOS]']}

In [34]:
json.dump(class_gram_dict, open('../data/class_successors.json', 'w'), indent=4)

In [17]:
rākau = rākau_data.rākau[0]
words = rākau_data.description[0]
tokenizer.decode(tokenizer.tokenize(rākau_data.rākau[0] + [words]))

'[SOS] [NOT_SELECTED] [COLOUR_BLUE] [HEIGHT_4] [SELECTED] [COLOUR_WHITE] [HEIGHT_8] [CLS] te rākau mā [EOS]'

In [18]:
rākau_data['tokens'] = rākau_data.apply(lambda x: tokenizer.tokenize(x.rākau + x.description.split()), axis=1)
rākau_data

Unnamed: 0,id,entropy,num_rākau,rākau,description,input,target,tokens
0,800d23b5-574c-46d3-94cf-1066082e9d7d,3.0,2,"[{'color': 'blue', 'height': 4, 'location': 1,...",te rākau mā,"[1, 3, 5, 15, 2, 9, 19, 22]","[24, 23, 27, 53]","[1, 3, 5, 15, 2, 9, 19, 22, 24, 23, 27, 54]"
1,659289fa-7473-4617-a8f0-61324cc0e3b1,2.0,2,"[{'color': 'blue', 'height': 1, 'location': 1,...",ngā rākau,"[1, 2, 5, 12, 2, 6, 12, 22]","[25, 23, 53]","[1, 2, 5, 12, 2, 6, 12, 22, 25, 23, 54]"
2,7629eca4-df36-4921-973b-0ca7859c5018,3.0,2,"[{'color': 'blue', 'height': 1, 'location': 1,...",te rākau iti,"[1, 2, 5, 12, 3, 6, 20, 22]","[24, 23, 35, 53]","[1, 2, 5, 12, 3, 6, 20, 22, 24, 23, 35, 54]"
3,89413927-94ca-4078-8198-85b957338400,3.0,2,"[{'color': 'white', 'height': 8, 'location': 1...",te rākau mā,"[1, 2, 9, 19, 3, 6, 13, 22]","[24, 23, 27, 53]","[1, 2, 9, 19, 3, 6, 13, 22, 24, 23, 27, 54]"
4,eb110390-7be8-4d7a-93e0-211c33c033f1,3.0,2,"[{'color': 'black', 'height': 10, 'location': ...",ngā rākau,"[1, 2, 8, 21, 2, 9, 19, 22]","[25, 23, 53]","[1, 2, 8, 21, 2, 9, 19, 22, 25, 23, 54]"
...,...,...,...,...,...,...,...,...
1007,3d543250-f7c8-47c4-9491-ab29d5c8c8fd,3.0,2,"[{'color': 'black', 'height': 9, 'location': 1...",te rākau kikorangi,"[1, 3, 8, 20, 2, 5, 19, 22]","[24, 23, 30, 53]","[1, 3, 8, 20, 2, 5, 19, 22, 24, 23, 30, 54]"
1008,b6fcfe8b-5604-47d5-8e1a-1a2830219683,3.0,2,"[{'color': 'white', 'height': 4, 'location': 1...",te rākau mā,"[1, 2, 9, 15, 3, 8, 21, 22]","[24, 23, 27, 53]","[1, 2, 9, 15, 3, 8, 21, 22, 24, 23, 27, 54]"
1009,ecc53392-b801-4159-8fe3-fe9475e50672,3.0,2,"[{'color': 'red', 'height': 9, 'location': 1, ...",te rākau whero,"[1, 2, 4, 20, 3, 8, 13, 22]","[24, 23, 33, 53]","[1, 2, 4, 20, 3, 8, 13, 22, 24, 23, 33, 54]"
1010,bd1d8542-5e5f-4d03-9bf3-8b3d889fe737,3.0,2,"[{'color': 'red', 'height': 10, 'location': 1,...",ngā rākau,"[1, 2, 4, 21, 2, 7, 14, 22]","[25, 23, 53]","[1, 2, 4, 21, 2, 7, 14, 22, 25, 23, 54]"


## Making more token maps

Next I need to make a map that goes from source token to the classes of the valid target tokens for that source.

In [19]:
token_to_target_classes = {k: class_gram_dict[v] for k, v in token_to_class_dict.items()}
token_to_target_classes

{'[SOS]': ['[SELECTION]'],
 '[SELECTED]': ['[COLOUR]'],
 '[NOT_SELECTED]': ['[COLOUR]'],
 '[COLOUR_RED]': ['[HEIGHT]'],
 '[COLOUR_BLUE]': ['[HEIGHT]'],
 '[COLOUR_GREEN]': ['[HEIGHT]'],
 '[COLOUR_YELLOW]': ['[HEIGHT]'],
 '[COLOUR_BLACK]': ['[HEIGHT]'],
 '[COLOUR_WHITE]': ['[HEIGHT]'],
 '[COLOUR_BROWN]': ['[HEIGHT]'],
 '[COLOUR_PINK]': ['[HEIGHT]'],
 '[HEIGHT_1]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_2]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_3]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_4]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_5]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_6]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_7]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_8]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_9]': ['[CLS]', '[SELECTION]'],
 '[HEIGHT_10]': ['[CLS]', '[SELECTION]'],
 '[CLS]': ['det'],
 'mā': ['[END]',
  'all',
  'conjunction',
  'e',
  'furthest',
  'in',
  'locative',
  'ordinal',
  'size',
  'to'],
 'kōwhai': ['[END]',
  'all',
  'conjunction',
  'e',
  'furthest',
  'in',
  'locative',
  'ordinal',

In principle this is the information that I would need to bake into the loss function, but now I should probably inspect the loss function inputs and stuff so that I can figure out exactly what form I need this matrix to be set up like in order to work properly.