In [1]:
import math
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data processing
###### Generate an addition dataset

In [2]:
PLUS_SIGN = 10
MUL_SIGN  = 11
MINUS_SIGN = 12
EQUAL_SIGN = 13
EOS = 14
BOS = 15
PAD = 16
UNK = 17

symbol_to_int_dict = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4,
                      "5": 5, "6": 6, "7": 7, "8": 8, "9": 9,
                      "+": PLUS_SIGN, "*": MUL_SIGN, "-": MINUS_SIGN,
                      "=": EQUAL_SIGN,  "<EOS>": EOS, "<BOS>": BOS,
                      "<pad>": PAD, "??": UNK
                      }

int_to_symbol_dict = {y:x for (x,y) in symbol_to_int_dict.items()}
vocab_size = len(symbol_to_int_dict)

def decode_equation(equation):
    '''convert an equation in list format to string format '''
    res = "".join([str(int_to_symbol_dict.get(x, UNK)) for x in equation.tolist()])
    return res.replace("<BOS>", "").replace("<EOS>", "")

def encode_equation(equation, max_ndigits, padQ=True):
    '''convert an equation(up to the equal sign in it) in string format to a list'''
    equal_size_loc = equation.index('=')
    plus_size_loc = equation.index('+')
    num1 = pad_number(equation[0:plus_size_loc], max_ndigits)
    num2 = pad_number(equation[plus_size_loc+1:equal_size_loc], max_ndigits)
    new_equation = num1 + "+" + num2 + "="
    return torch.tensor([BOS]+[symbol_to_int_dict.get(n, UNK) for n in new_equation]).to(DEVICE)


def pad_number(num, max_ndigits)->str:
    'pad numbers with zeros in front so that they have the same length max_ndigits'
    s = str(num)
    while len(s)<max_ndigits:
      s = "0"+s
    return s

def create_add_dataset(max_ndigits, dataset_size, padQ=True):
    ''' Function for creating an addition dataset.
    if padQ=True, pre-padding of 0s will be added on the numbers such that all the 
    numbers has the same length max_ndigits, for example, with max_ndigits=3,  
    32 will be represented 032.
    '''
    dataset_str = []
    for i in range(dataset_size):
        num1, num2 = np.random.randint(0, 10**max_ndigits, 2)
        ans = num1 + num2
        # If padQ=True, we pad all the numbers with '0' in front
        # such that they all have length max_ndigits
        if padQ:
            equation = pad_number(num1, max_ndigits) + '+' + pad_number(num2, max_ndigits) + "=" + pad_number(ans, max_ndigits)
        else:
            equation = str(num1) + '+' + str(num2) + "=" + str(ans)
        dataset_str.append(equation)

    dataset = [torch.tensor([BOS]+[symbol_to_int_dict.get(n, UNK) for n in x]+[EOS])
               for x in dataset_str]
    return dataset, dataset_str

print(create_add_dataset(2, 4, padQ=False))
print(create_add_dataset(2, 4, padQ=True))

([tensor([15,  5,  1, 10,  9,  2, 13,  1,  4,  3, 14]), tensor([15,  1,  4, 10,  7,  1, 13,  8,  5, 14]), tensor([15,  6,  0, 10,  2,  0, 13,  8,  0, 14]), tensor([15,  8,  2, 10,  8,  6, 13,  1,  6,  8, 14])], ['51+92=143', '14+71=85', '60+20=80', '82+86=168'])
([tensor([15,  7,  4, 10,  7,  4, 13,  1,  4,  8, 14]), tensor([15,  8,  7, 10,  9,  9, 13,  1,  8,  6, 14]), tensor([15,  2,  3, 10,  0,  2, 13,  2,  5, 14]), tensor([15,  2,  1, 10,  5,  2, 13,  7,  3, 14])], ['74+74=148', '87+99=186', '23+02=25', '21+52=73'])


# Create dataloders for the train, validation and test sets

In [3]:
class TranslationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

batch_size = 256

def pad_sequence(batch):
    input_padded = torch.nn.utils.rnn.pad_sequence(batch,
                                batch_first=True, padding_value = PAD)
    return input_padded

@dataclass
class DataLoaders:
    max_ndigits: int
    dataset_size: int
    padQ: bool = True
    val_loader = None
    test_loader = None
    train_loader = None

    def split_data(self, split=[0.7, 0.1, 0.2]):
        # If split consists of floats whose sum is equal to 1 then we split the
        # dataset by the percentages given by split. If split contains integers, then
        # it is understood that the two first integers in split are the number of examples
        # in the validation and test sets.
        if isinstance(split[0], float):
            train_size  = round(self.dataset_size * split[0])
            val_size = round(self.dataset_size * split[1])
            test_size = self.dataset_size - train_size - val_size

        elif isinstance(split[0], int):
            val_size = split[0]
            test_size = split[1]
            train_size  = dataset_size - test_size - val_size


        dataset, _ = create_add_dataset(self.max_ndigits, self.dataset_size, padQ=self.padQ)
        train_set, val_set, test_set = torch.utils.data.random_split(dataset,
                                                             [train_size, val_size, test_size],
                                                    generator=torch.Generator().manual_seed(42) )

        self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                           shuffle=True, collate_fn = pad_sequence)
        self.test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                           shuffle=True, collate_fn=pad_sequence)
        self.val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,
                                           shuffle=True, collate_fn=pad_sequence)

# GPT model
Here is my implementation of the GPT model, including the multi-headed self-attention module.

In [4]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_embed, dropout=0.0):
        super(MultiHeadedAttention, self).__init__()
        assert d_embed % h == 0 # check the h number
        self.d_k = d_embed//h
        self.d_embed = d_embed
        self.h = h
        self.WQ = nn.Linear(d_embed, d_embed)
        self.WK = nn.Linear(d_embed, d_embed)
        self.WV = nn.Linear(d_embed, d_embed)
        self.linear = nn.Linear(d_embed, d_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_query, x_key, x_value, mask=None):
        nbatch = x_query.size(0) # get batch size
        # 1) Linear projections to get the multi-head query, key and value tensors
        # x_query, x_key, x_value dimension: nbatch * seq_len * d_embed
        # LHS query, key, value dimensions: nbatch * h * seq_len * d_k
        query = self.WQ(x_query).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        key   = self.WK(x_key).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        value = self.WV(x_value).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        # 2) Attention
        # scores has dimensions: nbatch * h * seq_len * seq_len
        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_k)
        # 3) Mask out padding tokens and future tokens
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        # p_atten dimensions: nbatch * h * seq_len * seq_len
        p_atten = torch.nn.functional.softmax(scores, dim=-1)
        p_atten = self.dropout(p_atten)
        # x dimensions: nbatch * h * seq_len * d_k
        x = torch.matmul(p_atten, value)
        # x now has dimensions:nbtach * seq_len * d_embed
        x = x.transpose(1, 2).contiguous().view(nbatch, -1, self.d_embed)
        return self.linear(x) # final linear layer


class ResidualConnection(nn.Module):
  '''residual connection: x + dropout(sublayer(layernorm(x))) '''
  def __init__(self, dim, dropout):
      super().__init__()
      self.drop = nn.Dropout(dropout)
      self.norm = nn.LayerNorm(dim)

  def forward(self, x, sublayer):
      return x + self.drop(sublayer(self.norm(x)))
    
class Decoder(nn.Module):
    '''Decoder = token embedding + positional embedding -> a stack of N DecoderBlock -> fully-connected layer'''
    def __init__(self, config):
        super().__init__()
        self.d_embed = config.d_embed
        self.tok_embed = nn.Embedding(config.decoder_vocab_size, config.d_embed)
        self.pos_embed = nn.Parameter(torch.zeros(1, config.max_seq_len, config.d_embed))
        self.dropout = nn.Dropout(config.dropout)
        self.decoder_blocks = nn.ModuleList([DecoderBlock(config) for _ in range(config.N_decoder)])
        self.norm = nn.LayerNorm(config.d_embed)
        self.linear = nn.Linear(config.d_embed, config.decoder_vocab_size)

    def future_mask(self, seq_len):
        '''mask out tokens at future positions'''
        mask = (torch.triu(torch.ones(seq_len, seq_len, requires_grad=False), diagonal=1)!=0).to(DEVICE)
        return mask.view(1, 1, seq_len, seq_len)

    def forward(self, input, pad_mask):
        seq_len = input.size(1)
        trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
        x = self.tok_embed(input) + self.pos_embed[:, :input.size(1), :]
        x = self.dropout(x)
        for layer in self.decoder_blocks:
            x = layer( x, trg_mask)
        x = self.norm(x)
        logits = self.linear(x)
        return logits


class DecoderBlock(nn.Module):
    ''' EncoderBlock: self-attention -> position-wise feed-forward (fully connected) layer'''
    def __init__(self, config):
        super().__init__()
        self.atten = MultiHeadedAttention(config.h, config.d_embed)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_embed, config.d_ff),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_ff, config.d_embed)
        )
        self.residuals = nn.ModuleList([ResidualConnection(config.d_embed, config.dropout)
                                       for i in range(2)])

    def forward(self, decoder_layer_input, trg_mask):
        y = decoder_layer_input
        y = self.residuals[0](y, lambda y: self.atten(y, y, y, mask=trg_mask))
        return self.residuals[1](y, self.feed_forward)
    
    

# Helper Functions

In [5]:
@dataclass
class ModelConfig:
    decoder_vocab_size: int
    d_embed: int
    # d_ff is the dimension of the fully-connected  feed-forward layer
    d_ff: int
    # h is the number of attention head
    h: int
    N_decoder: int
    max_seq_len: int
    dropout: float


def make_GPT(config):
    model = Decoder(config).to(DEVICE)
    # initialize model parameters
    # it seems that this initialization is very important!
    for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    return model

