<a href="https://colab.research.google.com/github/mandliya/mandliya.github.io/blob/main/parked/gpt_2_complete_loop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q datasets jaxtyping tiktoken

In [2]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from jaxtyping import Float, Int
import math
from typing import Optional, Tuple
import tiktoken
from datasets import load_dataset
from google.colab import userdata
import pathlib
from tqdm import tqdm
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader

In [3]:
@dataclass
class GPT2Config:
  n_layers: int = 12
  d_model: int = 768
  n_heads: int = 12
  vocab_size: int = 50257
  layer_norm_eps: float = 1e-5
  init_range: float = 0.02
  dropout: float = 0.1
  n_ctx: int = 1024
  d_mlp: int = 4 * 768
  weight_tying: bool = True

In [35]:
class GPT2Attention(nn.Module):
  def __init__(self, cfg: GPT2Config):
    super().__init__()
    assert cfg.d_model % cfg.n_heads == 0, (
        f"{cfg.d_model} should be divisible by {cfg.n_heads}"
    )
    self.cfg = cfg
    self.c_attn = nn.Linear(cfg.d_model, 3 * cfg.d_model)
    self.attn_dropout = nn.Dropout(cfg.dropout)
    self.c_proj = nn.Linear(cfg.d_model, cfg.d_model)
    self.resid_dropout = nn.Dropout(cfg.dropout)
    self.register_buffer(
        'mask',
        torch.tril(torch.ones(cfg.n_ctx, cfg.n_ctx))
        .view(1, 1, cfg.n_ctx, cfg.n_ctx)
    )

  def forward(
      self,
      x: Float[Tensor, "B T d_model"]) -> Float[Tensor, "B T d_model"]:
      B, T, d_model = x.shape
      n_heads = self.cfg.n_heads
      d_head = d_model // n_heads
      qkv = self.c_attn(x) #[B, T, d_model * 3]
      q, k, v = qkv.split(d_model, dim=2)
      q = q.view(B, T, n_heads, d_head).transpose(1, 2) #[B nh T dh]
      k = k.view(B, T, n_heads, d_head).transpose(1, 2)
      v = v.view(B, T, n_heads, d_head).transpose(1, 2)

      attn = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(d_head))
      attn = attn.masked_fill(
          self.mask[:, :, :T, :T] == 0,
          float('-inf')
      )
      attn = attn.softmax(dim=-1)
      attn = self.attn_dropout(attn)
      out = attn @ v #[B, nh, T, dh]
      out = out.transpose(1, 2).contiguous().view(B, T, d_model)
      out = self.c_proj(out)
      return self.resid_dropout(out)

In [5]:
class GPT2MLP(nn.Module):
  def __init__(self, cfg: GPT2Config):
    super().__init__()
    self.cfg = cfg
    self.c_fc = nn.Linear(cfg.d_model, cfg.d_mlp)
    self.gelu = nn.GELU()
    self.c_proj = nn.Linear(cfg.d_mlp, cfg.d_model)
    self.dropout = nn.Dropout(cfg.dropout)

  def forward(self, x: Float[Tensor, "B T d_model"]) -> Float[Tensor, "B T d_model"]:
    x = self.c_fc(x)
    x = self.gelu(x)
    x = self.c_proj(x)
    return self.dropout(x)

In [6]:
class GPT2Block(nn.Module):
  def __init__(self, cfg: GPT2Config):
    super().__init__()
    self.cfg = cfg
    self.ln_1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
    self.attn = GPT2Attention(cfg)
    self.ln_2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
    self.mlp = GPT2MLP(cfg)

  def forward(self, x: Float[Tensor, "B T d_model"]) -> Float[Tensor, "B T d_model"]:
    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))
    return x

