In [2]:
import numpy as np
import random
import time
import torch
from x_transformers.x_transformers import XTransformer
import torch

from run_experiment import *
from generate_data import *

## Variables

In [1]:
from sklearn.model_selection import ParameterGrid

TAG = '8+30_close2paper'

TASK_NAME = 'retrieval_15'
TRAIN_SIZE = 100_000
VAL_SIZE = 10_000
TEST_SIZE = 20_000
NUM_INITS = 3


NUM_BATCHES = int(1.5e5)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = NUM_BATCHES // 10
ENC_NUM_TOKENS = 26+10+1
DEC_NUM_TOKENS = 10+1
ENC_SEQ_LEN = 9
DEC_SEQ_LEN = 1

INPUT_LEN = 9


TASK_NAME = 'retrieval_15'
model_parameters = ParameterGrid({'dim': [20, 50, 100],
    'tie_token_embeds': [True],
    'return_tgt_loss': [True],
    'enc_num_tokens': [ENC_NUM_TOKENS],
    'depth,heads': [(1,1), (2,4)],
    'enc_max_seq_len': [5, 10, 15],
    'dec_num_tokens': [DEC_NUM_TOKENS],
    'dec_max_seq_len': [DEC_SEQ_LEN],
    'enc_num_memory_tokens': [0, 2, 4, 8, 16, 32]})

print('Total runs: ', NUM_INITS * len(model_parameters))

Total runs:  324


In [9]:
# for i, p in enumerate(model_parameters):
#     print(i, p)

#### Generate data

In [3]:
class retrieval_generator:
    def __init__(self, K=4):
        self.src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool()
        self.tgt_mask = torch.ones(BATCH_SIZE, DEC_SEQ_LEN+1).bool()
        self.K = K
    
    def __next__(self):
        X = np.zeros([BATCH_SIZE, ENC_SEQ_LEN]).astype(int)
        y = np.zeros([BATCH_SIZE, DEC_SEQ_LEN+1]).astype(int)
        y[:, 0] = 10
        for i in range(BATCH_SIZE):
            X[i], y[i, 1:] = create_sequence(one_hot=False, K=self.K)


        return torch.tensor(X), torch.tensor(y), self.src_mask, self.tgt_mask         


# generator = retrieval_generator(4)
# generate_data(generator, task_name='retrieval_4', train_size=TRAIN_SIZE, test_size=TEST_SIZE, val_size=VAL_SIZE)
# ENC_SEQ_LEN = 31
# generator = retrieval_generator(15)
# generate_data(generator, task_name='retrieval_15', train_size=TRAIN_SIZE, test_size=TEST_SIZE, val_size=VAL_SIZE)

In [11]:
# s,t, _, _ = next(generator)
# s[0], t[0]

### Run

