In [1]:
# !git clone https://github.com/booydar/algotrade
# !pip install einops entmax

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
# !mkdir logs
# !mkdir checkpoints

In [4]:
import torch
import torch.nn as nn
import numpy as np
import random
import time
import sys
from sklearn.model_selection import train_test_split

# sys.path.append('algotrade/NN')
# from algotrade.NN.x_transformers.x_transformers import *
# from algotrade.NN.run_experiment import *
# from algotrade.NN.generate_data import *

sys.path.append('NN')
from x_transformers.NN.x_transformers import *
from x_transformers.NN.x_transformers.x_transformers import *
from x_transformers.NN.run_experiment import *
# from x_transformers.NN.generate_data import *

## Variables

In [5]:
from sklearn.model_selection import ParameterGrid

TAG = 'test'

TASK_NAME = 'price'
TRAIN_SIZE = 100_000
VAL_SIZE = 2_000
TEST_SIZE = 10_000
NUM_INITS = 4


NUM_BATCHES = int(4e5)
BATCH_SIZE = 128
GENERATE_EVERY  = 10000
NUM_TOKENS = 10 + 2
ENC_SEQ_LEN = 24
DEC_SEQ_LEN = 48

INPUT_LEN = 24

#### Load data

In [24]:
class data_loader:
    def __init__(self, mode, path='data', tgt_len=24, batch_size=32, tgt_dim=2, device='cpu'):
        X, y = np.load(f'{path}/X_{mode}.npy'), np.load(f'{path}/y_{mode}.npy')        
        X = torch.tensor(X)

        slices_x = [X[i:tgt_len + i] for i in range(X.shape[0] - tgt_len)]
        src = torch.stack(slices_x)
        tgt = y[tgt_len-1:-1]
        
        if tgt_dim is not None:
            tgt = tgt[:, [0, tgt_dim]]
        
        perm_ind = torch.randperm(src.shape[0])
        src, tgt = src[perm_ind], tgt[perm_ind]
        self.src, self.tgt = torch.tensor(src).float(), torch.tensor(tgt).float()

        self.data_size = self.src.shape[0]
        self.data_ptr = 0

        self.batch_size = batch_size
        self.device = device

    def __next__(self):
        if self.data_ptr + self.batch_size > self.data_size:
            self.data_ptr = 0

        src = self.src[self.data_ptr: self.data_ptr + self.batch_size].to(device=self.device)
        tgt = self.tgt[self.data_ptr: self.data_ptr + self.batch_size].to(device=self.device)
        
        src_mask = tgt_mask = None
            
        self.data_ptr = (self.data_ptr + self.batch_size) % self.data_size

        return src, tgt, src_mask, tgt_mask

### Run

In [25]:

gen_train = data_loader(path=f'data/BTCUSD', mode='train', batch_size=BATCH_SIZE, device='cuda')
gen_val = data_loader(path=f'data/BTCUSD', mode='val', batch_size=BATCH_SIZE, device='cuda')
gen_test = data_loader(path=f'data/BTCUSD', mode='test', batch_size=BATCH_SIZE, device='cuda')

  self.src, self.tgt = torch.tensor(src).float(), torch.tensor(tgt).float()


In [30]:
class CXTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        tie_token_emb = False,
        **kwargs
    ):
        super().__init__()
        enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
        dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
        
        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
        enc_transformer_kwargs = pick_and_pop(['max_seq_len', 'dim_in', 'use_pos_emb'], enc_kwargs)
        # enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)

        dec_transformer_kwargs = pick_and_pop(['max_seq_len', 'dim_in', 'dim_out'], dec_kwargs)

        self.encoder = ContinuousTransformerWrapper(
            **enc_transformer_kwargs,
            attn_layers = Encoder(dim = dim, **enc_kwargs)
        )

        self.decoder = ContinuousTransformerWrapper(
            **dec_transformer_kwargs,
            attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
        )

        if tie_token_emb:
            self.decoder.token_emb = self.encoder.token_emb

        self.decoder = AutoregressiveWrapper(self.decoder)

    @torch.no_grad()
    def generate(self, seq_in, seq_out_start, seq_len, src_mask = None, **kwargs):
        encodings = self.encoder(seq_in, return_embeddings = True, mask = src_mask)
        return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = src_mask, **kwargs)

    def forward(self, src, tgt, src_mask = None, tgt_mask = None):
        enc = model.encoder(src, mask = src_mask, return_embeddings = True)
    
        gen_token = -10_000 * torch.ones_like(src[:, :1, :])

        out = model.decoder.net(gen_token, context=enc)
        xo = tgt[:, 1:]
        loss = F.mse_loss(out.transpose(1, 2)[:, 0], xo)
        return loss

