# Train a decoder-only transformer (GPT-like) to do addition

In this notebook I train a GPT model to do n-digit addition.

I learned a lot from https://github.com/karpathy/minGPT, but I have rewritten all the code based on my own understanding.


In [2]:
import math
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## Data processing
### Generate an addition dataset

In [3]:
PLUS_SIGN = 10
MUL_SIGN  = 11
MINUS_SIGN = 12
EQUAL_SIGN = 13
EOS = 14
BOS = 15
PAD = 16
UNK = 17

symbol_to_int_dict = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4,
                      "5": 5, "6": 6, "7": 7, "8": 8, "9": 9,
                      "+": PLUS_SIGN, "*": MUL_SIGN, "-": MINUS_SIGN,
                      "=": EQUAL_SIGN,  "<EOS>": EOS, "<BOS>": BOS,
                      "<pad>": PAD, "??": UNK
                      }

int_to_symbol_dict = {y:x for (x,y) in symbol_to_int_dict.items()}
vocab_size = len(symbol_to_int_dict)

def decode_equation(equation):
    '''convert an equation in list format to string format '''
    res = "".join([str(int_to_symbol_dict.get(x, UNK)) for x in equation.tolist()])
    return res.replace("<BOS>", "").replace("<EOS>", "")

def encode_equation(equation, max_ndigits, padQ=True):
    '''convert an equation(up to the equal sign in it) in string format to a list'''
    equal_size_loc = equation.index('=')
    plus_size_loc = equation.index('+')
    num1 = pad_number(equation[0:plus_size_loc], max_ndigits)
    num2 = pad_number(equation[plus_size_loc+1:equal_size_loc], max_ndigits)
    new_equation = num1 + "+" + num2 + "="
    return torch.tensor([BOS]+[symbol_to_int_dict.get(n, UNK) for n in new_equation]).to(DEVICE)


def pad_number(num, max_ndigits)->str:
    'pad numbers with zeros in front so that they have the same length max_ndigits'
    s = str(num)
    while len(s)<max_ndigits:
      s = "0"+s
    return s

def create_add_dataset(max_ndigits, dataset_size, padQ=True):
    ''' Function for creating an addition dataset.
    if padQ=True, pre-padding of 0s will be added on the numbers such that all the 
    numbers has the same length max_ndigits, for example, with max_ndigits=3,  
    32 will be represented 032.
    '''
    dataset_str = []
    for i in range(dataset_size):
        num1, num2 = np.random.randint(0, 10**max_ndigits, 2)
        ans = num1 + num2
        # If padQ=True, we pad all the numbers with '0' in front
        # such that they all have length max_ndigits
        if padQ:
            equation = pad_number(num1, max_ndigits) + '+' + pad_number(num2, max_ndigits) + "=" + pad_number(ans, max_ndigits)
        else:
            equation = str(num1) + '+' + str(num2) + "=" + str(ans)
        dataset_str.append(equation)

    dataset = [torch.tensor([BOS]+[symbol_to_int_dict.get(n, UNK) for n in x]+[EOS])
               for x in dataset_str]
    return dataset, dataset_str

print(create_add_dataset(2, 4, padQ=False))
print(create_add_dataset(2, 4, padQ=True))

([tensor([15,  5,  1, 10,  9,  2, 13,  1,  4,  3, 14]), tensor([15,  1,  4, 10,  7,  1, 13,  8,  5, 14]), tensor([15,  6,  0, 10,  2,  0, 13,  8,  0, 14]), tensor([15,  8,  2, 10,  8,  6, 13,  1,  6,  8, 14])], ['51+92=143', '14+71=85', '60+20=80', '82+86=168'])
([tensor([15,  7,  4, 10,  7,  4, 13,  1,  4,  8, 14]), tensor([15,  8,  7, 10,  9,  9, 13,  1,  8,  6, 14]), tensor([15,  2,  3, 10,  0,  2, 13,  2,  5, 14]), tensor([15,  2,  1, 10,  5,  2, 13,  7,  3, 14])], ['74+74=148', '87+99=186', '23+02=25', '21+52=73'])


### Create dataloders for the train, validation and test sets

In [4]:
class TranslationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

batch_size = 256

def pad_sequence(batch):
    input_padded = torch.nn.utils.rnn.pad_sequence(batch,
                                batch_first=True, padding_value = PAD)
    return input_padded

@dataclass
class DataLoaders:
    max_ndigits: int
    dataset_size: int
    padQ: bool = True
    val_loader = None
    test_loader = None
    train_loader = None

    def split_data(self, split=[0.7, 0.1, 0.2]):
        # If split consists of floats whose sum is equal to 1 then we split the
        # dataset by the percentages given by split. If split contains integers, then
        # it is understood that the two first integers in split are the number of examples
        # in the validation and test sets.
        if isinstance(split[0], float):
            train_size  = round(self.dataset_size * split[0])
            val_size = round(self.dataset_size * split[1])
            test_size = self.dataset_size - train_size - val_size

        elif isinstance(split[0], int):
            val_size = split[0]
            test_size = split[1]
            train_size  = dataset_size - test_size - val_size


        dataset, _ = create_add_dataset(self.max_ndigits, self.dataset_size, padQ=self.padQ)
        train_set, val_set, test_set = torch.utils.data.random_split(dataset,
                                                             [train_size, val_size, test_size],
                                                    generator=torch.Generator().manual_seed(42) )

        self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                           shuffle=True, collate_fn = pad_sequence)
        self.test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                           shuffle=True, collate_fn=pad_sequence)
        self.val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,
                                           shuffle=True, collate_fn=pad_sequence)


## GPT model
Here is my implementation of the GPT model, including the multi-headed self-attention module.

