In [1]:
#!pip install --user tensorboardX

In [2]:
# code based off of 
# https://github.com/mandubian/pytorch_math_dataset and
# https://github.com/lucidrains/reformer-pytorch

import math
import numpy as np
import torch
from torch.utils import data
import torch.optim as optim
import tqdm as tqdm
import random
from datetime import datetime
from apex import amp
import pickle


import mandubian.math_dataset
from mandubian.math_dataset import MathDatasetManager
from mandubian.transformer import Constants

# from transformer.Models import Transformer
from mandubian.math_dataset import (
    random_split_dataset,
    question_answer_to_mask_batch_collate_fn
)
from mandubian.math_dataset import np_encode_string, np_decode_string
import mandubian.model_process
import mandubian.utils
from mandubian.tensorboard_utils import Tensorboard
from mandubian.tensorboard_utils import tensorboard_event_accumulator

import mandubian.checkpoints

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook

from datetime import datetime

print("Torch Version", torch.__version__)

%load_ext autoreload
%autoreload 2

Using backend: pytorch


Torch Version 1.5.0


# Check hardware

In [3]:
seed = 1
print(torch.cuda.device_count(), "detected CUDA devices")
cuda_device = torch.cuda.current_device()
print("Using CUDA device: ", cuda_device)
print(torch.cuda.get_device_name(cuda_device))

torch.manual_seed(seed)
device = torch.device("cuda")
print("device", device)

1 detected CUDA devices
Using CUDA device:  0
GeForce RTX 2080
device cuda


# Reformer library

In [4]:
from lucidrains_reformer.reformer_pytorch import ReformerLM, Autopadder, Recorder
from lucidrains_reformer.reformer_pytorch import ReformerEncDec
from lucidrains_reformer.reformer_pytorch.generative_tools import TrainingWrapper

# Initialize Math Dataset Manager

In [5]:
mdsmgr = MathDatasetManager(
  "/home/jonathan/Repos/final_year_at_ic/awesome_project/mathematics_dataset-v1.0/"
)
# Examine dataset structure
print("mdsmgr structure", dir(mdsmgr))

initialized MultiFilesMathDataset with categories ['algebra', 'numbers', 'polynomials', 'comparison', 'arithmetic', 'measurement', 'probability', 'calculus'] and types ['train-easy', 'train-medium', 'train-hard', 'interpolate', 'extrapolate']
mdsmgr structure ['__add__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_build_datasets_from_category', 'build_dataset_from_categories', 'build_dataset_from_category', 'build_dataset_from_module', 'build_dataset_from_modules', 'dfs', 'dirs', 'get_categories', 'get_modules_for_category', 'get_types', 'root_dir']


In [6]:
# print(MathDatasetManager.__dir__
mdsmgr._build_datasets_from_category

<bound method MathDatasetManager._build_datasets_from_category of <mandubian.math_dataset.MathDatasetManager object at 0x7efb71344cc0>>

### Check availables types, problem categories and problem subcategories

In [7]:
print("types", list(mdsmgr.get_types()))
print("categories", list(mdsmgr.get_categories()))
print("modules of arithmetic", mdsmgr.get_modules_for_category('arithmetic'))


types ['train-easy', 'train-medium', 'train-hard', 'interpolate', 'extrapolate']
categories ['algebra', 'numbers', 'polynomials', 'comparison', 'arithmetic', 'measurement', 'probability', 'calculus']
modules of arithmetic dict_keys(['div', 'nearest_integer_root', 'mul_div_multiple', 'mul', 'add_or_sub', 'add_sub_multiple', 'mixed', 'add_or_sub_in_base', 'simplify_surd', 'add_or_sub_big', 'add_sub_multiple_longer', 'mixed_longer', 'div_big', 'mul_div_multiple_longer', 'mul_big'])


### Ways to manipulate dataset

In [8]:
# # Build Dataset from a single module in a category
ds = mdsmgr.build_dataset_from_module('arithmetic', 'add_or_sub', 'train-easy')
print("size", len(ds))