In [7]:
class GPT2Model(nn.Module):
  def __init__(self, cfg: GPT2Config):
    super().__init__()
    self.cfg = cfg
    self.transformer = nn.ModuleDict(dict(
        wte = nn.Embedding(cfg.vocab_size, cfg.d_model),
        wpe = nn.Embedding(cfg.n_ctx, cfg.d_model),
        embd_dropout = nn.Dropout(cfg.dropout),
        h = nn.ModuleList([
            GPT2Block(cfg) for _ in range(self.cfg.n_layers)
        ]),
        ln_f = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
    ))
    self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size)
    if cfg.weight_tying:
      self.lm_head.weight = self.transformer.wte.weight

    self.apply(self._init_weights)
    for np, p in self.named_parameters():
      if np.endswith('c_proj.weight'):
        nn.init.normal_(p, mean=0.0, std=(cfg.init_range/math.sqrt(2 * cfg.n_layers)))

  def _init_weights(self, module: nn.Module):
    if isinstance(module, nn.Linear):
      nn.init.normal_(module.weight, mean=0.0, std=self.cfg.init_range)
      if module.bias is None:
        nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
      nn.init.normal_(module.weight, mean=0.0, std=self.cfg.init_range)

  def forward(
      self,
      tokens: Int[Tensor, "B T"],
      targets: Optional[Int[Tensor, "B T"]]
    ) -> Tuple[Int[Tensor, "B T vocab_size"], Float[Tensor, ""]]:
    B, T = tokens.shape
    assert T <= self.cfg.n_ctx, (
        f"Sequence length {T} is longer than max sequence length: {self.cfg.n_ctx}"
    )
    tok_emb = self.transformer.wte(tokens)
    pos = torch.arange(0, T, dtype=torch.long, device=tokens.device)
    pos_emb = self.transformer.wpe(pos)
    residual = pos_emb + tok_emb
    residual = self.transformer.embd_dropout(residual)
    for block in self.transformer.h:
      residual = block(residual)

    residual = self.transformer.ln_f(residual)
    if targets is not None:
      logits = self.lm_head(residual) #[B T vocab_size]
      loss = F.cross_entropy(
          logits.view(-1, logits.size(-1)), #[B*T vocab_size]
          targets.view(-1), #[B * T]
          ignore_index=-1
      )
    else:
      logits = self.lm_head(residual[:, [-1], :]) #[B 1 vocab_size]
      loss = None

    return logits, loss

  @torch.no_grad()
  def generate(
      self,
      tokens: Int[Tensor, "B T"],
      temperature: float = 1.0,
      max_num_tokens: int = 256,
      top_k: Optional[int] = None
    ) -> Int[Tensor, "B T+max_num_tokens"]:
    for _ in range(max_num_tokens):
      tok_cond = (
          tokens
          if tokens.size(-1) <= self.cfg.n_ctx
          else tokens[:, -self.cfg.n_ctx:]
      )
      logits, _ = self(tok_cond, targets=None) #[B T vocab_size]
      logits = logits[:, -1, :] #[B vocab_size]
      logits = logits / temperature
      if top_k is not None:
        k = min(k, self.cfg.vocab_size)
        v, _ = torch.topk(logits, k) #[B k]
        threshold = v[:, [-1]] #[B, 1]
        logits.masked_fill_(logits < threshold, float('-inf'))
      probs = logits.softmax(dim=-1) #[B vocab_size]
      next_token = torch.multinomial(probs, num_samples=1) #[B, 1]
      tokens = torch.cat((tokens, next_token), dim=1)
    return tokens




In [8]:
text = "Hello, I am a large language model." * 500
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(text)
tokens = torch.tensor(tokens, dtype=torch.long)
print(f'Total tokens: {len(tokens)}')

Total tokens: 4500


In [9]:
def get_batch(tokens: Tensor, block_size: int, batch_size: int, device: str):
  idx = torch.randint(0, len(tokens)-block_size, (batch_size,))
  x = torch.stack([tokens[i: i+block_size] for i in idx])
  y = torch.stack([tokens[i+1: i+block_size+1] for i in idx])
  return x.to(device), y.to(device)

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cfg = GPT2Config(dropout=0)
model = GPT2Model(cfg)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-5)
steps = 200
for step in range(steps):
  x, y = get_batch(tokens, block_size=cfg.n_ctx, batch_size=2, device=device)
  optimizer.zero_grad()
  _, loss = model(x, targets=y)
  loss.backward()
  optimizer.step()
  if step % 10 == 0:
    print(f'{step=:4} | Loss: {loss.item():.6f}')