In [6]:
def make_batch_input(x):
        'function for generating model input, target and pad_mask from raw input x'
        input = x[:, :-1].to(DEVICE)
        equal_sign_loc = [(equation==EQUAL_SIGN).nonzero().item() for equation in x]
        # for the target, we mask out the tokens before the equal sign (including the equal sign)
        target = [torch.cat(
            (torch.tensor([PAD]*equal_sign_loc[i]), x[i][equal_sign_loc[i]+1:])) for i in range(len(x))]
        target = torch.cat(target, 0).contiguous().view(-1).to(DEVICE)
        pad_mask = (input == PAD).view(input.size(0), 1, 1, input.size(-1))
        return input, target, pad_mask

In [7]:
def train_epoch(model, dataloader):
    model.train()
    grad_norm_clip = 1.0
    losses, acc, count = [], 0, 0
    num_batches = len(dataloader)
    pbar = tqdm(enumerate(dataloader), total=num_batches)
    for idx, x  in  pbar:
        optimizer.zero_grad()
        input, target, pad_mask = make_batch_input(x)
        pred = model(input, pad_mask).to(DEVICE)
        pred = pred.view(-1, pred.size(-1))
        loss = loss_fn(pred, target).to(DEVICE)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
        # report progress
        if idx>0 and idx%50 == 0:
            pbar.set_description(f"ep: {scheduler.last_epoch//num_batches}, train loss={loss.item():.3f},lr={scheduler.get_last_lr()[0]:.5f}")
    return np.mean(losses)

def train(model, dataloaders, epochs):
    global early_stop_count
    train_size = len(dataloaders.train_loader)*batch_size
    for ep in range(epochs):
        train_loss = train_epoch(model, dataloaders.train_loader)
        val_loss = validate(model, dataloaders.val_loader)
        print(f'ep {ep}: train_loss: {train_loss:.5f}, val_loss: {val_loss:.5f}')

    return train_loss, val_loss


def validate(model, dataloder):
    'function for computing the loss on the validation set'
    model.eval()
    losses = []
    with torch.no_grad():
        for i, x in enumerate(dataloder):
            input, target, pad_mask = make_batch_input(x)
            pred = model(input, pad_mask).to(DEVICE)
            pred = pred.view(-1, pred.size(-1))
            losses.append(loss_fn(pred, target).item())
    return np.mean(losses)

In [8]:
@torch.no_grad()
def compute_sum(model, x):
    'Function for computing the sum of two numbers.'
    for i in range(max_ndigits+2):
        pad_mask = (x == PAD).view(1, 1, 1, x.size(-1)).to(DEVICE)
        logits = model(x, pad_mask)
        last_output = logits.argmax(-1)[:,-1].view(1,1)
        x = torch.cat((x, last_output), 1).to(DEVICE)
        if last_output.item() == EOS:
            break
    return x[0]

def evaluate(model, dataloader, num_batch=None):
    '''Function for evaluation the model.
    This function take equations, and truncate them up to the equal-sign, and feed them to the
    model to get the predictions, compare them with the correct answers, and output the accuracy.
    '''
    model.eval()
    acc, count = 0, 0
    num_wrong_to_display = 5
    for idx, x in enumerate(dataloader):
        for equation in x:
            loc_equal_sign = equation.tolist().index(EQUAL_SIGN)
            loc_EOS = equation.tolist().index(EOS)
            input = equation[0:loc_equal_sign+1].view(1, -1).to(DEVICE)
            ans = equation[:loc_EOS+1].tolist()
            ans_pred = compute_sum(model, input)
            count += 1

            if ans == ans_pred.tolist():
                acc +=1
            else:
                if num_wrong_to_display > 0:
                    print(f'correct equation: {decode_equation(equation).replace("<pad>","")}')
                    print(f'predicted:        {decode_equation(ans_pred)}')
                    num_wrong_to_display -= 1
        if num_batch and idx>num_batch:
            break
    return acc/count

def what_is(question:str)->str:
    'function for computing the sum of two numbers with input in literal string format'
    pred = compute_sum(model, encode_equation(question, max_ndigits).view(1,-1))
    pred = decode_equation(pred)
    pred = pred[pred.index("=")+1:]
    return question+pred

# 2 Digit Addition

In [11]:
max_ndigits = 2
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits + 1 + 1
# where the 1s represent BOS, Plus sign, Equal sign, the extra digit in the sum, EOS, respectively.
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)
dataset_size = 10000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[1000, 2000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 3*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=50)

model_size: 271378, train_set_size: 7168


  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
100%|██████████| 28/28 [00:00<00:00, 67.97it/s]
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))


ep 0: train_loss: 2.50030, val_loss: 1.68551


100%|██████████| 28/28 [00:00<00:00, 69.15it/s]


ep 1: train_loss: 1.52802, val_loss: 1.29850


100%|██████████| 28/28 [00:00<00:00, 69.24it/s]


ep 2: train_loss: 1.23560, val_loss: 1.04376


100%|██████████| 28/28 [00:00<00:00, 69.31it/s]


ep 3: train_loss: 1.02151, val_loss: 0.89575


100%|██████████| 28/28 [00:00<00:00, 69.00it/s]


ep 4: train_loss: 0.92378, val_loss: 0.84775


100%|██████████| 28/28 [00:00<00:00, 68.75it/s]


ep 5: train_loss: 0.87335, val_loss: 0.81039


100%|██████████| 28/28 [00:00<00:00, 69.06it/s]


ep 6: train_loss: 0.83730, val_loss: 0.78523


100%|██████████| 28/28 [00:00<00:00, 68.38it/s]


ep 7: train_loss: 0.81336, val_loss: 0.74565


100%|██████████| 28/28 [00:00<00:00, 69.08it/s]


ep 8: train_loss: 0.77771, val_loss: 0.72891


100%|██████████| 28/28 [00:00<00:00, 68.22it/s]


ep 9: train_loss: 0.75249, val_loss: 0.71506


100%|██████████| 28/28 [00:00<00:00, 68.88it/s]


ep 10: train_loss: 0.72704, val_loss: 0.64269


100%|██████████| 28/28 [00:00<00:00, 69.43it/s]


ep 11: train_loss: 0.63846, val_loss: 0.47019


100%|██████████| 28/28 [00:00<00:00, 69.25it/s]


ep 12: train_loss: 0.44912, val_loss: 0.26273


100%|██████████| 28/28 [00:00<00:00, 69.34it/s]


ep 13: train_loss: 0.31334, val_loss: 0.16070


100%|██████████| 28/28 [00:00<00:00, 69.26it/s]


ep 14: train_loss: 0.22283, val_loss: 0.10100


100%|██████████| 28/28 [00:00<00:00, 69.47it/s]


ep 15: train_loss: 0.16839, val_loss: 0.06478


100%|██████████| 28/28 [00:00<00:00, 69.30it/s]


ep 16: train_loss: 0.12748, val_loss: 0.03534


100%|██████████| 28/28 [00:00<00:00, 69.57it/s]


ep 17: train_loss: 0.09619, val_loss: 0.01663


100%|██████████| 28/28 [00:00<00:00, 69.30it/s]


ep 18: train_loss: 0.07163, val_loss: 0.01034


100%|██████████| 28/28 [00:00<00:00, 69.35it/s]


ep 19: train_loss: 0.05693, val_loss: 0.00651


100%|██████████| 28/28 [00:00<00:00, 69.23it/s]


ep 20: train_loss: 0.05116, val_loss: 0.00334


100%|██████████| 28/28 [00:00<00:00, 69.35it/s]


ep 21: train_loss: 0.04267, val_loss: 0.00355


100%|██████████| 28/28 [00:00<00:00, 69.28it/s]


ep 22: train_loss: 0.03565, val_loss: 0.00313


100%|██████████| 28/28 [00:00<00:00, 69.25it/s]


ep 23: train_loss: 0.03339, val_loss: 0.00395


100%|██████████| 28/28 [00:00<00:00, 68.69it/s]


ep 24: train_loss: 0.03013, val_loss: 0.00219


100%|██████████| 28/28 [00:00<00:00, 69.02it/s]


ep 25: train_loss: 0.02725, val_loss: 0.00263


100%|██████████| 28/28 [00:00<00:00, 69.41it/s]


ep 26: train_loss: 0.02585, val_loss: 0.00189


100%|██████████| 28/28 [00:00<00:00, 69.13it/s]


ep 27: train_loss: 0.02414, val_loss: 0.00191


100%|██████████| 28/28 [00:00<00:00, 68.98it/s]


ep 28: train_loss: 0.02031, val_loss: 0.00120


100%|██████████| 28/28 [00:00<00:00, 68.07it/s]


ep 29: train_loss: 0.02168, val_loss: 0.00110


100%|██████████| 28/28 [00:00<00:00, 69.21it/s]


ep 30: train_loss: 0.01743, val_loss: 0.00170


100%|██████████| 28/28 [00:00<00:00, 69.33it/s]


ep 31: train_loss: 0.01855, val_loss: 0.00299


100%|██████████| 28/28 [00:00<00:00, 68.70it/s]


ep 32: train_loss: 0.01424, val_loss: 0.00272


100%|██████████| 28/28 [00:00<00:00, 69.33it/s]


ep 33: train_loss: 0.01330, val_loss: 0.00260


100%|██████████| 28/28 [00:00<00:00, 69.05it/s]


ep 34: train_loss: 0.01315, val_loss: 0.00038


100%|██████████| 28/28 [00:00<00:00, 69.36it/s]


ep 35: train_loss: 0.01269, val_loss: 0.00161


100%|██████████| 28/28 [00:00<00:00, 69.28it/s]


ep 36: train_loss: 0.01482, val_loss: 0.00026


100%|██████████| 28/28 [00:00<00:00, 69.39it/s]


ep 37: train_loss: 0.01232, val_loss: 0.00093


100%|██████████| 28/28 [00:00<00:00, 69.10it/s]


ep 38: train_loss: 0.00931, val_loss: 0.00194


100%|██████████| 28/28 [00:00<00:00, 69.38it/s]


ep 39: train_loss: 0.01202, val_loss: 0.00087


