# Import

In [1]:
import numpy as np
import random
import torch
from x_transformers import XTransformer
import torch
from torch.utils.tensorboard import SummaryWriter

In [2]:
# %load_ext tensorboard
# %tensorboard --logdir log

In [3]:
# !rm -r log_comp

In [4]:
writer = SummaryWriter(log_dir='log_comp')

def train_validate_model(model, batch_generator, optimizer, model_name, num_batches=1e4, eval_size=5):

    for i in range(num_batches):

        model.train()
        
        src, tgt, src_mask, tgt_mask = next(batch_generator)
        loss = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        loss.backward()

        loss_value = loss.item()
        print(f'{i}: {loss_value}')
        
        writer.add_scalars("/train/loss", {model_name: loss_value}, i)
        if loss_value < 1e-3:
            break

        optim.step()
        optim.zero_grad()

        if i != 0 and i % GENERATE_EVERY == 0:
            model.eval()

            src, tgt, src_mask, _ = next(batch_generator)
            src, tgt = src[:eval_size], tgt[:eval_size, 1:]
            start_tokens = tgt[:1, :1]
            sm = src_mask[:1]

            num_correct = 0
            total_batch_len = 0
            for s, t in zip(src, tgt):
                sample = model.generate(s[None], start_tokens, DEC_SEQ_LEN, src_mask=sm)
                num_correct += torch.abs(((t == sample) & (t != 0)).float()).sum()
                # print('t: ', t)
                # print('sample: ', sample)
                # print('correct mask = ', (t == sample) & (t != 0))
                total_batch_len += (t != 0).float().sum()

            accuracy = num_correct / total_batch_len
            writer.add_scalars("/val/accuracy", {model_name: accuracy}, i)

            print(f"input:  ", s)
            print(f"predicted output:  ", sample)
            print(f"correct output:  ", t)
            print(f"accuracy: {accuracy}")

    writer.flush()
    # writer.close()

# 1. Copy

In [5]:
NUM_BATCHES = int(1e5)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY = 100
NUM_TOKENS = 16 + 2
ENC_SEQ_LEN = 32//2
DEC_SEQ_LEN = 64//2

In [6]:
class copy_generator:
    def __init__(self):
        self.mode = 'train'
        # data = pd.read_csv('data/copy.csv')
        # self.data_size = data.shape[0]
        # self.X, self.y = data.X.values, data.y.values

        self.src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool()#.cuda()
        self.tgt_mask = torch.ones(BATCH_SIZE, DEC_SEQ_LEN+1).bool()#.cuda()

        # self.sequence_length = ENC_SEQ_LEN

    def train(self):
        self.mode = 'train'

    def eval(self):
        self.mode = 'eval'
    
    def __next__(self):
        if self.mode == 'train':
            X = np.zeros([BATCH_SIZE, ENC_SEQ_LEN]).astype(int)
            y = np.zeros([BATCH_SIZE, DEC_SEQ_LEN+1]).astype(int)
            y[:, 0] = 1
            for i in range(BATCH_SIZE):
                sequence_length = np.random.randint(1, ENC_SEQ_LEN)
                random_sequence = np.random.randint(2, NUM_TOKENS, sequence_length)
                
                X[i, :sequence_length] = random_sequence
                y[i, 1: 2 * sequence_length + 1] = np.concatenate([random_sequence] * 2)

            return torch.tensor(X), torch.tensor(y), self.src_mask, self.tgt_mask

        else:
            return


In [7]:
def generate_validation(generator, task_name):
    val_X, val_y = [], []
    for _ in range(4):
        X, y, _, _ = next(generator)
        val_X.append(X)
        val_y.append(y)

    val_X = torch.vstack(val_X)
    val_y = torch.vstack(val_y)

    np.save('data/{}_X.npy'.format(task_name), val_X)
    np.save('data/{}_y.npy'.format(task_name), val_y)

# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! separately
cg = copy_generator()
generate_validation(cg, 'copy')

# rg = reverse_generator()
# generate_validation(rg, 'reverse')

# retr_gen = retrieval_generator()
# generate_validation(retr_gen, 'retrieval')

# listops_gen = listops_generator()
# generate_validation(listops_gen, 'listops')

In [50]:

# src, tgt, src_mask, tgt_mask = next(cg)
# src, tgt, src_mask, tgt_mask 