In [5]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_embed, dropout=0.0):
        super(MultiHeadedAttention, self).__init__()
        assert d_embed % h == 0 # check the h number
        self.d_k = d_embed//h
        self.d_embed = d_embed
        self.h = h
        self.WQ = nn.Linear(d_embed, d_embed)
        self.WK = nn.Linear(d_embed, d_embed)
        self.WV = nn.Linear(d_embed, d_embed)
        self.linear = nn.Linear(d_embed, d_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_query, x_key, x_value, mask=None):
        nbatch = x_query.size(0) # get batch size
        # 1) Linear projections to get the multi-head query, key and value tensors
        # x_query, x_key, x_value dimension: nbatch * seq_len * d_embed
        # LHS query, key, value dimensions: nbatch * h * seq_len * d_k
        query = self.WQ(x_query).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        key   = self.WK(x_key).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        value = self.WV(x_value).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        # 2) Attention
        # scores has dimensions: nbatch * h * seq_len * seq_len
        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_k)
        # 3) Mask out padding tokens and future tokens
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        # p_atten dimensions: nbatch * h * seq_len * seq_len
        p_atten = torch.nn.functional.softmax(scores, dim=-1)
        p_atten = self.dropout(p_atten)
        # x dimensions: nbatch * h * seq_len * d_k
        x = torch.matmul(p_atten, value)
        # x now has dimensions:nbtach * seq_len * d_embed
        x = x.transpose(1, 2).contiguous().view(nbatch, -1, self.d_embed)
        return self.linear(x) # final linear layer


class ResidualConnection(nn.Module):
  '''residual connection: x + dropout(sublayer(layernorm(x))) '''
  def __init__(self, dim, dropout):
      super().__init__()
      self.drop = nn.Dropout(dropout)
      self.norm = nn.LayerNorm(dim)

  def forward(self, x, sublayer):
      return x + self.drop(sublayer(self.norm(x)))



class Decoder(nn.Module):
    '''Decoder = token embedding + positional embedding -> a stack of N DecoderBlock -> fully-connected layer'''
    def __init__(self, config):
        super().__init__()
        self.d_embed = config.d_embed
        self.tok_embed = nn.Embedding(config.decoder_vocab_size, config.d_embed)
        self.pos_embed = nn.Parameter(torch.zeros(1, config.max_seq_len, config.d_embed))
        self.dropout = nn.Dropout(config.dropout)
        self.decoder_blocks = nn.ModuleList([DecoderBlock(config) for _ in range(config.N_decoder)])
        self.norm = nn.LayerNorm(config.d_embed)
        self.linear = nn.Linear(config.d_embed, config.decoder_vocab_size)

    def future_mask(self, seq_len):
        '''mask out tokens at future positions'''
        mask = (torch.triu(torch.ones(seq_len, seq_len, requires_grad=False), diagonal=1)!=0).to(DEVICE)
        return mask.view(1, 1, seq_len, seq_len)

    def forward(self, input, pad_mask):
        seq_len = input.size(1)
        trg_mask = torch.logical_or(pad_mask, self.future_mask(seq_len))
        x = self.tok_embed(input) + self.pos_embed[:, :input.size(1), :]
        x = self.dropout(x)
        for layer in self.decoder_blocks:
            x = layer( x, trg_mask)
        x = self.norm(x)
        logits = self.linear(x)
        return logits


class DecoderBlock(nn.Module):
    ''' EncoderBlock: self-attention -> position-wise feed-forward (fully connected) layer'''
    def __init__(self, config):
        super().__init__()
        self.atten = MultiHeadedAttention(config.h, config.d_embed)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_embed, config.d_ff),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_ff, config.d_embed)
        )
        self.residuals = nn.ModuleList([ResidualConnection(config.d_embed, config.dropout)
                                       for i in range(2)])

    def forward(self, decoder_layer_input, trg_mask):
        y = decoder_layer_input
        y = self.residuals[0](y, lambda y: self.atten(y, y, y, mask=trg_mask))
        return self.residuals[1](y, self.feed_forward)


### Let's creat a GPT!

In [6]:
@dataclass
class ModelConfig:
    decoder_vocab_size: int
    d_embed: int
    # d_ff is the dimension of the fully-connected  feed-forward layer
    d_ff: int
    # h is the number of attention head
    h: int
    N_decoder: int
    max_seq_len: int
    dropout: float


def make_GPT(config):
    model = Decoder(config).to(DEVICE)
    # initialize model parameters
    # it seems that this initialization is very important!
    for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    return model

## Functions for training and input/output processing

In [7]:
def make_batch_input(x):
        'function for generating model input, target and pad_mask from raw input x'
        input = x[:, :-1].to(DEVICE)
        equal_sign_loc = [(equation==EQUAL_SIGN).nonzero().item() for equation in x]
        # for the target, we mask out the tokens before the equal sign (including the equal sign)
        target = [torch.cat(
            (torch.tensor([PAD]*equal_sign_loc[i]), x[i][equal_sign_loc[i]+1:])) for i in range(len(x))]
        target = torch.cat(target, 0).contiguous().view(-1).to(DEVICE)
        pad_mask = (input == PAD).view(input.size(0), 1, 1, input.size(-1))
        return input, target, pad_mask

In [8]:
def train_epoch(model, dataloader):
    model.train()
    grad_norm_clip = 1.0
    losses, acc, count = [], 0, 0
    num_batches = len(dataloader)
    pbar = tqdm(enumerate(dataloader), total=num_batches)
    for idx, x  in  pbar:
        optimizer.zero_grad()
        input, target, pad_mask = make_batch_input(x)
        pred = model(input, pad_mask).to(DEVICE)
        pred = pred.view(-1, pred.size(-1))
        loss = loss_fn(pred, target).to(DEVICE)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
        # report progress
        if idx>0 and idx%50 == 0:
            pbar.set_description(f"ep: {scheduler.last_epoch//num_batches}, train loss={loss.item():.3f},lr={scheduler.get_last_lr()[0]:.5f}")
    return np.mean(losses)

def train(model, dataloaders, epochs):
    global early_stop_count
    train_size = len(dataloaders.train_loader)*batch_size
    for ep in range(epochs):
        train_loss = train_epoch(model, dataloaders.train_loader)
        val_loss = validate(model, dataloaders.val_loader)
        print(f'ep {ep}: train_loss: {train_loss:.5f}, val_loss: {val_loss:.5f}')

    return train_loss, val_loss


def validate(model, dataloder):
    'function for computing the loss on the validation set'
    model.eval()
    losses = []
    with torch.no_grad():
        for i, x in enumerate(dataloder):
            input, target, pad_mask = make_batch_input(x)
            pred = model(input, pad_mask).to(DEVICE)
            pred = pred.view(-1, pred.size(-1))
            losses.append(loss_fn(pred, target).item())
    return np.mean(losses)