100%|██████████| 28/28 [00:00<00:00, 69.34it/s]


ep 40: train_loss: 0.01197, val_loss: 0.00020


100%|██████████| 28/28 [00:00<00:00, 69.37it/s]


ep 41: train_loss: 0.00932, val_loss: 0.00015


100%|██████████| 28/28 [00:00<00:00, 69.40it/s]


ep 42: train_loss: 0.00924, val_loss: 0.00061


100%|██████████| 28/28 [00:00<00:00, 69.19it/s]


ep 43: train_loss: 0.00988, val_loss: 0.00015


100%|██████████| 28/28 [00:00<00:00, 69.39it/s]


ep 44: train_loss: 0.00806, val_loss: 0.00015


100%|██████████| 28/28 [00:00<00:00, 69.24it/s]


ep 45: train_loss: 0.00643, val_loss: 0.00077


100%|██████████| 28/28 [00:00<00:00, 69.34it/s]


ep 46: train_loss: 0.00778, val_loss: 0.00245


100%|██████████| 28/28 [00:00<00:00, 69.15it/s]


ep 47: train_loss: 0.00688, val_loss: 0.00188


100%|██████████| 28/28 [00:00<00:00, 69.40it/s]


ep 48: train_loss: 0.00917, val_loss: 0.00018


100%|██████████| 28/28 [00:00<00:00, 69.15it/s]

ep 49: train_loss: 0.00858, val_loss: 0.00225





In [12]:
test_loss = validate(model, data_loaders.test_loader)
print('training set examples the model gives an incorrect result:')
train_acc = evaluate(model, data_loaders.train_loader, 20)
print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, data_loaders.test_loader)
print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, data_loaders.test_loader)
result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(result)

  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))


training set examples the model gives an incorrect result:
validataion set examples the model gives an incorrect result:
test set examples the model gives an incorrect result:
train_size: 7168, train_loss: 0.008579625598421054,
                val_loss: 0.002247950482342276, test_loss: 0.00010329616316084866,
                test_acc: 1.0, val_acc: 1.0, train_acc: 1.0
                


no incorrect results were observed

# 5 Digit Addition

In [15]:
max_ndigits = 5
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits +1 +1
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)

dataset_size = 200000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[10000, 20000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 5*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=70)

model_size: 272530, train_set_size: 170240


  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
ep: 0, train loss=1.606,lr=0.00030: 100%|██████████| 665/665 [00:10<00:00, 65.75it/s]


ep 0: train_loss: 2.07644, val_loss: 1.54197


ep: 1, train loss=1.298,lr=0.00061: 100%|██████████| 665/665 [00:10<00:00, 65.65it/s]


ep 1: train_loss: 1.47744, val_loss: 1.24604


ep: 2, train loss=0.321,lr=0.00091: 100%|██████████| 665/665 [00:10<00:00, 65.59it/s]


ep 2: train_loss: 0.73744, val_loss: 0.19745


ep: 3, train loss=0.070,lr=0.00122: 100%|██████████| 665/665 [00:10<00:00, 65.75it/s]


ep 3: train_loss: 0.09241, val_loss: 0.02036


ep: 4, train loss=0.035,lr=0.00153: 100%|██████████| 665/665 [00:10<00:00, 65.61it/s]


ep 4: train_loss: 0.04839, val_loss: 0.01729


ep: 5, train loss=0.025,lr=0.00140: 100%|██████████| 665/665 [00:10<00:00, 65.64it/s]


ep 5: train_loss: 0.03925, val_loss: 0.01674


ep: 6, train loss=0.031,lr=0.00130: 100%|██████████| 665/665 [00:10<00:00, 65.82it/s]


ep 6: train_loss: 0.03185, val_loss: 0.01425


ep: 7, train loss=0.034,lr=0.00121: 100%|██████████| 665/665 [00:10<00:00, 65.58it/s]


ep 7: train_loss: 0.02839, val_loss: 0.01523


ep: 8, train loss=0.024,lr=0.00114: 100%|██████████| 665/665 [00:10<00:00, 65.47it/s]


ep 8: train_loss: 0.02501, val_loss: 0.01177


ep: 9, train loss=0.023,lr=0.00108: 100%|██████████| 665/665 [00:10<00:00, 65.76it/s]


ep 9: train_loss: 0.02232, val_loss: 0.01145


ep: 10, train loss=0.026,lr=0.00103: 100%|██████████| 665/665 [00:10<00:00, 65.57it/s]


ep 10: train_loss: 0.02107, val_loss: 0.00966


ep: 11, train loss=0.013,lr=0.00099: 100%|██████████| 665/665 [00:10<00:00, 65.80it/s]


ep 11: train_loss: 0.01902, val_loss: 0.00929


ep: 12, train loss=0.015,lr=0.00095: 100%|██████████| 665/665 [00:10<00:00, 65.79it/s]


ep 12: train_loss: 0.01809, val_loss: 0.00882


ep: 13, train loss=0.020,lr=0.00092: 100%|██████████| 665/665 [00:10<00:00, 64.02it/s]


ep 13: train_loss: 0.01716, val_loss: 0.00806


ep: 14, train loss=0.032,lr=0.00089: 100%|██████████| 665/665 [00:10<00:00, 65.47it/s]


ep 14: train_loss: 0.01634, val_loss: 0.00899


ep: 15, train loss=0.017,lr=0.00086: 100%|██████████| 665/665 [00:10<00:00, 65.36it/s]


ep 15: train_loss: 0.01637, val_loss: 0.00901


ep: 16, train loss=0.012,lr=0.00083: 100%|██████████| 665/665 [00:10<00:00, 65.41it/s]


ep 16: train_loss: 0.01544, val_loss: 0.00883


ep: 17, train loss=0.010,lr=0.00081: 100%|██████████| 665/665 [00:10<00:00, 65.52it/s]


ep 17: train_loss: 0.01457, val_loss: 0.00846


ep: 18, train loss=0.012,lr=0.00079: 100%|██████████| 665/665 [00:10<00:00, 65.16it/s]


ep 18: train_loss: 0.01426, val_loss: 0.00728


ep: 19, train loss=0.018,lr=0.00077: 100%|██████████| 665/665 [00:10<00:00, 65.35it/s]


ep 19: train_loss: 0.01356, val_loss: 0.00768


ep: 20, train loss=0.015,lr=0.00075: 100%|██████████| 665/665 [00:10<00:00, 65.49it/s]


ep 20: train_loss: 0.01340, val_loss: 0.00794


ep: 21, train loss=0.005,lr=0.00073: 100%|██████████| 665/665 [00:10<00:00, 65.26it/s]


ep 21: train_loss: 0.01323, val_loss: 0.00800


ep: 22, train loss=0.012,lr=0.00071: 100%|██████████| 665/665 [00:10<00:00, 65.39it/s]


ep 22: train_loss: 0.01304, val_loss: 0.00715


ep: 23, train loss=0.003,lr=0.00070: 100%|██████████| 665/665 [00:10<00:00, 65.42it/s]


ep 23: train_loss: 0.01234, val_loss: 0.00651


ep: 24, train loss=0.008,lr=0.00069: 100%|██████████| 665/665 [00:10<00:00, 65.44it/s]


ep 24: train_loss: 0.01221, val_loss: 0.00707


ep: 25, train loss=0.008,lr=0.00067: 100%|██████████| 665/665 [00:10<00:00, 65.44it/s]


ep 25: train_loss: 0.01183, val_loss: 0.00620


ep: 26, train loss=0.015,lr=0.00066: 100%|██████████| 665/665 [00:10<00:00, 65.57it/s]


ep 26: train_loss: 0.01114, val_loss: 0.00626


ep: 27, train loss=0.008,lr=0.00065: 100%|██████████| 665/665 [00:10<00:00, 65.41it/s]


ep 27: train_loss: 0.01064, val_loss: 0.00535


ep: 28, train loss=0.008,lr=0.00064: 100%|██████████| 665/665 [00:10<00:00, 65.43it/s]


ep 28: train_loss: 0.00907, val_loss: 0.00360


ep: 29, train loss=0.019,lr=0.00063: 100%|██████████| 665/665 [00:10<00:00, 65.54it/s]


ep 29: train_loss: 0.00781, val_loss: 0.00331


ep: 30, train loss=0.005,lr=0.00062: 100%|██████████| 665/665 [00:10<00:00, 65.18it/s]


ep 30: train_loss: 0.00735, val_loss: 0.00326


ep: 31, train loss=0.014,lr=0.00061: 100%|██████████| 665/665 [00:10<00:00, 65.35it/s]


ep 31: train_loss: 0.00688, val_loss: 0.00323


ep: 32, train loss=0.004,lr=0.00060: 100%|██████████| 665/665 [00:10<00:00, 65.45it/s]


ep 32: train_loss: 0.00681, val_loss: 0.00312


ep: 33, train loss=0.007,lr=0.00059: 100%|██████████| 665/665 [00:10<00:00, 65.61it/s]


ep 33: train_loss: 0.00625, val_loss: 0.00322


ep: 34, train loss=0.009,lr=0.00058: 100%|██████████| 665/665 [00:10<00:00, 65.63it/s]


ep 34: train_loss: 0.00638, val_loss: 0.00357


ep: 35, train loss=0.002,lr=0.00057: 100%|██████████| 665/665 [00:10<00:00, 65.47it/s]


ep 35: train_loss: 0.00602, val_loss: 0.00207


ep: 36, train loss=0.001,lr=0.00056: 100%|██████████| 665/665 [00:10<00:00, 65.34it/s]


ep 36: train_loss: 0.00471, val_loss: 0.00084


ep: 37, train loss=0.009,lr=0.00056: 100%|██████████| 665/665 [00:10<00:00, 65.43it/s]


ep 37: train_loss: 0.00366, val_loss: 0.00046


ep: 38, train loss=0.001,lr=0.00055: 100%|██████████| 665/665 [00:10<00:00, 64.24it/s]