In [15]:
# copy_base_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)

In [16]:
# src, tgt, src_mask, _ = next(cg)
# src, tgt = src[:1], tgt[:1]
# start_tokens = tgt[:1, :1]
# sm = src_mask[:1]

# copy_base_model.generate(src, start_tokens, DEC_SEQ_LEN, src_mask=sm).shape

In [17]:
model_parameters = {'dim': 512,
    'tie_token_embeds': True,
    'return_tgt_loss': True,
    'enc_num_tokens': NUM_TOKENS,
    'enc_depth': 3,
    'enc_heads': 8,
    'enc_max_seq_len': ENC_SEQ_LEN,
    'dec_num_tokens': NUM_TOKENS,
    'dec_depth': 3,
    'dec_heads': 8}

## Base transformer

In [18]:
copy_base_model = XTransformer(**model_parameters, dec_max_seq_len=DEC_SEQ_LEN).cuda()

optim = torch.optim.Adam(copy_base_model.parameters(), lr=LEARNING_RATE)
cgen = copy_generator()

In [19]:
train_validate_model(copy_base_model, cgen, optim, 'copy_base', num_batches=2500)

0.021146396175026894
718: 0.015517136082053185
719: 0.01580696366727352
720: 0.017386652529239655
721: 0.011327404528856277
722: 0.014203464612364769
723: 0.017117705196142197
724: 0.01402110792696476
725: 0.03703693300485611
726: 0.01818186603486538
727: 0.01636953093111515
728: 0.01846953108906746
729: 0.016642559319734573
730: 0.02013467624783516
731: 0.046334158629179
732: 0.03526313975453377
733: 0.031095068901777267
734: 0.03505806252360344
735: 0.021103542298078537
736: 0.02688232623040676
737: 0.022969337180256844
738: 0.031837206333875656
739: 0.01558721624314785
740: 0.03346063196659088
741: 0.0321040041744709
742: 0.028909262269735336
743: 0.03525696322321892
744: 0.05901722237467766
745: 0.026196425780653954
746: 0.03213098272681236
747: 0.020130060613155365
748: 0.023255322128534317
749: 0.021327229216694832
750: 0.019719231873750687
751: 0.05405283719301224
752: 0.02156735211610794
753: 0.019569262862205505
754: 0.014875595457851887
755: 0.029129330068826675
756: 0.012056

KeyboardInterrupt: 

## Memory transformer

In [55]:
copy_memory_model = XTransformer(**model_parameters, dec_max_seq_len=DEC_SEQ_LEN, num_memory_tokens=16).cuda()

optim = torch.optim.Adam(copy_memory_model.parameters(), lr=LEARNING_RATE)
cgen = copy_generator()

In [56]:
train_validate_model(copy_memory_model, cgen, optim, 'copy_memory', num_batches=2500)

0: 2.8105568885803223
1: 5.373259544372559
2: 4.240912437438965
input:   tensor([ 4, 13,  8, 13, 12, 16,  3, 13, 16,  9, 15, 12,  0,  0,  0,  0])
predicted output:   tensor([[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
         15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15]])
correct output:   tensor([ 1,  4, 13,  8, 13, 12, 16,  3, 13, 16,  9, 15, 12,  4, 13,  8, 13, 12,
        16,  3, 13, 16,  9, 15, 12,  0,  0,  0,  0,  0,  0,  0,  0])
accuracy: 0.014598540030419827
3: 3.4374091625213623
4: 3.011038064956665
input:   tensor([12,  5,  3,  5,  4, 16, 17, 13,  8,  5, 11,  5,  0,  0,  0,  0])
predicted output:   tensor([[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5]])
correct output:   tensor([ 1, 12,  5, 12,  5,  4, 16, 17, 13,  8,  5, 11,  5, 16, 14,  3,  5,  4,
        16, 17, 13,  8,  5, 11,  5,  0,  0,  0,  0,  0,  0,  0,  0])
accuracy: 0.0476190485060215
5: 2.807417154312134


KeyboardInterrupt: 

# 2. Reverse

In [9]:
NUM_BATCHES = int(1e5)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 3
NUM_TOKENS = 16 + 2
ENC_SEQ_LEN = 32
DEC_SEQ_LEN = 32