In [31]:
LEARNING_RATE = 0.0001

model_parameters = ParameterGrid({'dim': [128],
    'tie_token_embeds': [True],
    'return_tgt_loss': [True],
    'enc_depth': [2],
    'enc_heads': [4],
    'dec_depth': [2],
    'dec_heads': [4],
    'enc_max_seq_len': [24],
    'dec_max_seq_len': [1],
    'enc_num_memory_tokens': [0],
    'enc_dim_in': [16],
    'dec_dim_in': [16],
    'enc_dim_out': [1],
    'dec_dim_out': [1],
    'enc_emb_dim': [128],
    'enc_emb_dropout': [0.],
    'enc_use_pos_emb': [False]
})

param = list(model_parameters)[0]


In [53]:
WINDOW_SIZE = 4
PATIENCE = 10
def train_validate_model(model, train_generator, val_generator, optim, model_name, config, generate_every=1e2, num_batches=1e3, verbose=True, overfit_stop=True, print_file=None, tag='', log_path='logs/', head_start=15):
    
    fix_seeds()
    t0 = time.time()
    
    log_dir = log_path + model_name.split('_')[0]
    writer = SummaryWriter(log_dir=log_dir)
    if print_file is None:
        print_file = f"{log_dir}/{model_name}_cout_log.txt"

    validation_scores = []
    for i in range(num_batches):

        model.train()
        
        src, tgt, src_mask, tgt_mask = next(train_generator)
        loss = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        loss.backward()

        loss_value = loss.item()        
        writer.add_scalars("/train/loss", {model_name: loss_value}, i)
#         if loss_value < 1e-10:
#             break

        optim.step()
        optim.zero_grad()

        if i != 0 and i % generate_every == 0:
            model.eval()
            
            with torch.no_grad():
                src, tgt, src_mask, tgt_mask = next(val_generator)
                
                enc = model.encoder(src, mask = src_mask, return_embeddings = True)
    
                gen_token = -10_000 * torch.ones_like(src[:, :1, :])

                out = model.decoder.net(gen_token, context=enc)
                xo = tgt[:, 1:]
                val_loss = F.mse_loss(out.transpose(1, 2)[:, 0], xo)
                val_loss_value = val_loss.item()

            writer.add_scalars("/val/loss", {model_name: val_loss_value}, i)

            validation_scores.append(val_loss_value) 
    
            if verbose:
                with open(print_file, 'a') as f:
                    f.write(f"\n\ninput:  {src[0]}")
                    f.write(f"\npredicted output:  {out[0]}")
                    f.write(f"\ncorrect output:  {xo[0]}")
                    f.write(f"\ntime: {round(time.time() - t0)}")
                    t0 = time.time()
            
            # save checkpoint
            if max(validation_scores) == validation_scores[-1]:
                os.system(f'mkdir {log_path}checkpoints')
                os.system(f'mkdir {log_path}checkpoints/{model_name.split("_")[0]}')
                os.system(f'mkdir {log_path}checkpoints/{model_name.split("_")[0]}/validation')
                save_path = f'{log_path}checkpoints/{model_name.split("_")[0]}/validation/{model_name}_{tag}_maxval.pt'
                save_checkpoint(save_path, model, optim, i, config)
                
            if i // generate_every < head_start:
                continue
                
            # early stopping
            smoothed_val_scores = [np.mean(validation_scores[i-WINDOW_SIZE+1:i]) for i in range(WINDOW_SIZE-1, len(validation_scores))]
            
            if overfit_stop and max(smoothed_val_scores) > max(smoothed_val_scores[-PATIENCE:]):
                break
                
    # save checkpoint
    save_path = f'{log_path}checkpoints/{model_name.split("_")[0]}/{model_name}_{tag}.pt'
    os.system(f'mkdir {log_path}checkpoints/{model_name.split("_")[0]}')
    save_checkpoint(save_path, model, optim, i, config)

    writer.flush()