ep 38: train_loss: 0.00331, val_loss: 0.00043


ep: 39, train loss=0.001,lr=0.00054: 100%|██████████| 665/665 [00:10<00:00, 64.10it/s]


ep 39: train_loss: 0.00285, val_loss: 0.00030


ep: 40, train loss=0.001,lr=0.00054: 100%|██████████| 665/665 [00:10<00:00, 64.24it/s]


ep 40: train_loss: 0.00259, val_loss: 0.00020


ep: 41, train loss=0.001,lr=0.00053: 100%|██████████| 665/665 [00:10<00:00, 64.36it/s]


ep 41: train_loss: 0.00253, val_loss: 0.00019


ep: 42, train loss=0.001,lr=0.00052: 100%|██████████| 665/665 [00:10<00:00, 63.96it/s]


ep 42: train_loss: 0.00217, val_loss: 0.00010


ep: 43, train loss=0.003,lr=0.00052: 100%|██████████| 665/665 [00:10<00:00, 64.23it/s]


ep 43: train_loss: 0.00258, val_loss: 0.00013


ep: 44, train loss=0.001,lr=0.00051: 100%|██████████| 665/665 [00:10<00:00, 64.03it/s]


ep 44: train_loss: 0.00214, val_loss: 0.00013


ep: 45, train loss=0.003,lr=0.00051: 100%|██████████| 665/665 [00:10<00:00, 64.02it/s]


ep 45: train_loss: 0.00215, val_loss: 0.00014


ep: 46, train loss=0.002,lr=0.00050: 100%|██████████| 665/665 [00:10<00:00, 65.59it/s]


ep 46: train_loss: 0.00193, val_loss: 0.00010


ep: 47, train loss=0.000,lr=0.00049: 100%|██████████| 665/665 [00:10<00:00, 65.57it/s]


ep 47: train_loss: 0.00204, val_loss: 0.00006


ep: 48, train loss=0.000,lr=0.00049: 100%|██████████| 665/665 [00:10<00:00, 65.83it/s]


ep 48: train_loss: 0.00183, val_loss: 0.00005


ep: 49, train loss=0.001,lr=0.00048: 100%|██████████| 665/665 [00:10<00:00, 65.67it/s]


ep 49: train_loss: 0.00165, val_loss: 0.00004


ep: 50, train loss=0.001,lr=0.00048: 100%|██████████| 665/665 [00:10<00:00, 65.55it/s]


ep 50: train_loss: 0.00173, val_loss: 0.00006


ep: 51, train loss=0.003,lr=0.00048: 100%|██████████| 665/665 [00:10<00:00, 65.55it/s]


ep 51: train_loss: 0.00163, val_loss: 0.00003


ep: 52, train loss=0.001,lr=0.00047: 100%|██████████| 665/665 [00:10<00:00, 65.67it/s]


ep 52: train_loss: 0.00165, val_loss: 0.00005


ep: 53, train loss=0.000,lr=0.00047: 100%|██████████| 665/665 [00:10<00:00, 65.57it/s]


ep 53: train_loss: 0.00148, val_loss: 0.00002


ep: 54, train loss=0.002,lr=0.00046: 100%|██████████| 665/665 [00:10<00:00, 65.84it/s]


ep 54: train_loss: 0.00168, val_loss: 0.00001


ep: 55, train loss=0.001,lr=0.00046: 100%|██████████| 665/665 [00:10<00:00, 65.74it/s]


ep 55: train_loss: 0.00155, val_loss: 0.00002


ep: 56, train loss=0.001,lr=0.00045: 100%|██████████| 665/665 [00:10<00:00, 65.73it/s]


ep 56: train_loss: 0.00137, val_loss: 0.00001


ep: 57, train loss=0.000,lr=0.00045: 100%|██████████| 665/665 [00:10<00:00, 62.33it/s]


ep 57: train_loss: 0.00129, val_loss: 0.00001


ep: 58, train loss=0.004,lr=0.00045: 100%|██████████| 665/665 [00:10<00:00, 65.04it/s]


ep 58: train_loss: 0.00130, val_loss: 0.00001


ep: 59, train loss=0.000,lr=0.00044: 100%|██████████| 665/665 [00:10<00:00, 64.59it/s]


ep 59: train_loss: 0.00128, val_loss: 0.00006


ep: 60, train loss=0.001,lr=0.00044: 100%|██████████| 665/665 [00:10<00:00, 65.10it/s]


ep 60: train_loss: 0.00137, val_loss: 0.00001


ep: 61, train loss=0.001,lr=0.00044: 100%|██████████| 665/665 [00:10<00:00, 64.94it/s]


ep 61: train_loss: 0.00133, val_loss: 0.00000


ep: 62, train loss=0.000,lr=0.00043: 100%|██████████| 665/665 [00:10<00:00, 64.43it/s]


ep 62: train_loss: 0.00123, val_loss: 0.00000


ep: 63, train loss=0.002,lr=0.00043: 100%|██████████| 665/665 [00:10<00:00, 64.14it/s]


ep 63: train_loss: 0.00121, val_loss: 0.00001


ep: 64, train loss=0.000,lr=0.00043: 100%|██████████| 665/665 [00:10<00:00, 63.89it/s]


ep 64: train_loss: 0.00113, val_loss: 0.00001


ep: 65, train loss=0.001,lr=0.00042: 100%|██████████| 665/665 [00:10<00:00, 64.48it/s]


ep 65: train_loss: 0.00132, val_loss: 0.00001


ep: 66, train loss=0.000,lr=0.00042: 100%|██████████| 665/665 [00:10<00:00, 64.34it/s]


ep 66: train_loss: 0.00107, val_loss: 0.00001


ep: 67, train loss=0.000,lr=0.00042: 100%|██████████| 665/665 [00:10<00:00, 64.55it/s]


ep 67: train_loss: 0.00122, val_loss: 0.00001


ep: 68, train loss=0.000,lr=0.00041: 100%|██████████| 665/665 [00:10<00:00, 64.81it/s]


ep 68: train_loss: 0.00110, val_loss: 0.00000


ep: 69, train loss=0.013,lr=0.00041: 100%|██████████| 665/665 [00:10<00:00, 64.08it/s]


ep 69: train_loss: 0.00108, val_loss: 0.00001


In [16]:
test_loss = validate(model, data_loaders.test_loader)
print('training set examples the model gives an incorrect result:')
train_acc = evaluate(model, data_loaders.train_loader, 20)
print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, data_loaders.test_loader)
print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, data_loaders.test_loader)
result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(result)

  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))


training set examples the model gives an incorrect result:
validataion set examples the model gives an incorrect result:
test set examples the model gives an incorrect result:
train_size: 170240, train_loss: 0.0010843980230736334,
                val_loss: 1.0287986023627127e-05, test_loss: 3.6189080926213803e-06,
                test_acc: 1.0, val_acc: 1.0, train_acc: 1.0
                


no incorrect results were observed

# 10 Digit Addition

In [19]:
max_ndigits = 10
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits +1 +1
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)
dataset_size = 300000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[10000, 20000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 5*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=150)

model_size: 274450, train_set_size: 270080


  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
ep: 0, train loss=1.917,lr=0.00024: 100%|██████████| 1055/1055 [00:19<00:00, 53.42it/s]
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))


ep 0: train_loss: 2.22305, val_loss: 1.88622


ep: 1, train loss=1.678,lr=0.00049: 100%|██████████| 1055/1055 [00:19<00:00, 53.44it/s]


ep 1: train_loss: 1.81938, val_loss: 1.64982


ep: 2, train loss=0.578,lr=0.00073: 100%|██████████| 1055/1055 [00:19<00:00, 52.98it/s]


ep 2: train_loss: 1.12391, val_loss: 0.45079


ep: 3, train loss=0.172,lr=0.00097: 100%|██████████| 1055/1055 [00:19<00:00, 53.05it/s]


ep 3: train_loss: 0.30092, val_loss: 0.08800


ep: 4, train loss=0.118,lr=0.00122: 100%|██████████| 1055/1055 [00:19<00:00, 52.87it/s]


ep 4: train_loss: 0.14321, val_loss: 0.07182


ep: 5, train loss=0.045,lr=0.00111: 100%|██████████| 1055/1055 [00:19<00:00, 52.81it/s]


ep 5: train_loss: 0.08933, val_loss: 0.02312


ep: 6, train loss=0.112,lr=0.00103: 100%|██████████| 1055/1055 [00:19<00:00, 53.10it/s]


ep 6: train_loss: 0.04062, val_loss: 0.01445


ep: 7, train loss=0.021,lr=0.00096: 100%|██████████| 1055/1055 [00:19<00:00, 53.55it/s]


ep 7: train_loss: 0.03002, val_loss: 0.01244


ep: 8, train loss=0.019,lr=0.00091: 100%|██████████| 1055/1055 [00:19<00:00, 53.60it/s]


ep 8: train_loss: 0.02380, val_loss: 0.00946


ep: 9, train loss=0.012,lr=0.00086: 100%|██████████| 1055/1055 [00:19<00:00, 53.67it/s]


ep 9: train_loss: 0.01963, val_loss: 0.00828


ep: 10, train loss=0.016,lr=0.00082: 100%|██████████| 1055/1055 [00:19<00:00, 53.49it/s]


ep 10: train_loss: 0.01762, val_loss: 0.00787


ep: 11, train loss=0.014,lr=0.00079: 100%|██████████| 1055/1055 [00:19<00:00, 53.57it/s]


ep 11: train_loss: 0.01551, val_loss: 0.00801


ep: 12, train loss=0.013,lr=0.00075: 100%|██████████| 1055/1055 [00:19<00:00, 53.68it/s]


ep 12: train_loss: 0.01438, val_loss: 0.00726


ep: 13, train loss=0.021,lr=0.00073: 100%|██████████| 1055/1055 [00:19<00:00, 52.79it/s]


ep 13: train_loss: 0.01352, val_loss: 0.00584