step=   0 | Loss: 11.091410
step=  10 | Loss: 5.922181
step=  20 | Loss: 5.396425
step=  30 | Loss: 5.022036
step=  40 | Loss: 4.661715
step=  50 | Loss: 4.316027
step=  60 | Loss: 3.979648
step=  70 | Loss: 3.658325
step=  80 | Loss: 3.350863
step=  90 | Loss: 3.031948
step= 100 | Loss: 2.147627
step= 110 | Loss: 1.343678
step= 120 | Loss: 0.310863
step= 130 | Loss: 0.093907
step= 140 | Loss: 0.061922
step= 150 | Loss: 0.048684
step= 160 | Loss: 0.039142
step= 170 | Loss: 0.033269
step= 180 | Loss: 0.029778
step= 190 | Loss: 0.025852


In [11]:
HF_TOKEN = userdata.get('HF_TOKEN')


In [12]:
@dataclass
class DatasetConfig:
  out_dir: str = './data'
  write_batch_size: int = 128
  hf_dataset:str = 'roneneldan/TinyStories'
  max_examples: int = 600000
  n_ctx: int = 1024

In [13]:
def pretokenize_and_save(config: DatasetConfig, split='train'):
  dataset = load_dataset(
      config.hf_dataset,
      split=split,
      streaming=True,
  )
  enc = tiktoken.get_encoding('gpt2')
  os.makedirs(config.out_dir, exist_ok=True)
  data_path = pathlib.Path(config.out_dir) / f'{split}.bin'
  total_tokens = 0
  with open(data_path, mode='wb') as f:
    batch_tokens = []
    for i, example in enumerate(tqdm(dataset)):
      if i >= config.max_examples:
        break
      tokens = enc.encode_ordinary(example['text'])
      batch_tokens.extend(tokens)
      total_tokens += len(tokens)
      if (i + 1) % config.write_batch_size == 0:
        chunk = np.array(batch_tokens, dtype=np.uint16)
        f.write(chunk.tobytes())
        batch_tokens = []

    if batch_tokens:
      chunk = np.array(chunk, dtype=np.uint16)
      f.write(chunk.tobytes())

  print(f'Wrote {total_tokens=:,} to path: {str(data_path)}')


In [14]:
config = DatasetConfig()
pretokenize_and_save(config)

README.md: 0.00B [00:00, ?B/s]

600000it [02:20, 4269.56it/s]

Wrote total_tokens=134,054,957 to path: data/train.bin





In [15]:
pretokenize_and_save(config, split='validation')

21990it [00:06, 3316.04it/s]

Wrote total_tokens=4,743,928 to path: data/validation.bin





In [14]:
class TokenDataset(Dataset):
  def __init__(self, config: DatasetConfig, split='train'):
    super().__init__()
    self.block_size = config.n_ctx
    path = pathlib.Path(config.out_dir) / f'{split}.bin'
    self.tokens = np.memmap(path, dtype=np.uint16, mode='r')

  def __len__(self):
    return len(self.tokens) // self.block_size

  def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
    start = index * self.block_size
    chunk = torch.from_numpy(
        self.tokens[start: start+1+self.block_size].astype(np.int64)
    )
    x = chunk[:-1]
    y = chunk[1:]
    return x, y

In [28]:
@dataclass
class TrainingConfig:
  lr: float = 3e-5
  log_steps: int = 200
  max_iters: int = 100
  train_batch_size: int = 8
  val_batch_size: int = 4
  out_path: str = './out'


In [29]:
def evaluate(
    model: GPT2Model,
    loader: DataLoader,
    max_batches: int = 20,
    device:str='cuda') -> float:
    model.eval()
    losses = []
    enc = tiktoken.get_encoding('gpt2')
    with torch.no_grad():
      for i, (x, y) in enumerate(tqdm(loader)):
        if i >= max_batches:
          break
        x, y = x.to(device), y.to(device)
        logits, loss = model(x, targets=y)
        losses.append(loss.item())
        if i == 0:
          preds = logits.argmax(dim=-1)
          print(f'Input: {enc.decode(x[0].tolist())}\nOutput: {enc.decode(preds[0].tolist())}')
    return sum(losses) / len(losses)

In [36]:
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = GPT2Config()
model = GPT2Model(config).to(device)
data_config = DatasetConfig()
train_dataset = TokenDataset(data_config, split='train')
val_dataset = TokenDataset(data_config, split='validation')