def test_model(model, test_generator, model_name, param, task_name, tag, num_batches=50, log_path='logs/_test_results.csv'):
    fix_seeds()
    model.eval()

    loss_values = []
    with torch.no_grad():
        for bn in range(num_batches):
            src, tgt, src_mask, tgt_mask = next(test_generator)
            
            enc = model.encoder(src, mask = src_mask, return_embeddings = True)

            gen_token = -10_000 * torch.ones_like(src[:, :1, :])

            out = model.decoder.net(gen_token, context=enc)
            xo = tgt[:, 1:]
            loss = F.mse_loss(out.transpose(1, 2)[:, 0], xo)
            loss_values.append(loss.cpu().item())

    param['tag'] = tag
    param['task_name'] = task_name
    param['model_name'] = model_name
    param['loss'] = np.mean(loss_values)

    if os.path.exists(log_path):
        df = pd.read_csv(log_path)
        df = df.append(param, ignore_index=True)
    else: 
        df = pd.DataFrame([param])
    df.to_csv(log_path, index=False)

In [54]:
GENERATE_EVERY = 100
NUM_BATCHES = 1000

In [55]:
drive_path = 'stocks_logs/'
print_file = f'{drive_path}{TAG}_logs.txt'
t = time.time()
for init_num in range(1):
    with open(print_file, 'a') as f:
        f.write('\n\nInit number ' + str(init_num)+'\n')
    for i, param in enumerate(list(model_parameters)):
        with open(print_file, 'a') as f:
            f.write('\n\n' + str(param)+'\n')
        # param['enc_depth'], param['enc_heads'] = param['depth,heads']
        # param['dec_depth'], param['dec_heads'] = param['depth,heads']
        # param.pop('depth,heads')

        with open(print_file, 'a') as f:
            f.write(f'{i / len(model_parameters) * 100}%')
        model = CXTransformer(**param).cuda()

        model_name = f"{TASK_NAME}{INPUT_LEN}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_{TAG}_v{init_num}"

        optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        train_validate_model(model, 
                        train_generator=gen_train, 
                        val_generator=gen_val, 
                        optim=optim, 
                        model_name=model_name, 
                        config=param,
                        num_batches=NUM_BATCHES,
                        generate_every=GENERATE_EVERY,
                        print_file=print_file,
                        tag=TAG,
                        overfit_stop=False)
        test_model(model, gen_test, model_name, param, TASK_NAME, tag=TAG, log_path=drive_path+'test_results.csv')
        with open(print_file, 'a') as f:
            f.write(f'\nTotal time: {time.time() - t}\n')
        t = time.time()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1 and 16x128)

In [None]:
model = CXTransformer(**param)

src, tgt, _, _ = next(gen_train)

In [11]:
model(src, tgt)

  loss = F.mse_loss(out.transpose(1, 2)[:, 0], xo)


tensor(37729472., grad_fn=<MseLossBackward0>)

In [12]:
1/0

ZeroDivisionError: division by zero

In [None]:
src.shape, tgt.shape

(torch.Size([128, 24, 16]), torch.Size([128, 2]))

In [None]:
src_mask = tgt_mask = None
context_mask = None

enc = model.encoder(src, mask = src_mask, return_embeddings = True)
    
gen_token = -10_000 * torch.ones_like(src[:, :1, :])

# out = model.decoder(gen_token, context = enc, mask = tgt_mask, context_mask = context_mask)