ep: 14, train loss=0.014,lr=0.00070: 100%|██████████| 1055/1055 [00:19<00:00, 52.90it/s]


ep 14: train_loss: 0.01160, val_loss: 0.00518


ep: 15, train loss=0.008,lr=0.00068: 100%|██████████| 1055/1055 [00:19<00:00, 52.94it/s]


ep 15: train_loss: 0.01054, val_loss: 0.00456


ep: 16, train loss=0.008,lr=0.00066: 100%|██████████| 1055/1055 [00:19<00:00, 52.77it/s]


ep 16: train_loss: 0.01004, val_loss: 0.00439


ep: 17, train loss=0.014,lr=0.00064: 100%|██████████| 1055/1055 [00:19<00:00, 52.80it/s]


ep 17: train_loss: 0.00961, val_loss: 0.00498


ep: 18, train loss=0.009,lr=0.00062: 100%|██████████| 1055/1055 [00:19<00:00, 52.99it/s]


ep 18: train_loss: 0.00895, val_loss: 0.00433


ep: 19, train loss=0.008,lr=0.00061: 100%|██████████| 1055/1055 [00:19<00:00, 52.91it/s]


ep 19: train_loss: 0.00912, val_loss: 0.00456


ep: 20, train loss=0.008,lr=0.00059: 100%|██████████| 1055/1055 [00:19<00:00, 52.99it/s]


ep 20: train_loss: 0.00838, val_loss: 0.00459


ep: 21, train loss=0.008,lr=0.00058: 100%|██████████| 1055/1055 [00:19<00:00, 52.96it/s]


ep 21: train_loss: 0.00807, val_loss: 0.00423


ep: 22, train loss=0.010,lr=0.00057: 100%|██████████| 1055/1055 [00:20<00:00, 52.71it/s]


ep 22: train_loss: 0.00779, val_loss: 0.00395


ep: 23, train loss=0.008,lr=0.00056: 100%|██████████| 1055/1055 [00:20<00:00, 52.66it/s]


ep 23: train_loss: 0.00763, val_loss: 0.00400


ep: 24, train loss=0.013,lr=0.00054: 100%|██████████| 1055/1055 [00:19<00:00, 52.99it/s]


ep 24: train_loss: 0.00735, val_loss: 0.00375


ep: 25, train loss=0.007,lr=0.00053: 100%|██████████| 1055/1055 [00:19<00:00, 52.90it/s]


ep 25: train_loss: 0.00740, val_loss: 0.00361


ep: 26, train loss=0.007,lr=0.00052: 100%|██████████| 1055/1055 [00:19<00:00, 52.91it/s]


ep 26: train_loss: 0.00686, val_loss: 0.00362


ep: 27, train loss=0.004,lr=0.00051: 100%|██████████| 1055/1055 [00:20<00:00, 52.52it/s]


ep 27: train_loss: 0.00674, val_loss: 0.00353


ep: 28, train loss=0.008,lr=0.00051: 100%|██████████| 1055/1055 [00:20<00:00, 52.55it/s]


ep 28: train_loss: 0.00665, val_loss: 0.00374


ep: 29, train loss=0.005,lr=0.00050: 100%|██████████| 1055/1055 [00:20<00:00, 52.60it/s]


ep 29: train_loss: 0.00639, val_loss: 0.00355


ep: 30, train loss=0.005,lr=0.00049: 100%|██████████| 1055/1055 [00:19<00:00, 52.93it/s]


ep 30: train_loss: 0.00643, val_loss: 0.00316


ep: 31, train loss=0.004,lr=0.00048: 100%|██████████| 1055/1055 [00:19<00:00, 52.86it/s]


ep 31: train_loss: 0.00637, val_loss: 0.00357


ep: 32, train loss=0.005,lr=0.00047: 100%|██████████| 1055/1055 [00:19<00:00, 52.84it/s]


ep 32: train_loss: 0.00609, val_loss: 0.00314


ep: 33, train loss=0.008,lr=0.00047: 100%|██████████| 1055/1055 [00:20<00:00, 52.73it/s]


ep 33: train_loss: 0.00598, val_loss: 0.00289


ep: 34, train loss=0.009,lr=0.00046: 100%|██████████| 1055/1055 [00:20<00:00, 52.42it/s]


ep 34: train_loss: 0.00588, val_loss: 0.00283


ep: 35, train loss=0.003,lr=0.00045: 100%|██████████| 1055/1055 [00:20<00:00, 52.71it/s]


ep 35: train_loss: 0.00565, val_loss: 0.00259


ep: 36, train loss=0.003,lr=0.00045: 100%|██████████| 1055/1055 [00:19<00:00, 52.98it/s]


ep 36: train_loss: 0.00551, val_loss: 0.00248


ep: 37, train loss=0.004,lr=0.00044: 100%|██████████| 1055/1055 [00:20<00:00, 52.71it/s]


ep 37: train_loss: 0.00515, val_loss: 0.00290


ep: 38, train loss=0.008,lr=0.00044: 100%|██████████| 1055/1055 [00:20<00:00, 52.66it/s]


ep 38: train_loss: 0.00506, val_loss: 0.00285


ep: 39, train loss=0.008,lr=0.00043: 100%|██████████| 1055/1055 [00:19<00:00, 52.81it/s]


ep 39: train_loss: 0.00509, val_loss: 0.00205


ep: 40, train loss=0.007,lr=0.00043: 100%|██████████| 1055/1055 [00:20<00:00, 52.23it/s]


ep 40: train_loss: 0.00474, val_loss: 0.00219


ep: 41, train loss=0.003,lr=0.00042: 100%|██████████| 1055/1055 [00:20<00:00, 52.52it/s]


ep 41: train_loss: 0.00436, val_loss: 0.00237


ep: 42, train loss=0.003,lr=0.00042: 100%|██████████| 1055/1055 [00:19<00:00, 52.83it/s]


ep 42: train_loss: 0.00431, val_loss: 0.00183


ep: 43, train loss=0.004,lr=0.00041: 100%|██████████| 1055/1055 [00:20<00:00, 52.60it/s]


ep 43: train_loss: 0.00410, val_loss: 0.00167


ep: 44, train loss=0.005,lr=0.00041: 100%|██████████| 1055/1055 [00:19<00:00, 52.75it/s]


ep 44: train_loss: 0.00403, val_loss: 0.00191


ep: 45, train loss=0.001,lr=0.00040: 100%|██████████| 1055/1055 [00:20<00:00, 51.70it/s]


ep 45: train_loss: 0.00409, val_loss: 0.00166


ep: 46, train loss=0.002,lr=0.00040: 100%|██████████| 1055/1055 [00:20<00:00, 51.91it/s]


ep 46: train_loss: 0.00385, val_loss: 0.00175


ep: 47, train loss=0.003,lr=0.00039: 100%|██████████| 1055/1055 [00:19<00:00, 52.79it/s]


ep 47: train_loss: 0.00385, val_loss: 0.00209


ep: 48, train loss=0.009,lr=0.00039: 100%|██████████| 1055/1055 [00:20<00:00, 52.63it/s]


ep 48: train_loss: 0.00370, val_loss: 0.00164


ep: 49, train loss=0.003,lr=0.00038: 100%|██████████| 1055/1055 [00:20<00:00, 52.50it/s]


ep 49: train_loss: 0.00380, val_loss: 0.00177


ep: 50, train loss=0.005,lr=0.00038: 100%|██████████| 1055/1055 [00:19<00:00, 52.81it/s]


ep 50: train_loss: 0.00379, val_loss: 0.00158


ep: 51, train loss=0.003,lr=0.00038: 100%|██████████| 1055/1055 [00:19<00:00, 52.81it/s]


ep 51: train_loss: 0.00361, val_loss: 0.00167


ep: 52, train loss=0.007,lr=0.00037: 100%|██████████| 1055/1055 [00:20<00:00, 52.74it/s]


ep 52: train_loss: 0.00361, val_loss: 0.00164


ep: 53, train loss=0.003,lr=0.00037: 100%|██████████| 1055/1055 [00:19<00:00, 52.78it/s]


ep 53: train_loss: 0.00353, val_loss: 0.00157


ep: 54, train loss=0.005,lr=0.00037: 100%|██████████| 1055/1055 [00:19<00:00, 52.87it/s]


ep 54: train_loss: 0.00341, val_loss: 0.00176


ep: 55, train loss=0.004,lr=0.00036: 100%|██████████| 1055/1055 [00:19<00:00, 52.88it/s]


ep 55: train_loss: 0.00340, val_loss: 0.00143


ep: 56, train loss=0.006,lr=0.00036: 100%|██████████| 1055/1055 [00:19<00:00, 52.85it/s]


ep 56: train_loss: 0.00335, val_loss: 0.00164


ep: 57, train loss=0.002,lr=0.00036: 100%|██████████| 1055/1055 [00:19<00:00, 53.11it/s]


ep 57: train_loss: 0.00321, val_loss: 0.00135


ep: 58, train loss=0.002,lr=0.00035: 100%|██████████| 1055/1055 [00:20<00:00, 52.69it/s]


ep 58: train_loss: 0.00348, val_loss: 0.00139


ep: 59, train loss=0.001,lr=0.00035: 100%|██████████| 1055/1055 [00:20<00:00, 52.62it/s]


ep 59: train_loss: 0.00327, val_loss: 0.00180


ep: 60, train loss=0.002,lr=0.00035: 100%|██████████| 1055/1055 [00:19<00:00, 52.84it/s]


ep 60: train_loss: 0.00332, val_loss: 0.00133


ep: 61, train loss=0.004,lr=0.00035: 100%|██████████| 1055/1055 [00:19<00:00, 52.81it/s]


ep 61: train_loss: 0.00322, val_loss: 0.00136


ep: 62, train loss=0.003,lr=0.00034: 100%|██████████| 1055/1055 [00:19<00:00, 52.96it/s]


ep 62: train_loss: 0.00318, val_loss: 0.00126