# # Build Dataset from a single module in a category with limited number of elements
# ds = mdsmgr.build_dataset_from_module('arithmetic', 'add_or_sub', 'train-easy', max_elements=1000)
# print("size", len(ds))

# # Build Dataset from several modules in a category
# ds = mdsmgr.build_dataset_from_modules('arithmetic', ['add_or_sub', 'add_sub_multiple'], 'train-easy')
# print("size", len(ds))

# # Build Dataset from all modules in a category
# ds = mdsmgr.build_dataset_from_category('arithmetic', 'train-easy')
# ds = mdsmgr.build_dataset_from_category('arithmetic', 'interpolate')
# print("size", len(ds))

# # Build Dataset from all modules in several categories
# ds = mdsmgr.build_dataset_from_categories(['arithmetic', 'polynomials'], 'train-easy')
# print("size", len(ds))

# # 

size 666666


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  return super(DataFrame, self).rename(**kwargs)


In [9]:
# Pytorch initialization

# Start baseline

In [10]:
exp_name = "overfit_one_batch_142_V2"
now = datetime.now()
unique_id = now.strftime("%m-%d-%Y_%H-%M-%S")
base_dir = "/home/jonathan/Repos/final_year_at_ic/awesome_project/code/tests/"

## Constants

In [11]:
from mandubian.math_dataset import (
    VOCAB_SZ, MAX_QUESTION_SZ, MAX_ANSWER_SZ
)

NUM_CPU_THREADS = 12
BATCH_SIZE = 128
NUM_BATCHES = int(1e5)
BATCH_SIZE = 32
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 20
GENERATE_EVERY  = 60
GENERATE_LENGTH = 32

# hyperparameters need updates

Q_SEQ_LEN = 256
A_SEQ_LEN = 30 # unused due to requirements of axial_positon_shape
NUM_TOKENS = VOCAB_SZ + 1
D_MODEL = 512
EMB_DIM = D_MODEL
NUM_HEADS = 8
QKV_DIM = D_MODEL / NUM_HEADS
NUM_LAYERS = 6
D_FF = 2048


### Get training and test data

In [12]:
# training data
# training_data = mdsmgr.build_dataset_from_category('arithmetic','train-easy') # for now
training_data = mdsmgr.build_dataset_from_modules('arithmetic', ['add_or_sub', 'add_sub_multiple'], 'train-easy', max_elements = 142)

# testing data
# testing_data_interpolate = mdsmgr.build_dataset_from_category('arithmetic','interpolate')
# testing_data_extrapolate = mdsmgr.build_dataset_from_category('arithmetic','extrapolate')

testing_data_interpolate = mdsmgr.build_dataset_from_modules('arithmetic', ['add_or_sub', 'add_sub_multiple'], 'interpolate', max_elements = 1024)
# testing_data_extrapolate = mdsmgr.build_dataset_from_modules('arithmetic', ['add_or_sub', 'add_sub_multiple'], 'extrapolate')


In [13]:
# from lucidrains_reformer.examples.enwik8_simple.train
# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

def get_non_pad_mask(seq):
    # returns true when token is not PAD and false otherwise
    assert seq.dim() == 2
    return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)

# get data splits
train_ds, val_ds = mandubian.math_dataset.random_split_dataset(training_data,split_rate=0.9)

# get pytorch dataloaders
# Questions are padded in question_answer_to_position_batch_collate_fn
train_loader = data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_CPU_THREADS,
    collate_fn=question_answer_to_mask_batch_collate_fn)
train_loader = cycle(train_loader)

val_loader = data.DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_CPU_THREADS,
    collate_fn=question_answer_to_mask_batch_collate_fn)
val_loader = cycle(val_loader)

# for viewing output sequences
gen_loader = data.DataLoader(
    val_ds, batch_size=1, shuffle=False, num_workers=NUM_CPU_THREADS,
    collate_fn=question_answer_to_mask_batch_collate_fn)
gen_loader = cycle(gen_loader)

interpolate_loader = data.DataLoader(
    testing_data_interpolate, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_CPU_THREADS,
    collate_fn=question_answer_to_mask_batch_collate_fn)