In [12]:
@torch.no_grad()
def compute_sum(model, x):
    'Function for computing the sum of two numbers.'
    for i in range(max_ndigits+2):
        pad_mask = (x == PAD).view(1, 1, 1, x.size(-1)).to(DEVICE)
        logits = model(x, pad_mask)
        last_output = logits.argmax(-1)[:,-1].view(1,1)
        x = torch.cat((x, last_output), 1).to(DEVICE)
        if last_output.item() == EOS:
            break
    return x[0]

def evaluate(model, dataloader, num_batch=None):
    '''Function for evaluation the model.
    This function take equations, and truncate them up to the equal-sign, and feed them to the
    model to get the predictions, compare them with the correct answers, and output the accuracy.
    '''
    model.eval()
    acc, count = 0, 0
    num_wrong_to_display = 5
    for idx, x in enumerate(dataloader):
        for equation in x:
            loc_equal_sign = equation.tolist().index(EQUAL_SIGN)
            loc_EOS = equation.tolist().index(EOS)
            input = equation[0:loc_equal_sign+1].view(1, -1).to(DEVICE)
            ans = equation[:loc_EOS+1].tolist()
            ans_pred = compute_sum(model, input)
            count += 1

            if ans == ans_pred.tolist():
                acc +=1
            else:
                if num_wrong_to_display > 0:
                    print(f'correct equation: {decode_equation(equation).replace("<pad>","")}')
                    print(f'predicted:        {decode_equation(ans_pred)}')
                    num_wrong_to_display -= 1
        if num_batch and idx>num_batch:
            break
    return acc/count

def what_is(question:str)->str:
    'function for computing the sum of two numbers with input in literal string format'
    pred = compute_sum(model, encode_equation(question, max_ndigits).view(1,-1))
    pred = decode_equation(pred)
    pred = pred[pred.index("=")+1:]
    return question+pred


## 2-digit addition

In [None]:
max_ndigits = 2
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits + 1 + 1
# where the 1s represent BOS, Plus sign, Equal sign, the extra digit in the sum, EOS, respectively.
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)
dataset_size = 10000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[1000, 2000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 3*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=30)


model_size: 271378, train_set_size: 7168


100%|██████████| 28/28 [00:00<00:00, 45.02it/s]


ep 0: train_loss: 2.33063, val_loss: 1.71811


100%|██████████| 28/28 [00:00<00:00, 45.65it/s]


ep 1: train_loss: 1.50895, val_loss: 1.26140


100%|██████████| 28/28 [00:00<00:00, 45.92it/s]


ep 2: train_loss: 1.20016, val_loss: 1.02070


100%|██████████| 28/28 [00:00<00:00, 44.94it/s]


ep 3: train_loss: 1.00823, val_loss: 0.92573


100%|██████████| 28/28 [00:00<00:00, 45.58it/s]


ep 4: train_loss: 0.91627, val_loss: 0.82485


100%|██████████| 28/28 [00:00<00:00, 45.32it/s]


ep 5: train_loss: 0.79150, val_loss: 0.60109


100%|██████████| 28/28 [00:00<00:00, 46.29it/s]


ep 6: train_loss: 0.57075, val_loss: 0.37401


100%|██████████| 28/28 [00:00<00:00, 43.61it/s]


ep 7: train_loss: 0.40913, val_loss: 0.25278


100%|██████████| 28/28 [00:00<00:00, 44.68it/s]


ep 8: train_loss: 0.31168, val_loss: 0.15657


100%|██████████| 28/28 [00:00<00:00, 45.71it/s]


ep 9: train_loss: 0.23823, val_loss: 0.10147


100%|██████████| 28/28 [00:00<00:00, 46.04it/s]


ep 10: train_loss: 0.18537, val_loss: 0.07394


100%|██████████| 28/28 [00:00<00:00, 45.64it/s]


ep 11: train_loss: 0.14428, val_loss: 0.03609


100%|██████████| 28/28 [00:00<00:00, 45.77it/s]


ep 12: train_loss: 0.11255, val_loss: 0.03152


100%|██████████| 28/28 [00:00<00:00, 44.90it/s]


ep 13: train_loss: 0.09997, val_loss: 0.01707


100%|██████████| 28/28 [00:00<00:00, 44.80it/s]


ep 14: train_loss: 0.08343, val_loss: 0.01232


100%|██████████| 28/28 [00:00<00:00, 45.10it/s]


ep 15: train_loss: 0.06938, val_loss: 0.00635


100%|██████████| 28/28 [00:00<00:00, 44.99it/s]


ep 16: train_loss: 0.06354, val_loss: 0.00647


100%|██████████| 28/28 [00:00<00:00, 45.02it/s]


ep 17: train_loss: 0.05574, val_loss: 0.00519


100%|██████████| 28/28 [00:00<00:00, 46.62it/s]


ep 18: train_loss: 0.04978, val_loss: 0.00308


100%|██████████| 28/28 [00:00<00:00, 45.97it/s]


ep 19: train_loss: 0.04401, val_loss: 0.00269


100%|██████████| 28/28 [00:00<00:00, 46.62it/s]


ep 20: train_loss: 0.03893, val_loss: 0.00346


100%|██████████| 28/28 [00:00<00:00, 43.98it/s]


ep 21: train_loss: 0.03643, val_loss: 0.00188


100%|██████████| 28/28 [00:00<00:00, 45.72it/s]


ep 22: train_loss: 0.02607, val_loss: 0.00110


100%|██████████| 28/28 [00:00<00:00, 45.62it/s]


ep 23: train_loss: 0.02838, val_loss: 0.00131


100%|██████████| 28/28 [00:00<00:00, 45.68it/s]


ep 24: train_loss: 0.02957, val_loss: 0.00150


100%|██████████| 28/28 [00:00<00:00, 45.86it/s]


ep 25: train_loss: 0.02499, val_loss: 0.00124


100%|██████████| 28/28 [00:00<00:00, 46.09it/s]


ep 26: train_loss: 0.02530, val_loss: 0.00088


100%|██████████| 28/28 [00:00<00:00, 45.55it/s]


ep 27: train_loss: 0.02412, val_loss: 0.00053


100%|██████████| 28/28 [00:00<00:00, 46.27it/s]


ep 28: train_loss: 0.02042, val_loss: 0.00068


100%|██████████| 28/28 [00:00<00:00, 45.31it/s]


ep 29: train_loss: 0.01903, val_loss: 0.00057


In [None]:
test_loss = validate(model, data_loaders.test_loader)
print('training set examples the model gives an incorrect result:')
train_acc = evaluate(model, data_loaders.train_loader, 20)
print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, data_loaders.test_loader)
print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, data_loaders.test_loader)
result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(result)