ep: 63, train loss=0.004,lr=0.00034: 100%|██████████| 1055/1055 [00:19<00:00, 53.04it/s]


ep 63: train_loss: 0.00320, val_loss: 0.00149


ep: 64, train loss=0.003,lr=0.00034: 100%|██████████| 1055/1055 [00:20<00:00, 52.66it/s]


ep 64: train_loss: 0.00303, val_loss: 0.00146


ep: 65, train loss=0.001,lr=0.00033: 100%|██████████| 1055/1055 [00:20<00:00, 52.69it/s]


ep 65: train_loss: 0.00300, val_loss: 0.00127


ep: 66, train loss=0.004,lr=0.00033: 100%|██████████| 1055/1055 [00:19<00:00, 52.86it/s]


ep 66: train_loss: 0.00301, val_loss: 0.00117


ep: 67, train loss=0.003,lr=0.00033: 100%|██████████| 1055/1055 [00:20<00:00, 52.08it/s]


ep 67: train_loss: 0.00301, val_loss: 0.00165


ep: 68, train loss=0.005,lr=0.00033: 100%|██████████| 1055/1055 [00:20<00:00, 52.28it/s]


ep 68: train_loss: 0.00296, val_loss: 0.00143


ep: 69, train loss=0.002,lr=0.00033: 100%|██████████| 1055/1055 [00:20<00:00, 52.30it/s]


ep 69: train_loss: 0.00318, val_loss: 0.00123


ep: 70, train loss=0.002,lr=0.00032: 100%|██████████| 1055/1055 [00:20<00:00, 51.86it/s]


ep 70: train_loss: 0.00282, val_loss: 0.00120


ep: 71, train loss=0.003,lr=0.00032: 100%|██████████| 1055/1055 [00:20<00:00, 52.22it/s]


ep 71: train_loss: 0.00287, val_loss: 0.00118


ep: 72, train loss=0.002,lr=0.00032: 100%|██████████| 1055/1055 [00:20<00:00, 52.45it/s]


ep 72: train_loss: 0.00283, val_loss: 0.00127


ep: 73, train loss=0.005,lr=0.00032: 100%|██████████| 1055/1055 [00:20<00:00, 51.83it/s]


ep 73: train_loss: 0.00284, val_loss: 0.00136


ep: 74, train loss=0.002,lr=0.00031: 100%|██████████| 1055/1055 [00:19<00:00, 53.00it/s]


ep 74: train_loss: 0.00278, val_loss: 0.00123


ep: 75, train loss=0.005,lr=0.00031: 100%|██████████| 1055/1055 [00:19<00:00, 53.67it/s]


ep 75: train_loss: 0.00284, val_loss: 0.00132


ep: 76, train loss=0.004,lr=0.00031: 100%|██████████| 1055/1055 [00:19<00:00, 53.57it/s]


ep 76: train_loss: 0.00281, val_loss: 0.00117


ep: 77, train loss=0.002,lr=0.00031: 100%|██████████| 1055/1055 [00:19<00:00, 53.56it/s]


ep 77: train_loss: 0.00273, val_loss: 0.00131


ep: 78, train loss=0.002,lr=0.00031: 100%|██████████| 1055/1055 [00:19<00:00, 53.59it/s]


ep 78: train_loss: 0.00276, val_loss: 0.00165


ep: 79, train loss=0.004,lr=0.00030: 100%|██████████| 1055/1055 [00:19<00:00, 53.42it/s]


ep 79: train_loss: 0.00271, val_loss: 0.00145


ep: 80, train loss=0.004,lr=0.00030: 100%|██████████| 1055/1055 [00:19<00:00, 53.63it/s]


ep 80: train_loss: 0.00264, val_loss: 0.00106


ep: 81, train loss=0.001,lr=0.00030: 100%|██████████| 1055/1055 [00:19<00:00, 53.53it/s]


ep 81: train_loss: 0.00277, val_loss: 0.00130


ep: 82, train loss=0.003,lr=0.00030: 100%|██████████| 1055/1055 [00:19<00:00, 53.00it/s]


ep 82: train_loss: 0.00266, val_loss: 0.00123


ep: 83, train loss=0.002,lr=0.00030: 100%|██████████| 1055/1055 [00:20<00:00, 52.47it/s]


ep 83: train_loss: 0.00266, val_loss: 0.00116


ep: 84, train loss=0.001,lr=0.00030: 100%|██████████| 1055/1055 [00:20<00:00, 52.40it/s]


ep 84: train_loss: 0.00268, val_loss: 0.00105


ep: 89, train loss=0.000,lr=0.00029: 100%|██████████| 1055/1055 [00:19<00:00, 53.67it/s]


ep 89: train_loss: 0.00253, val_loss: 0.00100


ep: 90, train loss=0.001,lr=0.00029: 100%|██████████| 1055/1055 [00:19<00:00, 53.82it/s]


ep 90: train_loss: 0.00262, val_loss: 0.00123


ep: 91, train loss=0.004,lr=0.00028: 100%|██████████| 1055/1055 [00:20<00:00, 52.64it/s]


ep 91: train_loss: 0.00245, val_loss: 0.00111


ep: 92, train loss=0.002,lr=0.00028: 100%|██████████| 1055/1055 [00:19<00:00, 53.52it/s]


ep 92: train_loss: 0.00252, val_loss: 0.00121


ep: 93, train loss=0.001,lr=0.00028: 100%|██████████| 1055/1055 [00:19<00:00, 53.49it/s]


ep 93: train_loss: 0.00243, val_loss: 0.00111


ep: 94, train loss=0.001,lr=0.00028: 100%|██████████| 1055/1055 [00:19<00:00, 53.49it/s]


ep 94: train_loss: 0.00268, val_loss: 0.00114


ep: 95, train loss=0.001,lr=0.00028: 100%|██████████| 1055/1055 [00:19<00:00, 53.66it/s]


ep 95: train_loss: 0.00233, val_loss: 0.00142


ep: 96, train loss=0.001,lr=0.00028: 100%|██████████| 1055/1055 [00:19<00:00, 53.54it/s]


ep 96: train_loss: 0.00237, val_loss: 0.00113


ep: 97, train loss=0.002,lr=0.00027: 100%|██████████| 1055/1055 [00:19<00:00, 53.63it/s]


ep 97: train_loss: 0.00241, val_loss: 0.00096


ep: 98, train loss=0.005,lr=0.00027: 100%|██████████| 1055/1055 [00:19<00:00, 53.58it/s]


ep 98: train_loss: 0.00244, val_loss: 0.00124


ep: 99, train loss=0.003,lr=0.00027: 100%|██████████| 1055/1055 [00:20<00:00, 52.59it/s]


ep 99: train_loss: 0.00237, val_loss: 0.00124


ep: 100, train loss=0.002,lr=0.00027: 100%|██████████| 1055/1055 [00:20<00:00, 52.39it/s]


ep 100: train_loss: 0.00242, val_loss: 0.00108


ep: 101, train loss=0.002,lr=0.00027: 100%|██████████| 1055/1055 [00:20<00:00, 52.62it/s]


ep 101: train_loss: 0.00239, val_loss: 0.00135


ep: 102, train loss=0.002,lr=0.00027: 100%|██████████| 1055/1055 [00:19<00:00, 53.19it/s]


ep 102: train_loss: 0.00234, val_loss: 0.00117


ep: 103, train loss=0.002,lr=0.00027: 100%|██████████| 1055/1055 [00:19<00:00, 53.50it/s]


ep 103: train_loss: 0.00232, val_loss: 0.00099


ep: 104, train loss=0.001,lr=0.00027: 100%|██████████| 1055/1055 [00:19<00:00, 53.59it/s]


ep 104: train_loss: 0.00229, val_loss: 0.00114


ep: 105, train loss=0.001,lr=0.00026: 100%|██████████| 1055/1055 [00:19<00:00, 53.51it/s]


ep 105: train_loss: 0.00233, val_loss: 0.00117


ep: 106, train loss=0.002,lr=0.00026: 100%|██████████| 1055/1055 [00:19<00:00, 53.46it/s]


ep 106: train_loss: 0.00222, val_loss: 0.00105


ep: 107, train loss=0.001,lr=0.00026: 100%|██████████| 1055/1055 [00:19<00:00, 53.67it/s]


ep 107: train_loss: 0.00226, val_loss: 0.00112


ep: 108, train loss=0.004,lr=0.00026: 100%|██████████| 1055/1055 [00:19<00:00, 53.47it/s]


ep 108: train_loss: 0.00220, val_loss: 0.00089


ep: 109, train loss=0.001,lr=0.00026: 100%|██████████| 1055/1055 [00:19<00:00, 53.46it/s]


ep 109: train_loss: 0.00228, val_loss: 0.00093


ep: 110, train loss=0.001,lr=0.00026: 100%|██████████| 1055/1055 [00:19<00:00, 53.64it/s]


ep 110: train_loss: 0.00224, val_loss: 0.00093


ep: 111, train loss=0.002,lr=0.00026: 100%|██████████| 1055/1055 [00:19<00:00, 53.63it/s]


ep 111: train_loss: 0.00223, val_loss: 0.00083


ep: 112, train loss=0.001,lr=0.00026: 100%|██████████| 1055/1055 [00:20<00:00, 52.58it/s]


ep 112: train_loss: 0.00225, val_loss: 0.00083


ep: 113, train loss=0.004,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 52.99it/s]


ep 113: train_loss: 0.00222, val_loss: 0.00115


ep: 114, train loss=0.002,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.51it/s]


ep 114: train_loss: 0.00223, val_loss: 0.00107


ep: 115, train loss=0.003,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.55it/s]


ep 115: train_loss: 0.00214, val_loss: 0.00097


ep: 116, train loss=0.001,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.62it/s]


ep 116: train_loss: 0.00205, val_loss: 0.00111


ep: 117, train loss=0.001,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.23it/s]


ep 117: train_loss: 0.00223, val_loss: 0.00091


ep: 118, train loss=0.002,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.44it/s]