In [10]:
def reverse_generator(variable_length=True):
    src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool()#.cuda()
    tgt_mask = torch.ones(BATCH_SIZE, DEC_SEQ_LEN+1).bool()#.cuda()

    sequence_length = ENC_SEQ_LEN
    while(True):        
        X = np.zeros([BATCH_SIZE, ENC_SEQ_LEN]).astype(int)
        y = np.zeros([BATCH_SIZE, DEC_SEQ_LEN+1]).astype(int)
        y[:, 0] = 1
        for i in range(BATCH_SIZE):
            if variable_length:
                sequence_length = np.random.randint(1, ENC_SEQ_LEN)
                random_sequence = np.random.randint(2, NUM_TOKENS, sequence_length)
                
                X[i, :sequence_length] = random_sequence
                y[i, 1:sequence_length + 1] = random_sequence[::-1]

        yield torch.tensor(X), \
                torch.tensor(y),\
                 src_mask, tgt_mask

In [11]:
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! separately
# cg = copy_generator()
# generate_validation(cg, 'copy')

rg = reverse_generator()
generate_validation(rg, 'reverse')

# retr_gen = retrieval_generator()
# generate_validation(retr_gen, 'retrieval')

# listops_gen = listops_generator()
# generate_validation(listops_gen, 'listops')

In [86]:
# rg = reverse_generator()
# X, y, sm, tm = next(rg)
# X, y, sm, tm

In [87]:
model_parameters = {'dim': 512,
    'tie_token_embeds': True,
    'return_tgt_loss': True,
    'enc_num_tokens': NUM_TOKENS,
    'enc_depth': 3,
    'enc_heads': 8,
    'enc_max_seq_len': ENC_SEQ_LEN,
    'dec_num_tokens': NUM_TOKENS,
    'dec_depth': 3,
    'dec_heads': 8}

## Base transformer

In [88]:
reverse_base_model = XTransformer(**model_parameters, dec_max_seq_len=DEC_SEQ_LEN)#.cuda()

optim = torch.optim.Adam(reverse_base_model.parameters(), lr=LEARNING_RATE)
rgen = reverse_generator()

In [89]:
train_validate_model(reverse_base_model, rgen, optim, 'reverse_base', num_batches=601)

0: 3.000492572784424
1: 4.143182754516602
2: 6.0525641441345215
3: 3.476656675338745
input:   tensor([12, 17,  9, 11, 12,  4,  8,  8, 14,  6,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
predicted output:   tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])
correct output:   tensor([ 6, 14,  8,  8,  4, 12, 11,  9, 17, 12,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
accuracy: 0.0
4: 1.950818657875061
5: 1.8770551681518555
6: 2.4338696002960205
input:   tensor([13,  3, 14,  6,  4, 14, 14,  5, 11,  3, 17, 13,  4,  4, 17, 11,  9, 17,
         2,  8, 16,  8,  8,  6, 16, 16,  9,  0,  0,  0,  0,  0])
predicted output:   tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])
correct output:   tensor([ 9, 16, 16,  6,  8,  8, 16,  8,  2, 17,  9, 11, 17,  4,  4, 13, 17,

KeyboardInterrupt: 

## Memory transformer

In [75]:
reverse_memory_model = XTransformer(**model_parameters, dec_max_seq_len=DEC_SEQ_LEN, num_memory_tokens=16)#.cuda()

optim = torch.optim.Adam(reverse_memory_model.parameters(), lr=LEARNING_RATE)
rgen = reverse_generator()

In [76]:
train_validate_model(reverse_memory_model, rgen, optim, 'reverse_memory', num_batches=601)

0: 3.0095040798187256
1: 5.572384834289551
2: 3.79541277885437
3: 2.949702262878418
input:   tensor([ 8, 11, 11, 16,  9,  5, 11,  3, 13, 17, 13,  9, 10,  6,  9,  2,  7,  5,
        15, 17, 11,  6, 16, 17,  9, 15,  3,  7,  3,  5, 10,  0])
predicted output:   tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0]])
correct output:   tensor([ 1, 10,  5,  3,  7,  3, 15,  9, 17, 16,  6, 11, 17, 15,  5,  7,  2,  9,
         6, 10,  9, 13, 17, 13,  3, 11,  5,  9, 16, 11, 11,  8,  0])