training set examples the model gives an incorrect result:
validataion set examples the model gives an incorrect result:
test set examples the model gives an incorrect result:
train_size: 7168, train_loss: 0.019029294240421483,
                val_loss: 0.0005657454748870805, test_loss: 0.0006502777905552648,
                test_acc: 1.0, val_acc: 1.0, train_acc: 1.0
                


No wrong example shown. The model got all the 2-digit addition right.

## 5-digit addition 
<!-- and scaling laws
For 5-digit addition, there are 10<sup>10</sup> possible data points, so we will have enough data to study the scaling laws. For example, we can study how the performance of the model (with fixed number of parameters) improves as we increase the training set size.   -->

In [None]:
max_ndigits = 5
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits +1 +1
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)

dataset_size = 200000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[10000, 20000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 5*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=30)


model_size: 272530, train_set_size: 170240


ep: 0, train loss=1.600,lr=0.00030: 100%|██████████| 665/665 [00:16<00:00, 41.38it/s]


ep 0: train_loss: 2.04759, val_loss: 1.55023


ep: 1, train loss=1.195,lr=0.00061: 100%|██████████| 665/665 [00:16<00:00, 40.65it/s]


ep 1: train_loss: 1.42975, val_loss: 1.11064


ep: 2, train loss=0.497,lr=0.00091: 100%|██████████| 665/665 [00:16<00:00, 41.47it/s]


ep 2: train_loss: 0.84462, val_loss: 0.33478


ep: 3, train loss=0.211,lr=0.00122: 100%|██████████| 665/665 [00:16<00:00, 41.47it/s]


ep 3: train_loss: 0.31756, val_loss: 0.15157


ep: 4, train loss=0.108,lr=0.00153: 100%|██████████| 665/665 [00:16<00:00, 41.13it/s]


ep 4: train_loss: 0.16307, val_loss: 0.06498


ep: 5, train loss=0.057,lr=0.00140: 100%|██████████| 665/665 [00:16<00:00, 41.38it/s]


ep 5: train_loss: 0.09666, val_loss: 0.03029


ep: 6, train loss=0.050,lr=0.00130: 100%|██████████| 665/665 [00:16<00:00, 41.51it/s]


ep 6: train_loss: 0.06794, val_loss: 0.03451


ep: 7, train loss=0.055,lr=0.00121: 100%|██████████| 665/665 [00:15<00:00, 41.71it/s]


ep 7: train_loss: 0.05177, val_loss: 0.01648


ep: 8, train loss=0.052,lr=0.00114: 100%|██████████| 665/665 [00:16<00:00, 41.33it/s]


ep 8: train_loss: 0.04194, val_loss: 0.01408


ep: 9, train loss=0.040,lr=0.00108: 100%|██████████| 665/665 [00:16<00:00, 41.56it/s]


ep 9: train_loss: 0.03560, val_loss: 0.01517


ep: 10, train loss=0.036,lr=0.00103: 100%|██████████| 665/665 [00:15<00:00, 41.76it/s]


ep 10: train_loss: 0.03015, val_loss: 0.01296


ep: 11, train loss=0.020,lr=0.00099: 100%|██████████| 665/665 [00:16<00:00, 41.22it/s]


ep 11: train_loss: 0.02715, val_loss: 0.01153


ep: 12, train loss=0.016,lr=0.00095: 100%|██████████| 665/665 [00:16<00:00, 41.11it/s]


ep 12: train_loss: 0.02448, val_loss: 0.01027


ep: 13, train loss=0.025,lr=0.00092: 100%|██████████| 665/665 [00:16<00:00, 41.48it/s]


ep 13: train_loss: 0.02179, val_loss: 0.01033


ep: 14, train loss=0.018,lr=0.00089: 100%|██████████| 665/665 [00:16<00:00, 41.15it/s]


ep 14: train_loss: 0.01895, val_loss: 0.00714


ep: 15, train loss=0.007,lr=0.00086: 100%|██████████| 665/665 [00:16<00:00, 41.11it/s]


ep 15: train_loss: 0.01562, val_loss: 0.00556


ep: 16, train loss=0.007,lr=0.00083: 100%|██████████| 665/665 [00:16<00:00, 41.03it/s]


ep 16: train_loss: 0.01351, val_loss: 0.00463


ep: 17, train loss=0.012,lr=0.00081: 100%|██████████| 665/665 [00:16<00:00, 41.46it/s]


ep 17: train_loss: 0.01164, val_loss: 0.00386


ep: 18, train loss=0.013,lr=0.00079: 100%|██████████| 665/665 [00:16<00:00, 41.40it/s]


ep 18: train_loss: 0.01060, val_loss: 0.00386


ep: 19, train loss=0.008,lr=0.00077: 100%|██████████| 665/665 [00:16<00:00, 41.30it/s]


ep 19: train_loss: 0.00987, val_loss: 0.00347


ep: 20, train loss=0.002,lr=0.00075: 100%|██████████| 665/665 [00:15<00:00, 41.63it/s]


ep 20: train_loss: 0.00926, val_loss: 0.00350


ep: 21, train loss=0.004,lr=0.00073: 100%|██████████| 665/665 [00:16<00:00, 41.19it/s]


ep 21: train_loss: 0.00867, val_loss: 0.00303


ep: 22, train loss=0.006,lr=0.00071: 100%|██████████| 665/665 [00:16<00:00, 41.52it/s]


ep 22: train_loss: 0.00811, val_loss: 0.00279


ep: 23, train loss=0.007,lr=0.00070: 100%|██████████| 665/665 [00:16<00:00, 41.38it/s]


ep 23: train_loss: 0.00764, val_loss: 0.00257


ep: 24, train loss=0.004,lr=0.00069: 100%|██████████| 665/665 [00:16<00:00, 41.24it/s]


ep 24: train_loss: 0.00741, val_loss: 0.00224


ep: 25, train loss=0.010,lr=0.00067: 100%|██████████| 665/665 [00:16<00:00, 41.02it/s]


ep 25: train_loss: 0.00696, val_loss: 0.00216


ep: 26, train loss=0.003,lr=0.00066: 100%|██████████| 665/665 [00:16<00:00, 40.99it/s]


ep 26: train_loss: 0.00677, val_loss: 0.00213