interpolate_loader = cycle(interpolate_loader)

# extrapolate_loader = data.DataLoader(
#     testing_data_extrapolate, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_CPU_THREADS,
#     collate_fn=question_answer_to_mask_batch_collate_fn)
# extrapolate_loader = cycle(extrapolate_loader)


### Model

In [14]:
# define model

enc_dec = ReformerEncDec(
    dim = D_MODEL,
    enc_num_tokens = NUM_TOKENS,
    enc_depth = NUM_LAYERS,
    enc_max_seq_len = Q_SEQ_LEN,
    dec_num_tokens = NUM_TOKENS,
    dec_depth = NUM_LAYERS,
    dec_max_seq_len = Q_SEQ_LEN,
    # heads = 8 by default
    axial_position_shape = (64, 16),  # the shape must multiply up to the max_seq_len (128 x 64 = 8192)
    axial_position_dims = (256,256),   # the dims must sum up to the model dimensions (512 + 512 = 1024)
    pad_value = Constants.PAD,
    ignore_index = Constants.PAD # see if this works. pad_value and ignore_index are probably different
).cuda()

# enc_dec = Recorder(enc_dec)
enc_dec.to(device)


ReformerEncDec(
  (enc): TrainingWrapper(
    (net): Autopadder(
      (net): ReformerLM(
        (token_emb): Embedding(96, 512, padding_idx=0)
        (to_model_dim): Identity()
        (pos_emb): AxialPositionalEncoding(
          (weights): ParameterList(
              (0): Parameter containing: [torch.cuda.FloatTensor of size 1x4x1x256 (GPU 0)]
              (1): Parameter containing: [torch.cuda.FloatTensor of size 1x1x64x256 (GPU 0)]
          )
        )
        (reformer): Reformer(
          (layers): ReversibleSequence(
            (blocks): ModuleList(
              (0): ReversibleBlock(
                (f): Deterministic(
                  (net): PreNorm(
                    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
                    (fn): LSHSelfAttention(
                      (toqk): Linear(in_features=512, out_features=512, bias=False)
                      (tov): Linear(in_features=512, out_features=512, bias=False)
                      (to_out)

## Optimizer learning rate scheduler, mixed precision setup

In [15]:
optimizer = optim.Adam(enc_dec.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.995), eps=1e-9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=100, verbose=True)

# mixed precision
enc_dec, optimizer = amp.initialize(enc_dec, optimizer, opt_level='O2')

Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.

Defaults for this optimization level are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic


# Train

In [16]:
# for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
i = 0

train_loss_list = []
val_loss_list = []

while True:
    
# for batch_idx, batch in enumerate(tqdm(train_loader, mininterval=2, leave=False)):
    batch_qs, batch_qs_mask, batch_as, batch_as_mask = map(lambda x: x.to(device), next(train_loader))
    # exclude the 0th element as it is BOS
    gold_as = batch_as[:, 1:]
    
    if (i % GENERATE_EVERY) - 1 == 0:
        enc_dec.eval()
        gen_qs, gen_qs_mask, gen_as, gen_as_mask = next(gen_loader)
    #         inp = random.choice(val_ds)[:-1]
        prime = np_decode_string(gen_qs.numpy())
        print('*' * 100, "\nQuestion: ", prime)
        print("Actual Answer: ", np_decode_string(gen_as.numpy()))
    #     print("Raw Answer: ", gen_as.numpy())
        gen_qs = gen_qs.to(device)
        gen_as = gen_as.to(device)
        gen_qs_mask = gen_qs_mask.to(device)
        sample = enc_dec.generate(gen_qs, gen_as, GENERATE_LENGTH, enc_input_mask = gen_qs_mask)
        sample = sample.cpu().numpy()
        output_str = np_decode_string(sample)
    #     print("Raw Prediction: ", sample)
        print("Decoded Prediction: ", output_str)
        np.savetxt(base_dir + "logs/" + exp_name + "_" + unique_id + "-train_loss.txt", train_loss_list)
        np.savetxt(base_dir + "logs/" + exp_name + "_" + unique_id + "-val_loss.txt", val_loss_list)
        
#         with open(base_dir + "logs/" + exp_name + "_" + unique_id + "-train_loss.txt", "w") as fp:
#             pickle.dumps(train_loss_list, fp)
#         with open(base_dir + "logs/" + exp_name + "_" + unique_id + "-val_loss.txt", "w") as fp:
#             pickle.dumps(val_loss_list, fp)
#         print("Logs saved to ", "logs/" + exp_name + "_" + unique_id + "-val_loss.txt")
            
            

    enc_dec.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        train_loss = enc_dec(batch_qs, batch_as, return_loss = True, enc_input_mask = batch_qs_mask)
        with amp.scale_loss(train_loss, optimizer) as scaled_loss:
            scaled_loss.backward()

    
#     if batch_idx % GRADIENT_ACCUMULATE_EVERY == 0:
    print("Step ", i, "\t", f'training loss: {train_loss.item()}', "\t", datetime.now().time() )
    train_loss_list.append((i, train_loss.item()))
    torch.nn.utils.clip_grad_norm_(enc_dec.parameters(), 0.1)
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step(train_loss)
    
    if i % VALIDATE_EVERY == 0:
        val_batch_qs, val_batch_qs_mask, val_batch_as, val_batch_as_mask = map(lambda x: x.to(device), next(val_loader))
        enc_dec.eval()
        with torch.no_grad():
            val_loss = enc_dec(val_batch_qs, val_batch_as, return_loss = True, enc_input_mask = val_batch_qs_mask)
            print(f'validation loss: {val_loss.item()}')
            val_loss_list.append((i, val_loss.item()))

    i += 1


Step  0 	 training loss: 4.875 	 12:41:58.103986
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
validation loss: 4.96484375
**************************************************************************************************** 
Question:  !What is the value of (10 - 5) + (-6 - -3 - 1)?"                                                                                                                                                                                                                
Actual Answer:  !1"                                                                                                                                                                                                                                                             
Decoded Prediction:  111.!.pNS<g 3+ 3...!N.TLp.n ..n.
Step  1 	 training loss: 4.9296875 	 12:42:12.328322
Step  2 	 training loss: 3.716796875 	 12:42:23.548709
Step  3 	 training loss: 3.33203125 	 

Step  117 	 training loss: 1.60546875 	 13:03:58.175013
Step  118 	 training loss: 1.4912109375 	 13:04:08.958444
Step  119 	 training loss: 1.5341796875 	 13:04:20.092188
Step  120 	 training loss: 1.5244140625 	 13:04:31.114749
validation loss: 1.8994140625
**************************************************************************************************** 
Question:  !Total of -1.7 and -8658."                                                                                                                                                                                                                                      
Actual Answer:  !-8659.7"                                                                                                                                                                                                                                                       
Decoded Prediction:  .7"1" "" 3"8""8"1" 8"""@""1"".4"
Step  121 	 training loss: 1.515625 	 13:04

Step  233 	 training loss: 1.138671875 	 13:25:13.349076
Step  234 	 training loss: 0.9921875 	 13:25:24.615732
Step  235 	 training loss: 1.0693359375 	 13:25:35.889804
Step  236 	 training loss: 0.91259765625 	 13:25:47.156579
Step  237 	 training loss: 1.078125 	 13:25:58.418637
Step  238 	 training loss: 1.0400390625 	 13:26:09.699172
Step  239 	 training loss: 0.97509765625 	 13:26:21.317379
Step  240 	 training loss: 0.99072265625 	 13:26:32.829511
validation loss: 2.080078125
**************************************************************************************************** 
Question:  !Evaluate (1 - 2) + -1 - (-2 - 1)."                                                                                                                                                                                                                             
Actual Answer:  !1"                                                                                                                           

Step  348 	 training loss: 0.442138671875 	 13:47:06.357415
Step  349 	 training loss: 0.48291015625 	 13:47:17.623646
Step  350 	 training loss: 0.4736328125 	 13:47:28.890769
Step  351 	 training loss: 0.517578125 	 13:47:40.509959
Step  352 	 training loss: 0.37451171875 	 13:47:51.998846
Step  353 	 training loss: 0.474609375 	 13:48:03.277224
Step  354 	 training loss: 0.374267578125 	 13:48:14.530751
Step  355 	 training loss: 0.423828125 	 13:48:25.815973
Step  356 	 training loss: 0.40478515625 	 13:48:37.101472
Step  357 	 training loss: 0.49609375 	 13:48:48.595111
Step  358 	 training loss: 0.4482421875 	 13:49:00.025945
Step  359 	 training loss: 0.494140625 	 13:49:11.772317
Step  360 	 training loss: 0.39892578125 	 13:49:23.289533
validation loss: 2.884765625
**************************************************************************************************** 
Question:  !Evaluate 0 - 9 - (-4 - -1)."                                                                         

validation loss: 3.7421875
Step  461 	 training loss: 0.11285400390625 	 14:08:35.597973
Step  462 	 training loss: 0.1070556640625 	 14:08:46.317569
Step  463 	 training loss: 0.09954833984375 	 14:08:57.373612
Step  464 	 training loss: 0.09674072265625 	 14:09:08.319906
Step  465 	 training loss: 0.0933837890625 	 14:09:19.042694
Step  466 	 training loss: 0.09600830078125 	 14:09:29.755829
Step  467 	 training loss: 0.10028076171875 	 14:09:40.471524
Step  468 	 training loss: 0.0972900390625 	 14:09:51.188799
Step  469 	 training loss: 0.0972900390625 	 14:10:01.903920
Step  470 	 training loss: 0.11077880859375 	 14:10:12.623850
Step  471 	 training loss: 0.08154296875 	 14:10:23.673003
Step  472 	 training loss: 0.0887451171875 	 14:10:34.624689
Step  473 	 training loss: 0.06329345703125 	 14:10:45.344661
Step  474 	 training loss: 0.0938720703125 	 14:10:56.069262
Step  475 	 training loss: 0.09051513671875 	 14:11:06.795319
Step  476 	 training loss: 0.077880859375 	 14:11:17

Step  567 	 training loss: 0.0236968994140625 	 14:28:24.769987
Step  568 	 training loss: 0.01898193359375 	 14:28:36.286266
Step  569 	 training loss: 0.0260772705078125 	 14:28:47.558884
Step  570 	 training loss: 0.0188751220703125 	 14:28:58.815536
Step  571 	 training loss: 0.0176544189453125 	 14:29:10.104330
Step  572 	 training loss: 0.0218658447265625 	 14:29:21.356177
Step  573 	 training loss: 0.0207672119140625 	 14:29:32.606718
Step  574 	 training loss: 0.022705078125 	 14:29:43.908837
Step  575 	 training loss: 0.019287109375 	 14:29:55.532253
Step  576 	 training loss: 0.0186614990234375 	 14:30:07.032802
Step  577 	 training loss: 0.0145416259765625 	 14:30:18.306217
Step  578 	 training loss: 0.0174102783203125 	 14:30:29.581847
Step  579 	 training loss: 0.02154541015625 	 14:30:40.839236
Step  580 	 training loss: 0.0182647705078125 	 14:30:52.114501
validation loss: 4.4140625
Step  581 	 training loss: 0.020477294921875 	 14:31:03.976047
Step  582 	 training loss:

Step  672 	 training loss: 0.00855255126953125 	 14:48:24.186055
Step  673 	 training loss: 0.01183319091796875 	 14:48:35.437621
Step  674 	 training loss: 0.0077972412109375 	 14:48:46.685115
Step  675 	 training loss: 0.00945281982421875 	 14:48:57.966348
Step  676 	 training loss: 0.00927734375 	 14:49:09.225238
Step  677 	 training loss: 0.00905609130859375 	 14:49:20.483629
Step  678 	 training loss: 0.01006317138671875 	 14:49:31.753107
Step  679 	 training loss: 0.01012420654296875 	 14:49:43.372958
Step  680 	 training loss: 0.01044464111328125 	 14:49:54.857704
validation loss: 4.92578125
Step  681 	 training loss: 0.01110076904296875 	 14:50:06.733393
Step  682 	 training loss: 0.00936126708984375 	 14:50:17.999037
Step  683 	 training loss: 0.00799560546875 	 14:50:29.264489
Step  684 	 training loss: 0.008880615234375 	 14:50:40.583978
Step  685 	 training loss: 0.00878143310546875 	 14:50:51.916399
Step  686 	 training loss: 0.0090789794921875 	 14:51:03.171873
Step  687 

Decoded Prediction:  """"."""""."""""""""""""".""""""
Step  781 	 training loss: 0.00504302978515625 	 15:09:09.261357
Step  782 	 training loss: 0.005558013916015625 	 15:09:20.527225
Step  783 	 training loss: 0.00508880615234375 	 15:09:32.170116
Step  784 	 training loss: 0.00562286376953125 	 15:09:43.725359
Step  785 	 training loss: 0.00466156005859375 	 15:09:54.996950
Step  786 	 training loss: 0.005573272705078125 	 15:10:06.267386
Step  787 	 training loss: 0.0056304931640625 	 15:10:17.536219
Step  788 	 training loss: 0.0036907196044921875 	 15:10:28.794582
Step  789 	 training loss: 0.003803253173828125 	 15:10:40.056140
Step  790 	 training loss: 0.0062408447265625 	 15:10:51.321479
Step  791 	 training loss: 0.004611968994140625 	 15:11:02.948583
Step  792 	 training loss: 0.00604248046875 	 15:11:14.454469
Step  793 	 training loss: 0.0046539306640625 	 15:11:25.721663
Step  794 	 training loss: 0.006927490234375 	 15:11:36.988877
Step  795 	 training loss: 0.004909515

Step  893 	 training loss: 0.0028667449951171875 	 15:30:03.397772
Step  894 	 training loss: 0.0024242401123046875 	 15:30:14.199768
Step  895 	 training loss: 0.0026397705078125 	 15:30:25.346225
Step  896 	 training loss: 0.003826141357421875 	 15:30:36.397935
Step  897 	 training loss: 0.003276824951171875 	 15:30:47.206092
Step  898 	 training loss: 0.0031337738037109375 	 15:30:58.006700
Step  899 	 training loss: 0.0038242340087890625 	 15:31:08.831239
Step  900 	 training loss: 0.002895355224609375 	 15:31:19.626965
validation loss: 5.76953125
**************************************************************************************************** 
Question:  !Evaluate -29 + 24 - (2 - 9)."                                                                                                                                                                                                                                  
Actual Answer:  !2"                                                     

Step  994 	 training loss: 0.0020122528076171875 	 15:48:29.049950
Step  995 	 training loss: 0.0021266937255859375 	 15:48:39.839285
Step  996 	 training loss: 0.0014696121215820312 	 15:48:50.614756
Step  997 	 training loss: 0.0013275146484375 	 15:49:01.404205
Step  998 	 training loss: 0.0013513565063476562 	 15:49:12.194648
Step  999 	 training loss: 0.0015239715576171875 	 15:49:23.316055
Step  1000 	 training loss: 0.0016222000122070312 	 15:49:34.346446
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
validation loss: 6.01953125
Step  1001 	 training loss: 0.00188446044921875 	 15:49:45.699851
Step  1002 	 training loss: 0.0015239715576171875 	 15:49:56.491734
Step  1003 	 training loss: 0.0020599365234375 	 15:50:07.279901
Step  1004 	 training loss: 0.00167083740234375 	 15:50:18.184677
Step  1005 	 training loss: 0.0017566680908203125 	 15:50:29.069014
Step  1006 	 training loss: 0.0016260147094726562 	 15:50:39.946385
Step  1007 	 training lo

KeyboardInterrupt: 

In [None]:
plt.plot([train_loss_list[1]])
plt.show