def train(
    model: GPT2Model,
    train_dataset: TokenDataset,
    valid_dataset: TokenDataset,
    train_config: TrainingConfig):

  optimizer = optim.AdamW(model.parameters(), lr=train_config.lr)

  train_dataloader = DataLoader(
      train_dataset,
      batch_size=train_config.train_batch_size,
      shuffle=True,
      num_workers=4,
      pin_memory=True
  )

  val_dataloader = DataLoader(
      val_dataset,
      batch_size=train_config.val_batch_size,
      shuffle=False,
      pin_memory=False
  )

  best_val_loss = float('inf')
  best_model_path = pathlib.Path(train_config.out_path) / f'best_model.pt'
  os.makedirs(train_config.out_path, exist_ok=True)

  for epoch in range(train_config.max_iters):
    model.train()
    for step, (x, y) in enumerate(train_dataloader):
      x, y = x.to(device), y.to(device)
      optimizer.zero_grad()
      _, loss = model(x, targets=y)
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
      optimizer.step()

      if step % train_config.log_steps == 0:
        print(f'epoch {epoch:4} | step: {step:5d} | train_loss {loss.item():.4f}')


    val_loss = evaluate(model, val_dataloader, max_batches=2000, device=device)
    print(f'epoch: {epoch} | val loss: {val_loss:.4f}')

    if val_loss < best_val_loss:
      best_val_loss = val_loss
      torch.save({
          'epoch' : epoch,
          'model' : model.state_dict(),
          'optimizer': optimizer.state_dict(),
          'val_loss' : val_loss,
          'model_config': config,
          'data_config': data_config,
          'train_config': train_config
      }, best_model_path)

In [None]:
train_config = TrainingConfig()
train(model, train_dataset=train_dataset, valid_dataset=val_dataset, train_config=train_config)

epoch    0 | step:     0 | train_loss 10.8672
epoch    0 | step:   200 | train_loss 4.7543
epoch    0 | step:   400 | train_loss 4.2919
epoch    0 | step:   600 | train_loss 4.1205
epoch    0 | step:   800 | train_loss 3.8463
epoch    0 | step:  1000 | train_loss 3.8636
epoch    0 | step:  1200 | train_loss 3.4059
epoch    0 | step:  1400 | train_loss 3.4324
epoch    0 | step:  1600 | train_loss 3.4059
epoch    0 | step:  1800 | train_loss 3.4817
epoch    0 | step:  2000 | train_loss 3.1964
epoch    0 | step:  2200 | train_loss 3.0764
epoch    0 | step:  2400 | train_loss 3.0702
epoch    0 | step:  2600 | train_loss 2.9180
epoch    0 | step:  2800 | train_loss 3.0865
epoch    0 | step:  3000 | train_loss 2.6199
epoch    0 | step:  3200 | train_loss 2.7768
epoch    0 | step:  3400 | train_loss 2.7093
epoch    0 | step:  3600 | train_loss 2.7877
epoch    0 | step:  3800 | train_loss 2.8743
epoch    0 | step:  4000 | train_loss 2.7639
epoch    0 | step:  4200 | train_loss 2.7605
epoch    

  0%|          | 3/1160 [00:00<01:54, 10.12it/s]

Input: Spot. Spot saw the shiny car and said, "Wow, Kitty, your car is so bright and clean!" Kitty smiled and replied, "Thank you, Spot. I polish it every day."

After playing with the car, Kitty and Spot felt thirsty. They found a small pond with clear water. They drank the water and felt very happy. They played together all day and became best friends.Once upon a time, in a big forest, there lived a rhinoceros named Roxy. Roxy loved to climb. She climbed trees, rocks, and hills. One day, Roxy found an icy hill. She had never seen anything like it before. It was shiny and cold, and she wanted to climb it.

Roxy tried to climb the icy hill, but it was very slippery. She tried again and again, but she kept falling down. Roxy was sad. She wanted to climb the icy hill so much. Then, she saw a little bird named Billy. Billy saw that Roxy was sad and asked, "Why are you sad, Roxy?"

Roxy told Billy about the icy hill and how she couldn't climb it. Billy said, "I have an idea! Let's find som