ep: 27, train loss=0.003,lr=0.00065: 100%|██████████| 665/665 [00:16<00:00, 41.39it/s]


ep 27: train_loss: 0.00620, val_loss: 0.00177


ep: 28, train loss=0.006,lr=0.00064: 100%|██████████| 665/665 [00:16<00:00, 41.02it/s]


ep 28: train_loss: 0.00581, val_loss: 0.00175


ep: 29, train loss=0.004,lr=0.00063: 100%|██████████| 665/665 [00:16<00:00, 41.39it/s]


ep 29: train_loss: 0.00580, val_loss: 0.00191


In [None]:
test_loss = validate(model, data_loaders.test_loader)
print('training set examples the model gives an incorrect result:')
train_acc = evaluate(model, data_loaders.train_loader, 20)
print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, data_loaders.test_loader)
print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, data_loaders.test_loader)
result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(result)

training set examples the model gives an incorrect result:
correct equation: 08062+11419=19481
predicted:        08062+11419=19471
correct equation: 00654+18876=19530
predicted:        00654+18876=19520
correct equation: 08899+01130=10029
predicted:        08899+01130=10039
correct equation: 04796+09763=14559
predicted:        04796+09763=14569
correct equation: 05801+07910=13711
predicted:        05801+07910=13721
validataion set examples the model gives an incorrect result:
correct equation: 07726+09569=17295
predicted:        07726+09569=17285
correct equation: 05021+14164=19185
predicted:        05021+14164=19175
correct equation: 09069+04551=13620
predicted:        09069+04551=13610
correct equation: 01863+17203=19066
predicted:        01863+17203=19076
correct equation: 03907+14952=18859
predicted:        03907+14952=18869
test set examples the model gives an incorrect result:
correct equation: 01863+17203=19066
predicted:        01863+17203=19076
correct equation: 09089+09305=18

The model got a very small fraction of the 5-digit addition wrong, and the answers the model gave are mostly one or two digits off. 

# 10-digit addition

In [10]:
max_ndigits = 10
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits +1 +1
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)
dataset_size = 300000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[10000, 20000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 5*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=50)


model_size: 274450, train_set_size: 270080


ep: 0, train loss=1.908,lr=0.00024: 100%|██████████| 1055/1055 [00:22<00:00, 47.04it/s]


ep 0: train_loss: 2.21616, val_loss: 1.87900


ep: 1, train loss=1.679,lr=0.00049: 100%|██████████| 1055/1055 [00:22<00:00, 47.27it/s]


ep 1: train_loss: 1.81667, val_loss: 1.65460


ep: 2, train loss=0.640,lr=0.00073: 100%|██████████| 1055/1055 [00:22<00:00, 47.31it/s]


ep 2: train_loss: 1.05677, val_loss: 0.51183


ep: 3, train loss=0.391,lr=0.00097: 100%|██████████| 1055/1055 [00:22<00:00, 47.36it/s]


ep 3: train_loss: 0.47011, val_loss: 0.31277


ep: 4, train loss=0.322,lr=0.00122: 100%|██████████| 1055/1055 [00:22<00:00, 47.28it/s]


ep 4: train_loss: 0.35374, val_loss: 0.27067


ep: 5, train loss=0.103,lr=0.00111: 100%|██████████| 1055/1055 [00:22<00:00, 47.45it/s]


ep 5: train_loss: 0.24741, val_loss: 0.05725


ep: 6, train loss=0.041,lr=0.00103: 100%|██████████| 1055/1055 [00:22<00:00, 47.40it/s]


ep 6: train_loss: 0.07131, val_loss: 0.02209


ep: 7, train loss=0.032,lr=0.00096: 100%|██████████| 1055/1055 [00:22<00:00, 47.36it/s]


ep 7: train_loss: 0.04485, val_loss: 0.01444


ep: 8, train loss=0.036,lr=0.00091: 100%|██████████| 1055/1055 [00:22<00:00, 47.20it/s]


ep 8: train_loss: 0.03369, val_loss: 0.01287


ep: 9, train loss=0.019,lr=0.00086: 100%|██████████| 1055/1055 [00:22<00:00, 47.24it/s]


ep 9: train_loss: 0.02758, val_loss: 0.01087


ep: 10, train loss=0.016,lr=0.00082: 100%|██████████| 1055/1055 [00:22<00:00, 47.04it/s]


ep 10: train_loss: 0.02359, val_loss: 0.00935


ep: 11, train loss=0.018,lr=0.00079: 100%|██████████| 1055/1055 [00:22<00:00, 47.18it/s]


ep 11: train_loss: 0.02068, val_loss: 0.00832


ep: 12, train loss=0.018,lr=0.00075: 100%|██████████| 1055/1055 [00:22<00:00, 47.39it/s]


ep 12: train_loss: 0.01868, val_loss: 0.00753


ep: 13, train loss=0.018,lr=0.00073: 100%|██████████| 1055/1055 [00:22<00:00, 47.15it/s]


ep 13: train_loss: 0.01724, val_loss: 0.00702


ep: 14, train loss=0.023,lr=0.00070: 100%|██████████| 1055/1055 [00:22<00:00, 47.14it/s]


ep 14: train_loss: 0.01598, val_loss: 0.00725


ep: 15, train loss=0.014,lr=0.00068: 100%|██████████| 1055/1055 [00:22<00:00, 47.19it/s]


ep 15: train_loss: 0.01453, val_loss: 0.00695


ep: 16, train loss=0.015,lr=0.00066: 100%|██████████| 1055/1055 [00:22<00:00, 47.39it/s]


ep 16: train_loss: 0.01377, val_loss: 0.00677


ep: 17, train loss=0.015,lr=0.00064: 100%|██████████| 1055/1055 [00:22<00:00, 47.22it/s]


ep 17: train_loss: 0.01291, val_loss: 0.00653


ep: 18, train loss=0.008,lr=0.00062: 100%|██████████| 1055/1055 [00:22<00:00, 47.11it/s]


ep 18: train_loss: 0.01233, val_loss: 0.00611


ep: 19, train loss=0.009,lr=0.00061: 100%|██████████| 1055/1055 [00:22<00:00, 47.38it/s]


ep 19: train_loss: 0.01209, val_loss: 0.00643


ep: 20, train loss=0.009,lr=0.00059: 100%|██████████| 1055/1055 [00:22<00:00, 47.15it/s]


ep 20: train_loss: 0.01146, val_loss: 0.00556