out = model.decoder.net(gen_token, context=enc)
xo = tgt[:, 1:]#.float()
# out = out.float()
# loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = model.decoder.ignore_index)
loss = F.mse_loss(out.transpose(1, 2)[:, 0], xo)

torch.Size([128, 24, 16])
torch.Size([128, 24, 128])
torch.Size([128, 24, 128])
torch.Size([128, 24, 128])
torch.Size([128, 1, 16])
torch.Size([128, 1, 128])
torch.Size([128, 1, 128])
torch.Size([128, 1, 128])


In [None]:
enc.shape, out.shape, tgt[:, 1:].shape

(torch.Size([128, 24, 128]), torch.Size([128, 1, 1]), torch.Size([128, 1]))

In [None]:
tgt[:, 1]

tensor([ 7993.5698,   574.0500, 10375.4297,  9479.9502,   391.0100,  7323.9399,
          430.9000,   539.5100,   637.9800,   732.5400, 10895.3203,  3560.8101,
         7687.0601,   375.9500,   675.2700,   612.2000,  8319.1504,   616.8000,
          385.0000,  7152.7598,   671.0000,  8182.4702,  6539.1201,   734.1200,
         6546.3101,  7902.1602,  1083.6300, 10538.2900,  6159.8901,  5420.0000,
          447.6700,  8699.0400,  8321.9004,  8144.2500,  1000.9700,   405.8400,
         1094.8101,   408.6900,  9504.8496,  8734.3604,  6395.7798,   654.7000,
         7537.9800,  7794.1499,  2544.7600,  7282.2798,  1009.4400,  9120.5596,
          253.9200,  4008.1399,  7741.4302,  9690.2402,   452.9900,   597.9600,
         8986.9697, 17744.5000,  1183.3101,  3900.5601,   684.4200,  3585.6299,
         8807.0098,   388.6900,  1221.9900,   272.4500,  8759.0400,  3872.4099,
          381.7500,  1207.6300,   609.3900,   580.6700,  4188.0000,  7722.8198,
         4587.0000,  6744.0601,  9820.00

In [None]:
src[:, 0].shape

torch.Size([128, 16])

In [None]:
gen_token = -10_000 * torch.ones_like(src[:, :1, :])

In [None]:
src.shape, gen_token.shape

(torch.Size([128, 24, 16]), torch.Size([128, 1, 16]))

In [None]:
model.decoder.net.project_in

Linear(in_features=16, out_features=128, bias=True)

In [None]:
out = model.decoder.net(gen_token)

torch.Size([128, 1, 16])
torch.Size([128, 1, 128])
torch.Size([128, 1, 128])
torch.Size([128, 1, 128])


In [None]:
# xi = x[:, :-1]
# xo = x[:, 1:]

# out = model.net(gen_token)
xo = tgt[:, 1:]
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = model.ignore_index)

AttributeError: 'CXTransformer' object has no attribute 'net'

In [None]:
enc.shape

torch.Size([128, 24, 128])

In [None]:
model(src, tgt)

torch.Size([128, 24, 128])
torch.Size([128, 24, 128])
torch.Size([128, 24, 128])
torch.Size([128, 1])


RuntimeError: output with shape [128, 1] doesn't match the broadcast shape [1, 128, 128]

In [None]:
drive_path = 'drive/MyDrive/stocks_logs/'
print_file = f'{drive_path}{TAG}_logs.txt'
t = time.time()
for init_num in range(NUM_INITS):
    with open(print_file, 'a') as f:
        f.write('\n\nInit number ' + str(init_num)+'\n')
    for i, param in enumerate(list(model_parameters)):
        with open(print_file, 'a') as f:
            f.write('\n\n' + str(param)+'\n')
        param['enc_depth'], param['enc_heads'] = param['depth,heads']
        param['dec_depth'], param['dec_heads'] = param['depth,heads']
        param.pop('depth,heads')

        with open(print_file, 'a') as f:
            f.write(f'{i / len(model_parameters) * 100}%')
        model = XTransformer(**param).cuda()

        model_name = f"{TASK_NAME}{INPUT_LEN}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_{TAG}_v{init_num}"

        optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        train_validate_model(model, 
                        train_generator=gen_train, 
                        val_generator=gen_val, 
                        optim=optim, 
                        model_name=model_name, 
                        config=param,
                        num_batches=NUM_BATCHES,
                        generate_every=GENERATE_EVERY,
                        print_file=print_file,
                        tag=TAG,
                        overfit_stop=False)
        test_model(model, gen_test, model_name, param, TASK_NAME, tag=TAG, log_path=drive_path+'test_results.csv')
        with open(print_file, 'a') as f:
            f.write(f'\nTotal time: {time.time() - t}\n')
        t = time.time()

RuntimeError: ignored

### Refit models

In [None]:
# import os

# def load_cpt(config, v, task_name, input_length):
#     for fns in os.walk('checkpoints'):
#         model_names = fns[2]
        
#     prefix = '{task_name}_dim{dim}d{d}h{h}M{M}l{l}'
#     name = prefix.format(task_name=task_name,
#                         dim=config['dim'],
#                         d=config['enc_depth'], h=config['enc_heads'], 
#                         M=config['enc_num_memory_tokens'], 
#                         l=input_length)

#     checkpoint_paths = ['checkpoints/' + n for n in model_names if name in n]
#     cpt = torch.load(checkpoint_paths[v])
#     bn, model_state, optim_state = cpt['batch_num'], cpt['state_dict'], cpt['optimizer']

#     model = XTransformer(**config).cuda()
#     model.load_state_dict(model_state)

#     optim = torch.optim.Adam(model.parameters(), lr=0.001)
#     optim.load_state_dict(optim_state)

#     return bn, model, optim


In [None]:
# TAG = 'refit_to_max'
# LEARNING_RATE = 0.001

# path = f"checkpoints/{TASK_NAME}{INPUT_LEN}/"

# for name in next(os.walk(path))[2]:
#     print(name)
#     if name == 'copy24_dim128d2h4M12l12_10tkn_len24_v2_10tkn_len24.pt':
#         continue
#     cpt = torch.load(path+name)
#     print(cpt['batch_num'])
#     delta_batches = NUM_BATCHES - cpt['batch_num'] - 1
#     if delta_batches < 1:
#         continue
    
#     split = name.split('_')
#     config = {'dec_max_seq_len': DEC_SEQ_LEN,
#          'dec_num_tokens': NUM_TOKENS,
#          'dim': int(split[1].split('dim')[1].split('d')[0]),
#          'enc_max_seq_len': int(split[1].split('M')[1].split('l')[1]),
#          'enc_num_memory_tokens': int(split[1].split('M')[1].split('l')[0]),
#          'enc_num_tokens': NUM_TOKENS,
#          'return_tgt_loss': True,
#          'tie_token_embeds': True,
#          'enc_depth': int(split[1][3:].split('d')[1].split('h')[0]),
#          'enc_heads': int(split[1][3:].split('d')[1].split('h')[1].split('M')[0]),
#          'dec_depth': int(split[1][3:].split('d')[1].split('h')[0]),
#          'dec_heads': int(split[1][3:].split('d')[1].split('h')[1].split('M')[0]),
#          'tag': TAG,
#          'task_name': TASK_NAME}
    
    
#     gen_train = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE)
#     gen_val = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE)
#     gen_test = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE)


#     print_file = f'logs/{TASK_NAME}_{TAG}_memory_logs.txt'
#     t = time.time()
#     with torch.cuda.device(0):
#         with open(print_file, 'a') as f:
#             f.write('\n\n' + str(config)+'\n')
#             f.write(str(delta_batches) + ' batches to go.\n')

#         print('\n\n' + str(config)+'\n')
#         print(str(delta_batches) + ' batches to go.\n')
#         model_name = name
#         model = XTransformer(**config).cuda()
#         optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        
#         model.load_state_dict(cpt['state_dict'])
#         optim.load_state_dict(cpt['optimizer'])

#         train_validate_model(model, 
#                             train_generator=gen_train, 
#                             val_generator=gen_val, 
#                             optim=optim, 
#                             model_name=model_name, 
#                             config=config,
#                             num_batches=delta_batches,
#                             generate_every=GENERATE_EVERY,
#                             print_file=print_file,
#                             tag=TAG,
#                             overfit_stop=False)
#         test_model(model, gen_test, model_name, config, TASK_NAME, tag=TAG)

#         with open(print_file, 'a') as f:
#             f.write(f'\nTotal time: {time.time() - t}\n')
#         t = time.time()

In [None]:
test_model(model, gen_test, model_name, config, TASK_NAME, tag=TAG)

In [None]:
# gen_train = data_loader(task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
# gen_val = data_loader(task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)
# gen_test = data_loader(task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE, enc_seq_len=INPUT_LEN, dec_seq_len=DEC_SEQ_LEN)


# print_file = f'logs/{TASK_NAME}_{TAG}_memory_logs.txt'
# t = time.time()
# with torch.cuda.device(0):
#     for init_num in range(NUM_INITS):
#         with open(print_file, 'a') as f:
#             f.write('\n\nInit number ' + str(init_num)+'\n')
#         for i, param in enumerate(list(model_parameters)):
#             with open(print_file, 'a') as f:
#                 f.write('\n\n' + str(param)+'\n')
#             param['enc_depth'], param['enc_heads'] = param['depth,heads']
#             param['dec_depth'], param['dec_heads'] = param['depth,heads']
#             param.pop('depth,heads')

#             with open(print_file, 'a') as f:
#                 f.write(f'{i / len(model_parameters) * 100}%')
#             model = XTransformer(**param).cuda()

#             model_name = f"{TASK_NAME}{INPUT_LEN}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_v{init_num}"

#             optim = torch.optim.Adam(model.ффparameters(), lr=LEARNING_RATE)
            
#             bn, model, optim = load_cpt(param, v=init_num, task_name='copy55', input_length=param['enc_max_seq_len'])
#             with open(print_file, 'a') as f:
#                 f.write(f'BN: {bn}\n')
#             if bn < 130_000:
#                 train_validate_model(model, 
#                                     train_generator=gen_train, 
#                                     val_generator=gen_val, 
#                                     optim=optim, 
#                                     model_name=model_name, 
#                                     dec_seq_len=DEC_SEQ_LEN,
#                                     num_batches=NUM_BATCHES,
#                                     generate_every=GENERATE_EVERY,
#                                     print_file=print_file,
#                                     tag=TAG,
#                                     overfit_stop=False,
#                                     head_start=(130_000 - bn)/GENERATE_EVERY)
#                 test_model(model, gen_test, model_name, param, TASK_NAME, tag=TAG, dec_seq_len=param['dec_max_seq_len'])
#             with open(print_file, 'a') as f:
#                 f.write(f'\nTotal time: {time.time() - t}\n')
#             t = time.time()

In [None]:
from run_experiment import save_checkpoint

In [None]:
# save_path = f'checkpoints/{model_name}_b{i}_{TAG}_maxval.pt'
# save_cpt(save_path, model, optim)

# if i // generate_every < head_start:
#     continue

# # early stopping
# smoothed_val_scores = [np.mean(validation_scores[i-WINDOW_SIZE+1:i]) for i in range(WINDOW_SIZE-1, len(validation_scores))]

# if overfit_stop and max(smoothed_val_scores) > max(smoothed_val_scores[-PATIENCE:]):
#     break

### Test!

In [None]:
init_num = 0

gen_train = data_loader(task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_val = data_loader(task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)
gen_test = data_loader(task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN)


param = list(model_parameters)[5]
print(param)
param['enc_depth'], param['enc_heads'] = param['depth,heads']
param['dec_depth'], param['dec_heads'] = param['depth,heads']
param.pop('depth,heads')

model = XTransformer(**param).cuda()

model_name = f"{TASK_NAME}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_v{init_num}"

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

src, tgt, _, _ = next(gen_train)

print(model.encoder.max_seq_len, model.encoder.num_memory_tokens)
model.encoder(torch.cat((src, src)), return_embeddings=True).shape