100%|██████████| 1160/1160 [01:47<00:00, 10.79it/s]


epoch: 0 | val loss: 1.8179
epoch    1 | step:     0 | train_loss 1.6890
epoch    1 | step:   200 | train_loss 1.8509
epoch    1 | step:   400 | train_loss 2.0580
epoch    1 | step:   600 | train_loss 1.8764
epoch    1 | step:   800 | train_loss 1.7376
epoch    1 | step:  1000 | train_loss 1.9256
epoch    1 | step:  1200 | train_loss 1.9663
epoch    1 | step:  1400 | train_loss 1.8613
epoch    1 | step:  1600 | train_loss 1.7877
epoch    1 | step:  1800 | train_loss 1.8273
epoch    1 | step:  2000 | train_loss 1.7959
epoch    1 | step:  2200 | train_loss 1.7887
epoch    1 | step:  2400 | train_loss 1.9185
epoch    1 | step:  2600 | train_loss 1.7561
epoch    1 | step:  2800 | train_loss 1.7235
epoch    1 | step:  3000 | train_loss 1.8207
epoch    1 | step:  3200 | train_loss 1.6815
epoch    1 | step:  3400 | train_loss 1.8177
epoch    1 | step:  3600 | train_loss 1.8451
epoch    1 | step:  3800 | train_loss 1.8381
epoch    1 | step:  4000 | train_loss 1.8918
epoch    1 | step:  4200 | 

  0%|          | 2/1160 [00:00<01:49, 10.59it/s]

Input: Spot. Spot saw the shiny car and said, "Wow, Kitty, your car is so bright and clean!" Kitty smiled and replied, "Thank you, Spot. I polish it every day."

After playing with the car, Kitty and Spot felt thirsty. They found a small pond with clear water. They drank the water and felt very happy. They played together all day and became best friends.Once upon a time, in a big forest, there lived a rhinoceros named Roxy. Roxy loved to climb. She climbed trees, rocks, and hills. One day, Roxy found an icy hill. She had never seen anything like it before. It was shiny and cold, and she wanted to climb it.

Roxy tried to climb the icy hill, but it was very slippery. She tried again and again, but she kept falling down. Roxy was sad. She wanted to climb the icy hill so much. Then, she saw a little bird named Billy. Billy saw that Roxy was sad and asked, "Why are you sad, Roxy?"

Roxy told Billy about the icy hill and how she couldn't climb it. Billy said, "I have an idea! Let's find som

100%|██████████| 1160/1160 [01:47<00:00, 10.81it/s]


epoch: 1 | val loss: 1.5978
epoch    2 | step:     0 | train_loss 1.4615
epoch    2 | step:   200 | train_loss 1.6890
epoch    2 | step:   400 | train_loss 1.5846
epoch    2 | step:   600 | train_loss 1.6913
epoch    2 | step:   800 | train_loss 1.6966
epoch    2 | step:  1000 | train_loss 1.5864
epoch    2 | step:  1200 | train_loss 1.5298
epoch    2 | step:  1400 | train_loss 1.5734
epoch    2 | step:  1600 | train_loss 1.5904
epoch    2 | step:  1800 | train_loss 1.7103
epoch    2 | step:  2000 | train_loss 1.5467
epoch    2 | step:  2200 | train_loss 1.4878
epoch    2 | step:  2400 | train_loss 1.4842
epoch    2 | step:  2600 | train_loss 1.5420
epoch    2 | step:  2800 | train_loss 1.6234
epoch    2 | step:  3000 | train_loss 1.6033
epoch    2 | step:  3200 | train_loss 1.7546
epoch    2 | step:  3400 | train_loss 1.4086
epoch    2 | step:  3600 | train_loss 1.6263
epoch    2 | step:  3800 | train_loss 1.6853
epoch    2 | step:  4000 | train_loss 1.5591
epoch    2 | step:  4200 | 

  0%|          | 3/1160 [00:00<01:50, 10.50it/s]

Input: Spot. Spot saw the shiny car and said, "Wow, Kitty, your car is so bright and clean!" Kitty smiled and replied, "Thank you, Spot. I polish it every day."