ep: 21, train loss=0.015,lr=0.00058: 100%|██████████| 1055/1055 [00:22<00:00, 47.05it/s]


ep 21: train_loss: 0.01100, val_loss: 0.00580


ep: 22, train loss=0.005,lr=0.00057: 100%|██████████| 1055/1055 [00:22<00:00, 46.88it/s]


ep 22: train_loss: 0.01079, val_loss: 0.00531


ep: 23, train loss=0.009,lr=0.00056: 100%|██████████| 1055/1055 [00:22<00:00, 47.06it/s]


ep 23: train_loss: 0.01037, val_loss: 0.00574


ep: 24, train loss=0.011,lr=0.00054: 100%|██████████| 1055/1055 [00:22<00:00, 47.24it/s]


ep 24: train_loss: 0.00996, val_loss: 0.00505


ep: 25, train loss=0.008,lr=0.00053: 100%|██████████| 1055/1055 [00:22<00:00, 47.22it/s]


ep 25: train_loss: 0.00974, val_loss: 0.00512


ep: 26, train loss=0.009,lr=0.00052: 100%|██████████| 1055/1055 [00:22<00:00, 47.22it/s]


ep 26: train_loss: 0.00937, val_loss: 0.00487


ep: 27, train loss=0.008,lr=0.00051: 100%|██████████| 1055/1055 [00:22<00:00, 47.24it/s]


ep 27: train_loss: 0.00915, val_loss: 0.00511


ep: 28, train loss=0.007,lr=0.00051: 100%|██████████| 1055/1055 [00:22<00:00, 47.04it/s]


ep 28: train_loss: 0.00894, val_loss: 0.00481


ep: 29, train loss=0.010,lr=0.00050: 100%|██████████| 1055/1055 [00:22<00:00, 47.14it/s]


ep 29: train_loss: 0.00880, val_loss: 0.00476


ep: 30, train loss=0.006,lr=0.00049: 100%|██████████| 1055/1055 [00:22<00:00, 47.36it/s]


ep 30: train_loss: 0.00858, val_loss: 0.00537


ep: 31, train loss=0.006,lr=0.00048: 100%|██████████| 1055/1055 [00:22<00:00, 47.36it/s]


ep 31: train_loss: 0.00851, val_loss: 0.00455


ep: 32, train loss=0.008,lr=0.00047: 100%|██████████| 1055/1055 [00:22<00:00, 47.20it/s]


ep 32: train_loss: 0.00837, val_loss: 0.00451


ep: 33, train loss=0.009,lr=0.00047: 100%|██████████| 1055/1055 [00:22<00:00, 47.39it/s]


ep 33: train_loss: 0.00805, val_loss: 0.00445


ep: 34, train loss=0.011,lr=0.00046: 100%|██████████| 1055/1055 [00:22<00:00, 47.46it/s]


ep 34: train_loss: 0.00802, val_loss: 0.00424


ep: 35, train loss=0.008,lr=0.00045: 100%|██████████| 1055/1055 [00:22<00:00, 47.21it/s]


ep 35: train_loss: 0.00790, val_loss: 0.00427


ep: 36, train loss=0.013,lr=0.00045: 100%|██████████| 1055/1055 [00:22<00:00, 47.27it/s]


ep 36: train_loss: 0.00770, val_loss: 0.00390


ep: 37, train loss=0.008,lr=0.00044: 100%|██████████| 1055/1055 [00:22<00:00, 47.12it/s]


ep 37: train_loss: 0.00769, val_loss: 0.00397


ep: 38, train loss=0.005,lr=0.00044: 100%|██████████| 1055/1055 [00:22<00:00, 47.45it/s]


ep 38: train_loss: 0.00738, val_loss: 0.00424


ep: 39, train loss=0.007,lr=0.00043: 100%|██████████| 1055/1055 [00:22<00:00, 47.33it/s]


ep 39: train_loss: 0.00733, val_loss: 0.00387


ep: 40, train loss=0.007,lr=0.00043: 100%|██████████| 1055/1055 [00:22<00:00, 47.22it/s]


ep 40: train_loss: 0.00737, val_loss: 0.00409


ep: 41, train loss=0.008,lr=0.00042: 100%|██████████| 1055/1055 [00:22<00:00, 47.67it/s]


ep 41: train_loss: 0.00709, val_loss: 0.00403


ep: 42, train loss=0.007,lr=0.00042: 100%|██████████| 1055/1055 [00:22<00:00, 47.22it/s]


ep 42: train_loss: 0.00704, val_loss: 0.00373


ep: 43, train loss=0.007,lr=0.00041: 100%|██████████| 1055/1055 [00:22<00:00, 47.24it/s]


ep 43: train_loss: 0.00688, val_loss: 0.00379


ep: 44, train loss=0.006,lr=0.00041: 100%|██████████| 1055/1055 [00:22<00:00, 47.13it/s]


ep 44: train_loss: 0.00693, val_loss: 0.00388


ep: 45, train loss=0.005,lr=0.00040: 100%|██████████| 1055/1055 [00:22<00:00, 47.25it/s]


ep 45: train_loss: 0.00690, val_loss: 0.00365


ep: 46, train loss=0.007,lr=0.00040: 100%|██████████| 1055/1055 [00:22<00:00, 47.14it/s]


ep 46: train_loss: 0.00667, val_loss: 0.00426


ep: 47, train loss=0.009,lr=0.00039: 100%|██████████| 1055/1055 [00:22<00:00, 47.20it/s]


ep 47: train_loss: 0.00662, val_loss: 0.00391


ep: 48, train loss=0.008,lr=0.00039: 100%|██████████| 1055/1055 [00:22<00:00, 47.33it/s]


ep 48: train_loss: 0.00658, val_loss: 0.00423


ep: 49, train loss=0.009,lr=0.00038: 100%|██████████| 1055/1055 [00:22<00:00, 47.30it/s]


ep 49: train_loss: 0.00651, val_loss: 0.00356


In [13]:
test_loss = validate(model, data_loaders.test_loader)
print('training set examples the model gives an incorrect result:')
train_acc = evaluate(model, data_loaders.train_loader, 20)
print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, data_loaders.test_loader)
print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, data_loaders.test_loader)
result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(result)