ep 118: train_loss: 0.00217, val_loss: 0.00089


ep: 119, train loss=0.001,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.50it/s]


ep 119: train_loss: 0.00214, val_loss: 0.00073


ep: 120, train loss=0.005,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.10it/s]


ep 120: train_loss: 0.00213, val_loss: 0.00081


ep: 121, train loss=0.008,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.06it/s]


ep 121: train_loss: 0.00203, val_loss: 0.00107


ep: 122, train loss=0.001,lr=0.00025: 100%|██████████| 1055/1055 [00:19<00:00, 53.11it/s]


ep 122: train_loss: 0.00220, val_loss: 0.00074


ep: 123, train loss=0.002,lr=0.00024: 100%|██████████| 1055/1055 [00:20<00:00, 52.31it/s]


ep 123: train_loss: 0.00204, val_loss: 0.00092


ep: 124, train loss=0.002,lr=0.00024: 100%|██████████| 1055/1055 [00:20<00:00, 52.50it/s]


ep 124: train_loss: 0.00197, val_loss: 0.00074


ep: 125, train loss=0.000,lr=0.00024: 100%|██████████| 1055/1055 [00:19<00:00, 52.79it/s]


ep 125: train_loss: 0.00197, val_loss: 0.00075


ep: 126, train loss=0.003,lr=0.00024: 100%|██████████| 1055/1055 [00:20<00:00, 52.40it/s]


ep 126: train_loss: 0.00198, val_loss: 0.00070


ep: 127, train loss=0.002,lr=0.00024: 100%|██████████| 1055/1055 [00:19<00:00, 53.03it/s]


ep 127: train_loss: 0.00185, val_loss: 0.00072


ep: 128, train loss=0.001,lr=0.00024: 100%|██████████| 1055/1055 [00:19<00:00, 52.91it/s]


ep 128: train_loss: 0.00186, val_loss: 0.00060


ep: 129, train loss=0.004,lr=0.00024: 100%|██████████| 1055/1055 [00:20<00:00, 52.40it/s]


ep 129: train_loss: 0.00197, val_loss: 0.00076


ep: 130, train loss=0.002,lr=0.00024: 100%|██████████| 1055/1055 [00:20<00:00, 52.58it/s]


ep 130: train_loss: 0.00189, val_loss: 0.00063


ep: 131, train loss=0.004,lr=0.00024: 100%|██████████| 1055/1055 [00:19<00:00, 52.91it/s]


ep 131: train_loss: 0.00170, val_loss: 0.00070


ep: 132, train loss=0.003,lr=0.00024: 100%|██████████| 1055/1055 [00:20<00:00, 52.68it/s]


ep 132: train_loss: 0.00161, val_loss: 0.00041


ep: 133, train loss=0.004,lr=0.00024: 100%|██████████| 1055/1055 [00:20<00:00, 52.47it/s]


ep 133: train_loss: 0.00158, val_loss: 0.00030


ep: 134, train loss=0.001,lr=0.00023: 100%|██████████| 1055/1055 [00:19<00:00, 52.78it/s]


ep 134: train_loss: 0.00149, val_loss: 0.00027


ep: 135, train loss=0.001,lr=0.00023: 100%|██████████| 1055/1055 [00:20<00:00, 52.27it/s]


ep 135: train_loss: 0.00131, val_loss: 0.00035


ep: 136, train loss=0.000,lr=0.00023: 100%|██████████| 1055/1055 [00:20<00:00, 52.36it/s]


ep 136: train_loss: 0.00115, val_loss: 0.00035


ep: 137, train loss=0.003,lr=0.00023: 100%|██████████| 1055/1055 [00:20<00:00, 52.71it/s]


ep 137: train_loss: 0.00115, val_loss: 0.00024


ep: 138, train loss=0.000,lr=0.00023: 100%|██████████| 1055/1055 [00:20<00:00, 52.60it/s]


ep 138: train_loss: 0.00112, val_loss: 0.00020


ep: 139, train loss=0.001,lr=0.00023: 100%|██████████| 1055/1055 [00:20<00:00, 52.50it/s]


ep 139: train_loss: 0.00106, val_loss: 0.00025


ep: 140, train loss=0.003,lr=0.00023: 100%|██████████| 1055/1055 [00:20<00:00, 52.69it/s]


ep 140: train_loss: 0.00109, val_loss: 0.00022


ep: 141, train loss=0.003,lr=0.00023: 100%|██████████| 1055/1055 [00:20<00:00, 52.67it/s]


ep 141: train_loss: 0.00109, val_loss: 0.00026


ep: 142, train loss=0.000,lr=0.00023: 100%|██████████| 1055/1055 [00:19<00:00, 53.04it/s]


ep 142: train_loss: 0.00097, val_loss: 0.00023


ep: 143, train loss=0.003,lr=0.00023: 100%|██████████| 1055/1055 [00:19<00:00, 53.12it/s]


ep 143: train_loss: 0.00106, val_loss: 0.00015


ep: 144, train loss=0.002,lr=0.00023: 100%|██████████| 1055/1055 [00:19<00:00, 52.99it/s]


ep 144: train_loss: 0.00100, val_loss: 0.00019


ep: 145, train loss=0.000,lr=0.00023: 100%|██████████| 1055/1055 [00:19<00:00, 53.05it/s]


ep 145: train_loss: 0.00093, val_loss: 0.00024


ep: 146, train loss=0.000,lr=0.00022: 100%|██████████| 1055/1055 [00:19<00:00, 53.05it/s]


ep 146: train_loss: 0.00107, val_loss: 0.00026


ep: 147, train loss=0.000,lr=0.00022: 100%|██████████| 1055/1055 [00:19<00:00, 52.90it/s]


ep 147: train_loss: 0.00093, val_loss: 0.00018


ep: 148, train loss=0.001,lr=0.00022: 100%|██████████| 1055/1055 [00:19<00:00, 52.78it/s]


ep 148: train_loss: 0.00098, val_loss: 0.00017


ep: 149, train loss=0.000,lr=0.00022: 100%|██████████| 1055/1055 [00:19<00:00, 53.05it/s]


ep 149: train_loss: 0.00086, val_loss: 0.00023


In [20]:
test_loss = validate(model, data_loaders.test_loader)
print('training set examples the model gives an incorrect result:')
train_acc = evaluate(model, data_loaders.train_loader, 20)
print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, data_loaders.test_loader)
print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, data_loaders.test_loader)
result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(result)

  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))


training set examples the model gives an incorrect result:
correct equation: 5986250995+1651849091=7638100086
predicted:        5986250995+1651849091=7638000086
correct equation: 6220698785+3774201250=9994900035
predicted:        6220698785+3774201250=9994800035
correct equation: 1479270551+2482219449=3961490000
predicted:        1479270551+2482219449=3961480000
validataion set examples the model gives an incorrect result:
correct equation: 1429344373+4270655709=5700000082
predicted:        1429344373+4270655709=5699000082
correct equation: 2974793598+0568206458=3543000056
predicted:        2974793598+0568206458=3543900056
correct equation: 1991526364+9108473724=11100000088
predicted:        1991526364+9108473724=11000000088
correct equation: 5446176784+3754473218=9200650002
predicted:        5446176784+3754473218=9200640002
correct equation: 8134374390+9895622045=18029996435
predicted:        8134374390+9895622045=18039996435
test set examples the model gives an incorrect result:
corr

minute single digit variation was observed

# 18 Digit Addition

In [None]:
max_ndigits = 18
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits +1 +1
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)
dataset_size = 400000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[10000, 20000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 5*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then decreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=150)

model_size: 277522, train_set_size: 370176


  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
ep: 0, train loss=2.081,lr=0.00020: 100%|██████████| 1446/1446 [00:36<00:00, 39.79it/s]
  trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))


ep 0: train_loss: 2.28385, val_loss: 2.05867


ep: 1, train loss=2.011,lr=0.00041: 100%|██████████| 1446/1446 [00:36<00:00, 39.99it/s]


ep 1: train_loss: 2.03676, val_loss: 1.99809


ep: 2, train loss=1.727,lr=0.00062: 100%|██████████| 1446/1446 [00:36<00:00, 39.97it/s]


ep 2: train_loss: 1.88767, val_loss: 1.63243


ep: 3, train loss=0.289,lr=0.00083: 100%|██████████| 1446/1446 [00:36<00:00, 39.86it/s]


ep 3: train_loss: 0.88012, val_loss: 0.17052


ep: 4, train loss=0.094,lr=0.00103: 100%|██████████| 1446/1446 [00:36<00:00, 40.06it/s]


ep 4: train_loss: 0.15666, val_loss: 0.02681


ep: 5, train loss=0.044,lr=0.00095: 100%|██████████| 1446/1446 [00:36<00:00, 39.88it/s]


ep 5: train_loss: 0.06800, val_loss: 0.01589


ep: 6, train loss=0.041,lr=0.00088: 100%|██████████| 1446/1446 [00:36<00:00, 39.80it/s]


ep 6: train_loss: 0.04235, val_loss: 0.01189


ep: 7, train loss=0.024,lr=0.00082: 100%|██████████| 1446/1446 [00:36<00:00, 39.88it/s]


ep 7: train_loss: 0.03190, val_loss: 0.01222


ep: 8, train loss=0.026,lr=0.00078: 100%|██████████| 1446/1446 [00:36<00:00, 39.83it/s]


ep 8: train_loss: 0.02600, val_loss: 0.00905


ep: 9, train loss=0.021,lr=0.00074: 100%|██████████| 1446/1446 [00:36<00:00, 39.94it/s]


ep 9: train_loss: 0.02160, val_loss: 0.00786


ep: 10, train loss=0.013,lr=0.00070: 100%|██████████| 1446/1446 [00:36<00:00, 39.81it/s]


ep 10: train_loss: 0.01825, val_loss: 0.00675


ep: 11, train loss=0.012,lr=0.00067: 100%|██████████| 1446/1446 [00:36<00:00, 39.96it/s]