In [None]:
gen_train = data_loader(task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_val = data_loader(task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_test = data_loader(task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)


t = time.time()
with torch.cuda.device(0):
    for init_num in range(NUM_INITS):
        print('\n\n\nInit number ', init_num)
        for i, param in enumerate(list(model_parameters)):
            print(param)
            param['enc_depth'], param['enc_heads'] = param['depth,heads']
            param['dec_depth'], param['dec_heads'] = param['depth,heads']
            param.pop('depth,heads')

            print(i / len(model_parameters) * 100, '%')
            model = XTransformer(**param).cuda()

            model_name = f"{TASK_NAME}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_v{init_num}"

            optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
            train_validate_model(model, 
                                train_generator=gen_train, 
                                val_generator=gen_val, 
                                optim=optim, 
                                model_name=model_name, 
                                dec_seq_len=DEC_SEQ_LEN,
                                num_batches=NUM_BATCHES,
                                generate_every=GENERATE_EVERY)
            test_model(model, gen_test, model_name, param, TASK_NAME, tag=TAG, dec_seq_len=param['dec_max_seq_len'])
            print('Total time: ', time.time() - t)
            t = time.time()




Init number  0
{'dec_max_seq_len': 1, 'dec_num_tokens': 11, 'depth,heads': (1, 1), 'dim': 20, 'enc_max_seq_len': 5, 'enc_num_memory_tokens': 0, 'enc_num_tokens': 37, 'return_tgt_loss': True, 'tie_token_embeds': True}
0.0 %
input:   tensor([11,  8, 17,  1, 33,  5, 36,  5, 22,  6, 34,  1, 12,  8, 16,  5, 25,  9,
        19,  3, 30,  4, 18,  4, 23,  7, 21,  3, 13,  4, 18], device='cuda:0')
predicted output:   tensor([[2]], device='cuda:0')
correct output:   tensor([4], device='cuda:0')
accuracy: 0.2086850255727768
input:   tensor([11,  8, 17,  1, 33,  5, 36,  5, 22,  6, 34,  1, 12,  8, 16,  5, 25,  9,
        19,  3, 30,  4, 18,  4, 23,  7, 21,  3, 13,  4, 18], device='cuda:0')
predicted output:   tensor([[7]], device='cuda:0')
correct output:   tensor([4], device='cuda:0')
accuracy: 0.23611728847026825
input:   tensor([11,  8, 17,  1, 33,  5, 36,  5, 22,  6, 34,  1, 12,  8, 16,  5, 25,  9,
        19,  3, 30,  4, 18,  4, 23,  7, 21,  3, 13,  4, 18], device='cuda:0')
predicted output: 

input:   tensor([11,  8, 17,  1, 33,  5, 36,  5, 22,  6, 34,  1, 12,  8, 16,  5, 25,  9,
        19,  3, 30,  4, 18,  4, 23,  7, 21,  3, 13,  4, 18], device='cuda:0')
predicted output:   tensor([[1]], device='cuda:0')
correct output:   tensor([4], device='cuda:0')
accuracy: 0.2374500185251236
Total time:  5575.574679613113
{'dec_max_seq_len': 1, 'dec_num_tokens': 11, 'depth,heads': (1, 1), 'dim': 20, 'enc_max_seq_len': 5, 'enc_num_memory_tokens': 8, 'enc_num_tokens': 37, 'return_tgt_loss': True, 'tie_token_embeds': True}
2.7777777777777777 %
input:   tensor([11,  8, 17,  1, 33,  5, 36,  5, 22,  6, 34,  1, 12,  8, 16,  5, 25,  9,
        19,  3, 30,  4, 18,  4, 23,  7, 21,  3, 13,  4, 18], device='cuda:0')
predicted output:   tensor([[4]], device='cuda:0')
correct output:   tensor([4], device='cuda:0')
accuracy: 0.22734341025352478
input:   tensor([11,  8, 17,  1, 33,  5, 36,  5, 22,  6, 34,  1, 12,  8, 16,  5, 25,  9,
        19,  3, 30,  4, 18,  4, 23,  7, 21,  3, 13,  4, 18], device=

### Test!

In [8]:
TASK_NAME = 'retrieval_15'



In [9]:
init_num = 0

gen_train = data_loader(task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_val = data_loader(task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_test = data_loader(task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)


param = list(model_parameters)[5]
print(param)
param['enc_depth'], param['enc_heads'] = param['depth,heads']
param['dec_depth'], param['dec_heads'] = param['depth,heads']
param.pop('depth,heads')

model = XTransformer(**param).cuda()

model_name = f"{TASK_NAME}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_v{init_num}"

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

src, tgt, _, _ = next(gen_train)

print(model.encoder.max_seq_len, model.encoder.num_memory_tokens)
model.encoder(torch.cat((src, src)), return_embeddings=True).shape

{'dec_max_seq_len': 1, 'dec_num_tokens': 11, 'depth,heads': (1, 1), 'dim': 20, 'enc_max_seq_len': 5, 'enc_num_memory_tokens': 32, 'enc_num_tokens': 37, 'return_tgt_loss': True, 'tie_token_embeds': True}
5 32


torch.Size([64, 37, 20])

In [10]:
src[0], tgt[0]

(tensor([36,  9, 23,  9, 18,  7, 21,  4, 29,  3, 14,  6, 31,  4, 33,  2, 12,  0,
         35,  4, 20,  0, 28,  9, 15,  6, 16,  5, 17,  7, 28], device='cuda:0'),
 tensor([10,  9], device='cuda:0'))