After playing with the car, Kitty and Spot felt thirsty. They found a small pond with clear water. They drank the water and felt very happy. They played together all day and became best friends.Once upon a time, in a big forest, there lived a rhinoceros named Roxy. Roxy loved to climb. She climbed trees, rocks, and hills. One day, Roxy found an icy hill. She had never seen anything like it before. It was shiny and cold, and she wanted to climb it.

Roxy tried to climb the icy hill, but it was very slippery. She tried again and again, but she kept falling down. Roxy was sad. She wanted to climb the icy hill so much. Then, she saw a little bird named Billy. Billy saw that Roxy was sad and asked, "Why are you sad, Roxy?"

Roxy told Billy about the icy hill and how she couldn't climb it. Billy said, "I have an idea! Let's find som

100%|██████████| 1160/1160 [01:46<00:00, 10.86it/s]


epoch: 2 | val loss: 1.5028
epoch    3 | step:     0 | train_loss 1.5217
epoch    3 | step:   200 | train_loss 1.4996
epoch    3 | step:   400 | train_loss 1.4721
epoch    3 | step:   600 | train_loss 1.4095
epoch    3 | step:   800 | train_loss 1.4377
epoch    3 | step:  1000 | train_loss 1.4539
epoch    3 | step:  1200 | train_loss 1.5956
epoch    3 | step:  1400 | train_loss 1.6113
epoch    3 | step:  1600 | train_loss 1.6243
epoch    3 | step:  1800 | train_loss 1.7765
epoch    3 | step:  2000 | train_loss 1.4571
epoch    3 | step:  2200 | train_loss 1.4983
epoch    3 | step:  2400 | train_loss 1.3908
epoch    3 | step:  2600 | train_loss 1.6217
epoch    3 | step:  2800 | train_loss 1.7734
epoch    3 | step:  3000 | train_loss 1.4773
epoch    3 | step:  3200 | train_loss 1.4317
epoch    3 | step:  3400 | train_loss 1.5083
epoch    3 | step:  3600 | train_loss 1.6837
epoch    3 | step:  3800 | train_loss 1.3481
epoch    3 | step:  4000 | train_loss 1.5909
epoch    3 | step:  4200 | 

  0%|          | 3/1160 [00:00<01:50, 10.47it/s]

Input: Spot. Spot saw the shiny car and said, "Wow, Kitty, your car is so bright and clean!" Kitty smiled and replied, "Thank you, Spot. I polish it every day."

After playing with the car, Kitty and Spot felt thirsty. They found a small pond with clear water. They drank the water and felt very happy. They played together all day and became best friends.Once upon a time, in a big forest, there lived a rhinoceros named Roxy. Roxy loved to climb. She climbed trees, rocks, and hills. One day, Roxy found an icy hill. She had never seen anything like it before. It was shiny and cold, and she wanted to climb it.

Roxy tried to climb the icy hill, but it was very slippery. She tried again and again, but she kept falling down. Roxy was sad. She wanted to climb the icy hill so much. Then, she saw a little bird named Billy. Billy saw that Roxy was sad and asked, "Why are you sad, Roxy?"

Roxy told Billy about the icy hill and how she couldn't climb it. Billy said, "I have an idea! Let's find som

100%|██████████| 1160/1160 [01:46<00:00, 10.85it/s]


epoch: 3 | val loss: 1.4441
epoch    4 | step:     0 | train_loss 1.3455
epoch    4 | step:   200 | train_loss 1.5556
epoch    4 | step:   400 | train_loss 1.3295
epoch    4 | step:   600 | train_loss 1.4188
epoch    4 | step:   800 | train_loss 1.4042
epoch    4 | step:  1000 | train_loss 1.3990
epoch    4 | step:  1200 | train_loss 1.3903
epoch    4 | step:  1400 | train_loss 1.4777
epoch    4 | step:  1600 | train_loss 1.4189
epoch    4 | step:  1800 | train_loss 1.4376
epoch    4 | step:  2000 | train_loss 1.3898
epoch    4 | step:  2200 | train_loss 1.4327
epoch    4 | step:  2400 | train_loss 1.5217
epoch    4 | step:  2600 | train_loss 1.4153
epoch    4 | step:  2800 | train_loss 1.4220
epoch    4 | step:  3000 | train_loss 1.3315