ep 11: train_loss: 0.01638, val_loss: 0.00654


ep: 12, train loss=0.017,lr=0.00065: 100%|██████████| 1446/1446 [00:36<00:00, 40.03it/s]


ep 12: train_loss: 0.01457, val_loss: 0.00560


ep: 13, train loss=0.013,lr=0.00062: 100%|██████████| 1446/1446 [00:36<00:00, 39.78it/s]


ep 13: train_loss: 0.01326, val_loss: 0.00526


ep: 14, train loss=0.009,lr=0.00060: 100%|██████████| 1446/1446 [00:36<00:00, 40.02it/s]


ep 14: train_loss: 0.01214, val_loss: 0.00467


ep: 15, train loss=0.013,lr=0.00058: 100%|██████████| 1446/1446 [00:36<00:00, 39.85it/s]


ep 15: train_loss: 0.01127, val_loss: 0.00510


ep: 16, train loss=0.012,lr=0.00056: 100%|██████████| 1446/1446 [00:36<00:00, 39.91it/s]


ep 16: train_loss: 0.01075, val_loss: 0.00498


ep: 17, train loss=0.011,lr=0.00055: 100%|██████████| 1446/1446 [00:36<00:00, 40.08it/s]


ep 17: train_loss: 0.01022, val_loss: 0.00466


ep: 18, train loss=0.007,lr=0.00053: 100%|██████████| 1446/1446 [00:36<00:00, 39.90it/s]


ep 18: train_loss: 0.00997, val_loss: 0.00432


ep: 19, train loss=0.006,lr=0.00052: 100%|██████████| 1446/1446 [00:36<00:00, 39.92it/s]


ep 19: train_loss: 0.00947, val_loss: 0.00473


ep: 20, train loss=0.013,lr=0.00051: 100%|██████████| 1446/1446 [00:36<00:00, 39.77it/s]


ep 20: train_loss: 0.00914, val_loss: 0.00437


ep: 21, train loss=0.007,lr=0.00050: 100%|██████████| 1446/1446 [00:36<00:00, 39.89it/s]


ep 21: train_loss: 0.00888, val_loss: 0.00422


ep: 22, train loss=0.006,lr=0.00048: 100%|██████████| 1446/1446 [00:36<00:00, 39.90it/s]


ep 22: train_loss: 0.00853, val_loss: 0.00415


ep: 23, train loss=0.005,lr=0.00047: 100%|██████████| 1446/1446 [00:36<00:00, 39.90it/s]


ep 23: train_loss: 0.00832, val_loss: 0.00401


ep: 24, train loss=0.011,lr=0.00047: 100%|██████████| 1446/1446 [00:36<00:00, 40.11it/s]


ep 24: train_loss: 0.00806, val_loss: 0.00395


ep: 25, train loss=0.010,lr=0.00046: 100%|██████████| 1446/1446 [00:36<00:00, 39.83it/s]


ep 25: train_loss: 0.00787, val_loss: 0.00419


ep: 26, train loss=0.012,lr=0.00045: 100%|██████████| 1446/1446 [00:36<00:00, 40.08it/s]


ep 26: train_loss: 0.00768, val_loss: 0.00398


ep: 27, train loss=0.010,lr=0.00044: 100%|██████████| 1446/1446 [00:36<00:00, 39.91it/s]


ep 27: train_loss: 0.00757, val_loss: 0.00365


ep: 28, train loss=0.010,lr=0.00043: 100%|██████████| 1446/1446 [00:36<00:00, 39.68it/s]


ep 28: train_loss: 0.00734, val_loss: 0.00427


ep: 29, train loss=0.009,lr=0.00042: 100%|██████████| 1446/1446 [00:36<00:00, 39.78it/s]


ep 29: train_loss: 0.00720, val_loss: 0.00360


ep: 30, train loss=0.005,lr=0.00042: 100%|██████████| 1446/1446 [00:36<00:00, 39.86it/s]


ep 30: train_loss: 0.00709, val_loss: 0.00394


ep: 31, train loss=0.007,lr=0.00041: 100%|██████████| 1446/1446 [00:36<00:00, 39.72it/s]


ep 31: train_loss: 0.00697, val_loss: 0.00375


ep: 32, train loss=0.011,lr=0.00040: 100%|██████████| 1446/1446 [00:36<00:00, 39.97it/s]


ep 32: train_loss: 0.00682, val_loss: 0.00359


ep: 33, train loss=0.008,lr=0.00040: 100%|██████████| 1446/1446 [00:36<00:00, 39.85it/s]


ep 33: train_loss: 0.00676, val_loss: 0.00395


ep: 34, train loss=0.010,lr=0.00039: 100%|██████████| 1446/1446 [00:36<00:00, 39.65it/s]


ep 34: train_loss: 0.00662, val_loss: 0.00375


ep: 35, train loss=0.005,lr=0.00039: 100%|██████████| 1446/1446 [00:36<00:00, 39.60it/s]


ep 35: train_loss: 0.00646, val_loss: 0.00386


ep: 36, train loss=0.006,lr=0.00038: 100%|██████████| 1446/1446 [00:36<00:00, 39.42it/s]


ep 36: train_loss: 0.00639, val_loss: 0.00380


ep: 37, train loss=0.006,lr=0.00038: 100%|██████████| 1446/1446 [00:36<00:00, 39.52it/s]


ep 37: train_loss: 0.00637, val_loss: 0.00376


ep: 38, train loss=0.008,lr=0.00037: 100%|██████████| 1446/1446 [00:36<00:00, 39.38it/s]


ep 38: train_loss: 0.00616, val_loss: 0.00319


ep: 39, train loss=0.008,lr=0.00037: 100%|██████████| 1446/1446 [00:36<00:00, 39.61it/s]


ep 39: train_loss: 0.00602, val_loss: 0.00332


ep: 40, train loss=0.004,lr=0.00036: 100%|██████████| 1446/1446 [00:36<00:00, 39.49it/s]


ep 40: train_loss: 0.00594, val_loss: 0.00345


ep: 41, train loss=0.007,lr=0.00036: 100%|██████████| 1446/1446 [00:36<00:00, 39.53it/s]


ep 41: train_loss: 0.00586, val_loss: 0.00301


ep: 42, train loss=0.003,lr=0.00035: 100%|██████████| 1446/1446 [00:36<00:00, 39.64it/s]


ep 42: train_loss: 0.00595, val_loss: 0.00375


ep: 43, train loss=0.003,lr=0.00035: 100%|██████████| 1446/1446 [00:36<00:00, 39.73it/s]


ep 43: train_loss: 0.00576, val_loss: 0.00375


ep: 44, train loss=0.004,lr=0.00035: 100%|██████████| 1446/1446 [00:36<00:00, 39.76it/s]


ep 44: train_loss: 0.00570, val_loss: 0.00304


ep: 45, train loss=0.007,lr=0.00034: 100%|██████████| 1446/1446 [00:35<00:00, 40.41it/s]


ep 45: train_loss: 0.00559, val_loss: 0.00341


ep: 46, train loss=0.005,lr=0.00034: 100%|██████████| 1446/1446 [00:35<00:00, 40.47it/s]


ep 46: train_loss: 0.00555, val_loss: 0.00313


ep: 47, train loss=0.005,lr=0.00034: 100%|██████████| 1446/1446 [00:35<00:00, 40.41it/s]


ep 47: train_loss: 0.00552, val_loss: 0.00333


ep: 48, train loss=0.008,lr=0.00033: 100%|██████████| 1446/1446 [00:35<00:00, 40.50it/s]


ep 48: train_loss: 0.00547, val_loss: 0.00290


ep: 49, train loss=0.003,lr=0.00033: 100%|██████████| 1446/1446 [00:35<00:00, 40.47it/s]


ep 49: train_loss: 0.00538, val_loss: 0.00292


ep: 50, train loss=0.005,lr=0.00033: 100%|██████████| 1446/1446 [00:35<00:00, 40.50it/s]


ep 50: train_loss: 0.00533, val_loss: 0.00335


ep: 51, train loss=0.005,lr=0.00032: 100%|██████████| 1446/1446 [00:35<00:00, 40.47it/s]


ep 51: train_loss: 0.00536, val_loss: 0.00364


ep: 52, train loss=0.003,lr=0.00032: 100%|██████████| 1446/1446 [00:36<00:00, 40.09it/s]


ep 52: train_loss: 0.00527, val_loss: 0.00297


ep: 53, train loss=0.005,lr=0.00032: 100%|██████████| 1446/1446 [00:36<00:00, 39.83it/s]


ep 53: train_loss: 0.00518, val_loss: 0.00319


ep: 54, train loss=0.009,lr=0.00031: 100%|██████████| 1446/1446 [00:36<00:00, 39.92it/s]


ep 54: train_loss: 0.00514, val_loss: 0.00465


ep: 55, train loss=0.004,lr=0.00031: 100%|██████████| 1446/1446 [00:36<00:00, 40.16it/s]


ep 55: train_loss: 0.00511, val_loss: 0.00311


ep: 56, train loss=0.006,lr=0.00031: 100%|██████████| 1446/1446 [00:36<00:00, 39.95it/s]


ep 56: train_loss: 0.00505, val_loss: 0.00305


ep: 57, train loss=0.006,lr=0.00031: 100%|██████████| 1446/1446 [00:36<00:00, 40.06it/s]


ep 57: train_loss: 0.00500, val_loss: 0.00412


ep: 58, train loss=0.004,lr=0.00030: 100%|██████████| 1446/1446 [00:36<00:00, 40.07it/s]


ep 58: train_loss: 0.00504, val_loss: 0.00328


ep: 59, train loss=0.003,lr=0.00030: 100%|██████████| 1446/1446 [00:36<00:00, 39.92it/s]


ep 59: train_loss: 0.00493, val_loss: 0.00281


ep: 60, train loss=0.003,lr=0.00030:  50%|█████     | 726/1446 [00:18<00:18, 39.52it/s]