training set examples the model gives an incorrect result:
correct equation: 5588619202+3569680742=9158299944
predicted:        5588619202+3569680742=9158309944
correct equation: 2246603235+4520943775=6767547010
predicted:        2246603235+4520943775=6767546010
correct equation: 0633806046+2337190109=2970996155
predicted:        0633806046+2337190109=2971996155
correct equation: 0400816738+2163010270=2563827008
predicted:        0400816738+2163010270=2563826008
correct equation: 0247470421+0862429184=1109899605
predicted:        0247470421+0862429184=1109999605
validataion set examples the model gives an incorrect result:
correct equation: 4205217864+3893393129=8098610993
predicted:        4205217864+3893393129=8098611993
correct equation: 8752428096+0517770891=9270198987
predicted:        8752428096+0517770891=9270298987
correct equation: 3015532723+6457643298=9473176021
predicted:        3015532723+6457643298=9473175021
correct equation: 0600963084+0001658910=0602621994
predicted:  

# 18-digit addition

In [14]:
max_ndigits = 18
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits +1 +1
max_len = 3*max_ndigits + 6
config = ModelConfig(decoder_vocab_size= vocab_size,
                     d_embed=128,
                     d_ff=256,
                     h=4,
                     N_decoder=2,
                     max_seq_len= max_len,
                     dropout=0.1)
dataset_size = 400000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[10000, 20000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 5*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then decreases
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)

train_loss, val_loss = train(model, data_loaders, epochs=50)


model_size: 277522, train_set_size: 370176


ep: 0, train loss=2.096,lr=0.00020: 100%|██████████| 1446/1446 [00:38<00:00, 37.10it/s]


ep 0: train_loss: 2.28885, val_loss: 2.06631


ep: 1, train loss=1.960,lr=0.00041: 100%|██████████| 1446/1446 [00:38<00:00, 37.14it/s]


ep 1: train_loss: 2.03282, val_loss: 1.93960


ep: 2, train loss=1.807,lr=0.00062: 100%|██████████| 1446/1446 [00:38<00:00, 37.21it/s]


ep 2: train_loss: 1.88473, val_loss: 1.79065


ep: 3, train loss=1.537,lr=0.00083: 100%|██████████| 1446/1446 [00:38<00:00, 37.18it/s]


ep 3: train_loss: 1.76677, val_loss: 1.45359


ep: 4, train loss=0.489,lr=0.00103: 100%|██████████| 1446/1446 [00:38<00:00, 37.15it/s]


ep 4: train_loss: 0.82969, val_loss: 0.39867


ep: 5, train loss=0.083,lr=0.00095: 100%|██████████| 1446/1446 [00:38<00:00, 37.19it/s]


ep 5: train_loss: 0.22921, val_loss: 0.03232


ep: 6, train loss=0.049,lr=0.00088: 100%|██████████| 1446/1446 [00:38<00:00, 37.21it/s]


ep 6: train_loss: 0.06458, val_loss: 0.02105


ep: 7, train loss=0.043,lr=0.00082: 100%|██████████| 1446/1446 [00:38<00:00, 37.21it/s]


ep 7: train_loss: 0.04325, val_loss: 0.01776


ep: 8, train loss=0.036,lr=0.00078: 100%|██████████| 1446/1446 [00:38<00:00, 37.13it/s]


ep 8: train_loss: 0.03480, val_loss: 0.01582


ep: 9, train loss=0.020,lr=0.00074: 100%|██████████| 1446/1446 [00:38<00:00, 37.23it/s]


ep 9: train_loss: 0.03010, val_loss: 0.01396


ep: 10, train loss=0.026,lr=0.00070: 100%|██████████| 1446/1446 [00:38<00:00, 37.08it/s]


ep 10: train_loss: 0.02628, val_loss: 0.01287


ep: 11, train loss=0.030,lr=0.00067: 100%|██████████| 1446/1446 [00:39<00:00, 37.04it/s]


ep 11: train_loss: 0.02355, val_loss: 0.01112


ep: 12, train loss=0.020,lr=0.00065: 100%|██████████| 1446/1446 [00:38<00:00, 37.15it/s]


ep 12: train_loss: 0.02127, val_loss: 0.01131


ep: 13, train loss=0.020,lr=0.00062: 100%|██████████| 1446/1446 [00:38<00:00, 37.14it/s]


ep 13: train_loss: 0.01980, val_loss: 0.01057


ep: 14, train loss=0.016,lr=0.00060: 100%|██████████| 1446/1446 [00:38<00:00, 37.11it/s]


ep 14: train_loss: 0.01870, val_loss: 0.01111


ep: 15, train loss=0.017,lr=0.00058: 100%|██████████| 1446/1446 [00:39<00:00, 36.82it/s]


ep 15: train_loss: 0.01762, val_loss: 0.00967


ep: 16, train loss=0.018,lr=0.00056: 100%|██████████| 1446/1446 [00:40<00:00, 35.79it/s]


ep 16: train_loss: 0.01695, val_loss: 0.00929


ep: 17, train loss=0.013,lr=0.00055: 100%|██████████| 1446/1446 [00:39<00:00, 36.46it/s]


ep 17: train_loss: 0.01605, val_loss: 0.00939


ep: 18, train loss=0.015,lr=0.00053: 100%|██████████| 1446/1446 [00:39<00:00, 36.43it/s]


ep 18: train_loss: 0.01547, val_loss: 0.00886


ep: 19, train loss=0.015,lr=0.00052: 100%|██████████| 1446/1446 [00:39<00:00, 36.62it/s]


ep 19: train_loss: 0.01507, val_loss: 0.00905


ep: 20, train loss=0.011,lr=0.00051: 100%|██████████| 1446/1446 [00:39<00:00, 36.71it/s]


ep 20: train_loss: 0.01462, val_loss: 0.00906


ep: 21, train loss=0.017,lr=0.00050: 100%|██████████| 1446/1446 [00:39<00:00, 36.64it/s]


ep 21: train_loss: 0.01410, val_loss: 0.00880


ep: 22, train loss=0.014,lr=0.00048: 100%|██████████| 1446/1446 [00:39<00:00, 36.78it/s]


ep 22: train_loss: 0.01382, val_loss: 0.00877


ep: 23, train loss=0.015,lr=0.00047: 100%|██████████| 1446/1446 [00:39<00:00, 36.86it/s]


ep 23: train_loss: 0.01337, val_loss: 0.00837


ep: 24, train loss=0.014,lr=0.00047: 100%|██████████| 1446/1446 [00:39<00:00, 36.61it/s]