accuracy: 0.006802720949053764
4: 2.800572395324707
5: 2.860241174697876
6: 2.861840009689331
input:   tensor([ 5, 10,  5, 11, 16,  9,  7, 16, 15,  2, 10, 10, 10,  2, 16, 15,  6,  2,
        11, 13,  3,  2, 11,  6, 11, 11, 14,  8,  7,  5, 10,  0])
predicted output:   tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0]])
correct output:   tensor([ 1, 11,  5, 10,  5,  2, 10, 10, 10,  2, 15,

KeyboardInterrupt: 

# Associative retrieval

In [12]:
def get_three_letters():
    return np.random.choice(range(0,26), 3, replace=False)

def get_three_numbers():
    return np.random.choice(range(26, 26+10), 3, replace=False)

def create_sequence(one_hot=True):
    letters = get_three_letters()
    numbers = get_three_numbers()
    X = np.zeros((9))
    y = np.zeros((1))
    for i in range(0, 5, 2):
        X[i] = letters[i//2]
        X[i+1] = numbers[i//2]

    # append ??
    X[6] = 26+10
    X[7] = 26+10

    # last key and respective value (y)
    index = np.random.choice(range(0,3), 1, replace=False)
    X[8] = letters[index]
    y = numbers[index]

    if one_hot:
        # one hot encode X and y
        X_one_hot = np.eye(26+10+1)[np.array(X).astype('int')]
        y_one_hot = np.eye(26+10+1)[y][0]

        return X_one_hot, y_one_hot
    else:
        return X, y

def ordinal_to_alpha(sequence):
    """
    Convert from ordinal to alpha-numeric representations.
    Just for funsies :)
    """
    corpus = ['a','b','c','d','e','f','g','h','i','j','k','l',
              'm','n','o','p','q','r','s','t','u','v','w','x','y','z',
               0, 1, 2, 3, 4, 5, 6, 7, 8, 9, '?']

    conversion = ""
    for item in sequence:
        conversion += str(corpus[item.argmax()])
    return conversion

In [13]:
def retrieval_generator():
    X = np.zeros([BATCH_SIZE, 9]).astype(int)
    y = np.zeros([BATCH_SIZE, 2]).astype(int)
    y[:, 0] = 1

    src_mask = torch.ones(BATCH_SIZE, 9).bool()#.cuda()
    tgt_mask = torch.ones(BATCH_SIZE, 2).bool()#.cuda()

    while(True):        
        for i in range(BATCH_SIZE):
            X[i], y[i, 1:] = create_sequence(one_hot=False)

        yield torch.tensor(X), torch.tensor(y), src_mask, tgt_mask

In [14]:
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! separately
# cg = copy_generator()
# generate_validation(cg, 'copy')

# rg = reverse_generator()
# generate_validation(rg, 'reverse')

retr_gen = retrieval_generator()
generate_validation(retr_gen, 'retrieval')

# listops_gen = listops_generator()
# generate_validation(listops_gen, 'listops')

In [63]:
NUM_BATCHES = int(1e5)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 3
NUM_TOKENS = 26+10+1
ENC_SEQ_LEN = 9
DEC_SEQ_LEN = 1

In [64]:
model_parameters = {'dim': 512,
    'tie_token_embeds': True,
    'return_tgt_loss': True,
    'enc_num_tokens': NUM_TOKENS,
    'enc_depth': 3,
    'enc_heads': 8,
    'enc_max_seq_len': ENC_SEQ_LEN,
    'dec_num_tokens': NUM_TOKENS,
    'dec_depth': 3,
    'dec_heads': 8}

## Base transformer

In [71]:
# retr_gen = retrieval_generator()
# src, tgt, sm, tm = next(retr_gen)
# src.shape, tgt.shape, sm.shape, tm.shape

In [72]:
# retr_gen = retrieval_generator()
# src, tgt, sm, tm = next(retr_gen)
# src.shape, tgt.shape, sm.shape, tm.shape
# tgt

In [73]:
retrieval_base_model = XTransformer(**model_parameters, dec_max_seq_len=DEC_SEQ_LEN)#.cuda()

optim = torch.optim.Adam(retrieval_base_model.parameters(), lr=LEARNING_RATE)
retr_gen = retrieval_generator()

In [74]:
train_validate_model(retrieval_base_model, retr_gen, optim, 'retrieval_base', num_batches=1000)

0: 3.6622865200042725
1: 2.652991771697998
2: 2.775786876678467
3: 3.0393452644348145
t:  tensor([28])
sample:  tensor([[27]])
correct mask =  tensor([[False]])
t:  tensor([33])
sample:  tensor([[31]])
correct mask =  tensor([[False]])
t:  tensor([32])
sample:  tensor([[30]])
correct mask =  tensor([[False]])
t:  tensor([33])
sample:  tensor([[30]])
correct mask =  tensor([[False]])
t:  tensor([32])
sample:  tensor([[30]])
correct mask =  tensor([[False]])
input:   tensor([14, 34, 20, 32, 16, 31, 36, 36, 20])
predicted output:   tensor([[30]])
correct output:   tensor([32])
accuracy: 0.0
4: 2.3778741359710693
5: 2.5743558406829834
6: 2.4867851734161377
t:  tensor([27])
sample:  tensor([[34]])
correct mask =  tensor([[False]])
t:  tensor([33])
sample:  tensor([[35]])
correct mask =  tensor([[False]])
t:  tensor([32])
sample:  tensor([[35]])
correct mask =  tensor([[False]])
t:  tensor([27])
sample:  tensor([[34]])
correct mask =  tensor([[False]])
t:  tensor([28])
sample:  tensor([[35]]

KeyboardInterrupt: 

## Memory transformer

In [75]:
retrieval_memory_model = XTransformer(**model_parameters, dec_max_seq_len=DEC_SEQ_LEN, num_memory_tokens=4)#.cuda()

optim = torch.optim.Adam(retrieval_memory_model.parameters(), lr=LEARNING_RATE)
retr_gen = retrieval_generator()

In [76]:
train_validate_model(retrieval_memory_model, retr_gen, optim, 'retrieval_memory', num_batches=1000, eval_size=20)

0: 3.7840535640716553
1: 3.407494306564331
2: 2.976376533508301
3: 3.039721727371216
t:  tensor([33])
sample:  tensor([[34]])
correct mask =  tensor([[False]])
t:  tensor([35])
sample:  tensor([[29]])
correct mask =  tensor([[False]])
t:  tensor([28])
sample:  tensor([[29]])
correct mask =  tensor([[False]])
t:  tensor([29])
sample:  tensor([[32]])
correct mask =  tensor([[False]])
t:  tensor([27])
sample:  tensor([[32]])
correct mask =  tensor([[False]])
t:  tensor([34])
sample:  tensor([[34]])
correct mask =  tensor([[True]])
t:  tensor([33])
sample:  tensor([[29]])
correct mask =  tensor([[False]])
t:  tensor([26])
sample:  tensor([[34]])
correct mask =  tensor([[False]])
t:  tensor([29])
sample:  tensor([[34]])
correct mask =  tensor([[False]])
t:  tensor([27])
sample:  tensor([[29]])
correct mask =  tensor([[False]])
t:  tensor([35])
sample:  tensor([[32]])
correct mask =  tensor([[False]])
t:  tensor([35])
sample:  tensor([[34]])
correct mask =  tensor([[False]])
t:  tensor([32])

KeyboardInterrupt: 

# ListOPS

In [91]:
4 + 4*3 + 3 * 3 * 4 + 3 * 3 * 3 * 4

160

In [15]:
MIN = "[MIN"
MAX = "[MAX"
MED = "[MED"
SUM_MOD = "[SM"
END = "]"

OPERATORS = [MIN, MAX, MED, SUM_MOD]
VALUES = range(10)

VALUE_P = 0.25
MAX_ARGS = 3
MAX_DEPTH = 10

In [16]:
def generate_tree(depth):
    if depth < MAX_DEPTH:
        r = random.random()
    else:
        r = 1

    if r > VALUE_P:
        value = random.choice(VALUES)
        return value
    else:
        num_values = random.randint(2, MAX_ARGS)
        values = []
        for _ in range(num_values):
            values.append(generate_tree(depth + 1))

        op = random.choice(OPERATORS)
        t = (op, values[0])
        for value in values[1:]:
            t = (t, value)
        t = (t, END)
    return t


def to_string(t, parens=False):
    if isinstance(t, str):
        return t
    elif isinstance(t, int):
        return str(t)
    else:
        if parens:
            return '( ' + to_string(t[0], parens) + ' ' + to_string(t[1], parens) + ' )'
        else:
            return to_string(t[0], parens) + ' ' + to_string(t[1], parens)# + ' '


op2token = dict(zip(list(np.arange(10).astype(str)) + OPERATORS + [END], range(2, 15+2)))

def to_tokens(t):
    string = to_string(t)
    tokens = list(map(lambda x: op2token[x], string.split(' ')))
    return tokens


def to_value(t):
    if not isinstance(t, tuple):
        return t
    l = to_value(t[0])
    r = to_value(t[1])
    if l in OPERATORS:  # Create an unsaturated function.
        return (l, [r])
    elif r == END:  # l must be an unsaturated function.
        if l[0] == MIN:
            return min(l[1])
        elif l[0] == MAX:
            return max(l[1])
        # elif l[0] == FIRST:
        #     return l[1][0]
        # elif l[0] == LAST:
        #     return l[1][-1]
        elif l[0] == MED:
            return int(np.median(l[1]))
        elif l[0] == SUM_MOD:
            return (np.sum(l[1]) % 10)
    elif isinstance(l, tuple):  # We've hit an unsaturated function and an argument.
        return (l[0], l[1] + [r])

In [17]:
NUM_BATCHES = int(1e5)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 3
NUM_TOKENS = 10+5+2
ENC_SEQ_LEN = 150
DEC_SEQ_LEN = 1

In [18]:
def listops_generator(max_depth=MAX_DEPTH):

    src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool()#.cuda()
    tgt_mask = torch.ones(BATCH_SIZE, 2).bool()#.cuda()
    
    while(True):
        X = np.zeros([BATCH_SIZE, ENC_SEQ_LEN]).astype(int)
        y = np.ones([BATCH_SIZE, 2]).astype(int) * 2
        # y[:, 0] = 1

        for i in range(BATCH_SIZE):
            t = generate_tree(max_depth)
            tokens, value = to_tokens(t), to_value(t) 
            X[i, 0:len(tokens)], y[i, 1:] = tokens, value+2
            del t

        src, tgt = torch.tensor(X), torch.tensor(y)#.cuda()
        yield src, tgt, src_mask, tgt_mask

In [19]:
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! separately
# cg = copy_generator()
# generate_validation(cg, 'copy')

# rg = reverse_generator()
# generate_validation(rg, 'reverse')

# retr_gen = retrieval_generator()
# generate_validation(retr_gen, 'retrieval')

listops_gen = listops_generator()
generate_validation(listops_gen, 'listops')

In [70]:
# lg = listops_generator(4)
# next(lg)

In [71]:
model_parameters = {'dim': 512,
    'tie_token_embeds': True,
    'return_tgt_loss': True,
    'enc_num_tokens': NUM_TOKENS,
    'enc_depth': 3,
    'enc_heads': 8,
    'enc_max_seq_len': ENC_SEQ_LEN,
    'dec_num_tokens': NUM_TOKENS,
    'dec_depth': 3,
    'dec_heads': 8}

## Base transformer

In [100]:
# del listops_base_model
# del tokens 
torch.cuda.empty_cache()

In [106]:
listops_base_model = XTransformer(**model_parameters, dec_max_seq_len=2)#.cuda()

optim = torch.optim.Adam(listops_base_model.parameters(), lr=LEARNING_RATE)
listops_gen = listops_generator(max_depth=4)


In [107]:
train_validate_model(listops_base_model, listops_gen, optim, 'listops_base', num_batches=150)

 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0])
predicted output:   tensor([[8]])
correct output:   tensor([8])
accuracy: 0.20000000298023224
70: 0.8611487150192261
71: 0.9114288091659546
72: 0.5899685621261597
input:   tensor([11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  

## Memory transformer

In [9]:
listops_memory_model = XTransformer(**model_parameters, dec_max_seq_len=2, num_memory_tokens=32).cuda()

optim = torch.optim.Adam(listops_memory_model.parameters(), lr=LEARNING_RATE)
listops_gen = listops_generator(max_depth=4)

,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0]], device='cuda:0')
predicted output:   tensor([[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,

In [None]:
train_validate_model(listops_memory_model, listops_gen, optim, 'listops_memory', num_batches=1500)