In [1]:
import numpy as np
import random
import time
import torch
from x_transformers.x_transformers import XTransformer
import torch

from run_experiment import *
from generate_data import *

## Variables

In [10]:
from sklearn.model_selection import ParameterGrid

TAG = 'improve_score_further'

TASK_NAME = 'reverse'
TRAIN_SIZE = 30_000
VAL_SIZE = 2_000
TEST_SIZE = 5_000
NUM_INITS = 3


NUM_BATCHES = int(1e5)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 3000
NUM_TOKENS = 16 + 2
ENC_SEQ_LEN = 32
DEC_SEQ_LEN = 32

INPUT_LEN = 32


model_parameters = ParameterGrid({'dim': [128],
    'tie_token_embeds': [True],
    'return_tgt_loss': [True],
    'enc_num_tokens': [NUM_TOKENS],
    'depth,heads': [(2,4), (1,1)],
    'enc_max_seq_len': [32],
    'dec_num_tokens': [NUM_TOKENS],
    'dec_max_seq_len': [DEC_SEQ_LEN],
    'enc_num_memory_tokens': [0]})

print('Total runs: ', NUM_INITS * len(model_parameters))

Total runs:  6


#### Generate data

In [3]:
# class reverse_generator:
#     def __init__(self):
#         self.src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
#         self.tgt_mask = torch.ones(BATCH_SIZE, DEC_SEQ_LEN+1).bool().cuda()
    
#     def __next__(self):
#         X = np.zeros([BATCH_SIZE, ENC_SEQ_LEN]).astype(int)
#         y = np.zeros([BATCH_SIZE, DEC_SEQ_LEN+1]).astype(int)
#         y[:, 0] = 1
#         for i in range(BATCH_SIZE):
#             sequence_length = np.random.randint(1, ENC_SEQ_LEN)
#             random_sequence = np.random.randint(2, NUM_TOKENS, sequence_length)
            
#             X[i, :sequence_length] = random_sequence
#             y[i, 1:sequence_length + 1] = random_sequence[::-1]


#         return torch.tensor(X), torch.tensor(y), self.src_mask, self.tgt_mask   
    
# generator = reverse_generator()
# generate_data(generator, task_name=TASK_NAME, train_size=TRAIN_SIZE, test_size=TEST_SIZE, val_size=VAL_SIZE)  

#### Gridsearch params

In [8]:
optimizer = torch.optim.Adam

optim_params = ParameterGrid({
    'lr': [0.0002, 0.00012, 0.000081, 0.0001],
})

print(len(optim_params))

4


In [9]:
gen_train = data_loader(task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_val = data_loader(task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_test = data_loader(task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)


print_file = f'logs/{TASK_NAME}_{TAG}_cout_logs2.txt'
t = time.time()

param = list(model_parameters)[0]
param['enc_depth'], param['enc_heads'] = param['depth,heads']
param['dec_depth'], param['dec_heads'] = param['depth,heads']
param.pop('depth,heads')

with torch.cuda.device(1):
    for i, optim_param in enumerate(list(optim_params)[1:]):
        with open(print_file, 'a') as f:
            f.write('\n\n' + str(optim_param)+'\n')
        
        for init_num in range(1):
            model = XTransformer(**param).cuda()

            model_name = f"{TASK_NAME}{INPUT_LEN}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_v{init_num}_{optim_param}"

            optim = optimizer(model.parameters(), **optim_param)
            train_validate_model(model, 
                                train_generator=gen_train, 
                                val_generator=gen_val, 
                                optim=optim, 
                                model_name=model_name, 
                                dec_seq_len=DEC_SEQ_LEN,
                                num_batches=NUM_BATCHES,
                                generate_every=GENERATE_EVERY,
                                print_file=print_file)
            test_model(model, gen_test, model_name, param, TASK_NAME, tag=str(optim_param), dec_seq_len=param['dec_max_seq_len'])
            with open(print_file, 'a') as f:
                f.write(f'\nTotal time: {time.time() - t}\n')
            t = time.time()

### Run

In [None]:
gen_train = data_loader(task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_val = data_loader(task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_test = data_loader(task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)


t = time.time()
with torch.cuda.device(0):
    for i, param in enumerate(list(model_parameters)):
        print(param)
        param['enc_depth'], param['enc_heads'] = param['depth,heads']
        param['dec_depth'], param['dec_heads'] = param['depth,heads']
        param.pop('depth,heads')
        
        print(i / len(model_parameters) * 100, '%')
        for init_num in range(NUM_INITS):
            print('\n\n\nInit number ', init_num)
            model = XTransformer(**param).cuda()

            model_name = f"{TASK_NAME}{INPUT_LEN}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_v{init_num}"

            optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
            train_validate_model(model, 
                                train_generator=gen_train, 
                                val_generator=gen_val, 
                                optim=optim, 
                                model_name=model_name, 
                                dec_seq_len=DEC_SEQ_LEN,
                                num_batches=NUM_BATCHES,
                                generate_every=GENERATE_EVERY)
            test_model(model, gen_test, model_name, param, TASK_NAME, tag=TAG, dec_seq_len=param['dec_max_seq_len'])
            print('Total time: ', time.time() - t)
            t = time.time()

{'dec_max_seq_len': 64, 'dec_num_tokens': 18, 'depth,heads': (1, 1), 'dim': 128, 'enc_max_seq_len': 64, 'enc_num_memory_tokens': 0, 'enc_num_tokens': 18, 'return_tgt_loss': True, 'tie_token_embeds': True}
0.0 %



Init number  0
input:   tensor([15,  8, 12,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0], device='cuda:0')
predicted output:   tensor([[ 6,  8, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
correct output:   tensor([ 6, 12,  8, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,

stopped on: 
{'dec_max_seq_len': 16, 'dec_num_tokens': 18, 'depth,heads': (2, 4), 'dim': 64, 'enc_max_seq_len': 8, 'enc_num_memory_tokens': 16, 'enc_num_tokens': 18, 'return_tgt_loss': True, 'tie_token_embeds': True}

### Test!

In [6]:
init_num = 0

gen_train = data_loader(task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_val = data_loader(task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_test = data_loader(task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)


param = list(model_parameters)[5]
print(param)
param['enc_depth'], param['enc_heads'] = param['depth,heads']
param['dec_depth'], param['dec_heads'] = param['depth,heads']
param.pop('depth,heads')

model = XTransformer(**param).cuda()

model_name = f"{TASK_NAME}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_v{init_num}"

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

src, tgt, _, _ = next(gen_train)

print(model.encoder.max_seq_len, model.encoder.num_memory_tokens)
model.encoder(torch.cat((src, src)), return_embeddings=True).shape

{'dec_max_seq_len': 16, 'dec_num_tokens': 18, 'depth,heads': (1, 1), 'dim': 64, 'enc_max_seq_len': 8, 'enc_num_memory_tokens': 4, 'enc_num_tokens': 18, 'return_tgt_loss': True, 'tie_token_embeds': True}
8 4


torch.Size([64, 12, 64])