ep 24: train_loss: 0.01300, val_loss: 0.00816


ep: 25, train loss=0.010,lr=0.00046: 100%|██████████| 1446/1446 [00:39<00:00, 36.49it/s]


ep 25: train_loss: 0.01283, val_loss: 0.00851


ep: 26, train loss=0.015,lr=0.00045: 100%|██████████| 1446/1446 [00:39<00:00, 36.37it/s]


ep 26: train_loss: 0.01252, val_loss: 0.00813


ep: 27, train loss=0.011,lr=0.00044: 100%|██████████| 1446/1446 [00:39<00:00, 36.30it/s]


ep 27: train_loss: 0.01193, val_loss: 0.00673


ep: 28, train loss=0.012,lr=0.00043: 100%|██████████| 1446/1446 [00:39<00:00, 36.42it/s]


ep 28: train_loss: 0.01069, val_loss: 0.00611


ep: 29, train loss=0.014,lr=0.00042: 100%|██████████| 1446/1446 [00:39<00:00, 36.51it/s]


ep 29: train_loss: 0.01008, val_loss: 0.00514


ep: 30, train loss=0.012,lr=0.00042: 100%|██████████| 1446/1446 [00:39<00:00, 36.72it/s]


ep 30: train_loss: 0.00953, val_loss: 0.00506


ep: 31, train loss=0.013,lr=0.00041: 100%|██████████| 1446/1446 [00:39<00:00, 36.57it/s]


ep 31: train_loss: 0.00919, val_loss: 0.00495


ep: 32, train loss=0.010,lr=0.00040: 100%|██████████| 1446/1446 [00:39<00:00, 36.65it/s]


ep 32: train_loss: 0.00888, val_loss: 0.00467


ep: 33, train loss=0.010,lr=0.00040: 100%|██████████| 1446/1446 [00:39<00:00, 36.68it/s]


ep 33: train_loss: 0.00861, val_loss: 0.00501


ep: 34, train loss=0.011,lr=0.00039: 100%|██████████| 1446/1446 [00:39<00:00, 36.69it/s]


ep 34: train_loss: 0.00841, val_loss: 0.00489


ep: 35, train loss=0.010,lr=0.00039: 100%|██████████| 1446/1446 [00:39<00:00, 36.51it/s]


ep 35: train_loss: 0.00820, val_loss: 0.00463


ep: 36, train loss=0.009,lr=0.00038: 100%|██████████| 1446/1446 [00:39<00:00, 36.63it/s]


ep 36: train_loss: 0.00810, val_loss: 0.00476


ep: 37, train loss=0.007,lr=0.00038: 100%|██████████| 1446/1446 [00:39<00:00, 36.56it/s]


ep 37: train_loss: 0.00795, val_loss: 0.00471


ep: 38, train loss=0.004,lr=0.00037: 100%|██████████| 1446/1446 [00:39<00:00, 36.45it/s]


ep 38: train_loss: 0.00787, val_loss: 0.00480


ep: 39, train loss=0.007,lr=0.00037: 100%|██████████| 1446/1446 [00:39<00:00, 36.44it/s]


ep 39: train_loss: 0.00766, val_loss: 0.00503


ep: 40, train loss=0.008,lr=0.00036: 100%|██████████| 1446/1446 [00:39<00:00, 36.52it/s]


ep 40: train_loss: 0.00750, val_loss: 0.00462


ep: 41, train loss=0.008,lr=0.00036: 100%|██████████| 1446/1446 [00:39<00:00, 36.65it/s]


ep 41: train_loss: 0.00751, val_loss: 0.00474


ep: 42, train loss=0.011,lr=0.00035: 100%|██████████| 1446/1446 [00:39<00:00, 36.53it/s]


ep 42: train_loss: 0.00731, val_loss: 0.00459


ep: 43, train loss=0.007,lr=0.00035: 100%|██████████| 1446/1446 [00:39<00:00, 36.52it/s]


ep 43: train_loss: 0.00728, val_loss: 0.00473


ep: 44, train loss=0.006,lr=0.00035: 100%|██████████| 1446/1446 [00:39<00:00, 36.54it/s]


ep 44: train_loss: 0.00717, val_loss: 0.00465


ep: 45, train loss=0.008,lr=0.00034: 100%|██████████| 1446/1446 [00:39<00:00, 36.52it/s]


ep 45: train_loss: 0.00697, val_loss: 0.00448


ep: 46, train loss=0.004,lr=0.00034: 100%|██████████| 1446/1446 [00:39<00:00, 36.49it/s]


ep 46: train_loss: 0.00697, val_loss: 0.00434


ep: 47, train loss=0.009,lr=0.00034: 100%|██████████| 1446/1446 [00:39<00:00, 36.44it/s]


ep 47: train_loss: 0.00684, val_loss: 0.00416


ep: 48, train loss=0.007,lr=0.00033: 100%|██████████| 1446/1446 [00:39<00:00, 36.53it/s]


ep 48: train_loss: 0.00694, val_loss: 0.00431


ep: 49, train loss=0.004,lr=0.00033: 100%|██████████| 1446/1446 [00:39<00:00, 36.48it/s]


ep 49: train_loss: 0.00675, val_loss: 0.00430


In [15]:
test_loss = validate(model, data_loaders.test_loader)
print('training set examples the model gives an incorrect result:')
train_acc = evaluate(model, data_loaders.train_loader, 20)
print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, data_loaders.test_loader)
print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, data_loaders.test_loader)
result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(result)

training set examples the model gives an incorrect result:
correct equation: 981032270596811016+441027710402717948=1422059980999528964
predicted:        981032270596811016+441027710402717948=1422059981999528964
correct equation: 016450324225214967+900569997274624451=917020321499839418
predicted:        016450324225214967+900569997274624451=917020321599839418
correct equation: 160302195048473699+117821422051169265=278123617099642964
predicted:        160302195048473699+117821422051169265=278123617199642964
correct equation: 460436849321738648+060481920177649533=520918769499388181
predicted:        460436849321738648+060481920177649533=520918769599388181
correct equation: 677132007489206030+743341927832790797=1420473935321996827
predicted:        677132007489206030+743341927832790797=1420473935322996827
validataion set examples the model gives an incorrect result:
correct equation: 052868694601456094+251576304088699597=304444998690155691
predicted:        052868694601456094+2515763040886