In [None]:
#GPT2 from the karpathy video
#Train it with regular (from scratch adam)
#Also train it with Galore



Karpathy GPT2 Model

In [None]:
import os
import math
import time
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, 'heads dont match'
        #Key,query, value projection for all heads but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias = False)
        #output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1.0
        #regularization
        self.n_head = config.n_head
        self.n_embed = config.n_embd
        self.register_buffer('bias', torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() #batch size, seq length, embedding dimensionality
        #calculate k,q,v for all heads in batch and move head forward to be the batch
        #nh is number of heads, hs is head size, c is channgel
        #split into heads
        qkv = self.c_attn(x)
        q,k,v = qkv.split(self.n_embed,dim = 2)
        k = k.view(B,T,self.n_head, C // self.n_head).transpose(1,2) #(B, nh, t ,hs)
        q = q.view(B,T,self.n_head, C // self.n_head).transpose(1,2) #(B, nh, t ,hs)
        v = v.view(B,T,self.n_head, C // self.n_head).transpose(1,2) #(B, nh, t ,hs)
        #attention (materliazes the large (T,T) matrix for all q's and v's)
        #-------Old Atention------
        #att = (q @ k.transpose(-2,-1)) * (1.0 / math.sqrt(k.size(-1)))
        #att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) #makes it attend to only past tokens never future
        #att = F.softmax(att, dim = -1)
        #y = att @ v #(B,nh, T,T) x (B,nh, T ,hs) -> (B,nh, T, hs)
        #------------------------
        #Flash Attention
        y = F.scaled_dot_product_attention(q,k,v,is_causal = True)
        y = y.transpose(1,2).contiguous().view(B,T,C) #reassemble all outputs
        y = self.c_proj(y) #GPT proposes to scale this down by 1/sqrt(N)

        return y



class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.gelu = nn.GELU(approximate = 'tanh')

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
#Transformers are reduce functions
#MLP are map
#think of attention as a map reduce
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    vocab_size: int = 50304
    block_size: int = 1024
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList(Block(config) for _ in range(config.n_layer)),
            ln_f = nn.LayerNorm(config.n_embd)
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)

      #Weight sharing scheme (copies the data pointer)
        self.transformer.wte.weight = self.lm_head.weight
      #init params
        self.apply(self._init_weights)
    def _init_weights(self,module):
        if isinstance(module,nn.Linear):
            std = 0.02
            if hasattr(module,'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5 #each layer has 2 blocks attention and MLP
            torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module,nn.Embedding):
            torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02)

    def forward(self, idx,targets = None):
      #idx is of shape (B,T)
      B,T = idx.size()
      assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
      pos = torch.arange(0,T,device = idx.device)
      pos_emb = self.transformer.wpe(pos)
      tok_emb = self.transformer.wte(idx)
      x = tok_emb + pos_emb
      for block in self.transformer.h:
        x = block(x)

      x = self.transformer.ln_f(x)
      logits = self.lm_head(x) #(B,T, vocab_size)
      loss = None
      if targets is not None:

        loss = F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1))
      return logits,loss





    @classmethod
    def from_pretrained(cls,model_type):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        # create a from-scratch initialized minGPT model
        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model
    def configure_optimizer(self,weight_decay,learning_rate,device):
      param_dict = {pn: p for pn, p in self.named_parameters()}
      param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
      #create optim groups any paratemer that is 2D gets decayed otherwise not
      #i.e. all weight tensors in matmuls + embeddings get decayed and layernorms dont
      decay_params = [p for n,p in param_dict.items() if p.dim()>= 2]
      nodecay_params = [p for n,p in param_dict.items() if p.dim()< 2]
      optim_groups = [
          {'params':decay_params,'weight_decay':weight_decay},
          {'params':nodecay_params,'weight_decay':0.0}
      ]
      num_decay_params = sum(p.numel() for p in decay_params)
      num_nodecay_params = sum(p.numel() for p in nodecay_params)
      print(f'num decay params: {num_decay_params}')
      print(f'num nodecay params: {num_nodecay_params}')
      #CreateAdamW optimizer
      fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
      #fused means that all kernels are fused into single kernel so a single time on all parameters you update them
      use_fused = fused_available and 'cuda' in device
      print(f'using fusedAdamW: {use_fused}')
      optimizer = torch.optim.AdamW(optim_groups,lr = learning_rate,fused = use_fused,betas = (0.9,0.95),eps = 1e-8)
      return optimizer




class DataLoaderLite:
  def __init__(self,B,T):
    self.B = B
    self.T = T

    # at init load tokens from disk and store them in memory

    with open('input.txt', 'r') as f:
      text = f.read()
    enc = tiktoken.get_encoding('gpt2')
    tokens = enc.encode(text)
    self.tokens = torch.tensor(tokens)
    print(f'loaded {len(self.tokens)} tokens')

    self.current_position = 0

  def next_batch(self):
    B,T = self.B,self.T
    buf = self.tokens[self.current_position:self.current_position+B*T+1]
    x = (buf[:-1]).view(B,T)
    y = (buf[1:]).view(B,T)
    self.current_position += B*T

    if self.current_position + (B*T + 1)> len(self.tokens):
      self.current_position = 0
    return x,y

DataSet Retrieval

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!pip install tiktoken
import tiktoken

--2024-07-07 19:25:36--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-07-07 19:25:37 (139 MB/s) - ‘input.txt’ saved [1115394/1115394]

Collecting tiktoken
  Downloading tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.7.0


In [None]:
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_iters = 10
max_steps = 50
def get_lr(step):
  #1 Linear warmup for warmup_iters steps
  if step < warmup_iters:
    return max_lr * (step + 1) / warmup_iters
  #2 if step > lr_decay_iters retun min learning rate
  if step > max_steps:
    return min_lr
  #3 in bewtween use cosine decay
  decay_ratio = (step - warmup_iters) / (max_steps - warmup_iters)
  assert 0<=decay_ratio<=1
  coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
  return min_lr + coeff * (max_lr - min_lr)

import time
model = GPT(GPTConfig())
model = torch.compile(model)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
torch.manual_seed(1337)


train_loader = DataLoaderLite(16,1024)

torch.set_float32_matmul_precision('high')



loaded 338025 tokens


Import the Galore Paremeters

In [None]:
#rank, subspace change frequency
rank = 5
subspace_change_freq = 50

target_modules_list = ['attn','mlp']
galore_params = []

for module_name, module in model.named_modules():
    if not isinstance(module, nn.Linear):
        continue

    if not any(target_key in module_name for target_key in target_modules_list):
        continue

    print('enable GaLore for weights in module: ', module_name)
    galore_params.append(module.weight)

id_galore_params = [id(p) for p in galore_params]
# make parameters without "rank" to another group
regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
# then call galore_adamw
param_groups = [{'params': regular_params},
                {'params': galore_params, 'rank': rank, 'subspace_change_freq': subspace_change_freq}]


enable GaLore for weights in module:  _orig_mod.transformer.h.0.attn.c_attn
enable GaLore for weights in module:  _orig_mod.transformer.h.0.attn.c_proj
enable GaLore for weights in module:  _orig_mod.transformer.h.0.mlp.c_fc
enable GaLore for weights in module:  _orig_mod.transformer.h.0.mlp.c_proj
enable GaLore for weights in module:  _orig_mod.transformer.h.1.attn.c_attn
enable GaLore for weights in module:  _orig_mod.transformer.h.1.attn.c_proj
enable GaLore for weights in module:  _orig_mod.transformer.h.1.mlp.c_fc
enable GaLore for weights in module:  _orig_mod.transformer.h.1.mlp.c_proj
enable GaLore for weights in module:  _orig_mod.transformer.h.2.attn.c_attn
enable GaLore for weights in module:  _orig_mod.transformer.h.2.attn.c_proj
enable GaLore for weights in module:  _orig_mod.transformer.h.2.mlp.c_fc
enable GaLore for weights in module:  _orig_mod.transformer.h.2.mlp.c_proj
enable GaLore for weights in module:  _orig_mod.transformer.h.3.attn.c_attn
enable GaLore for weight

In [None]:
class GaloreProjecter:
  def __init__(self,rank,subspace_change_freq):
     self.rank = rank
     self.subspace_change_freq = subspace_change_freq
     self.ortho_matrix = None

  def project(self,full_rank_grad, iter ):
    if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
      if self.ortho_matrix is None or iter % self.subspace_change_freq == 0:
        self.ortho_matrix = self.SVD(full_rank_grad,self.rank,'right')
      low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t())
    else:
      if self.ortho_matrix is None or iter % self.subspace_change_freq == 0:
        self.ortho_matrix = self.SVD(full_rank_grad,self.rank,'left')
      low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad)

    return low_rank_grad

  def projectback(self,low_rank_grad):
    if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
      full_rank_grad = torch.matmul(low_rank_grad,self.ortho_matrix)
    else:
      full_rank_grad = torch.matmul(self.ortho_matrix,low_rank_grad)
    return full_rank_grad

  def SVD(self,weights,rank,type):
    module_params = weights

    if module_params.data.dtype != torch.float:
        float_data = False
        original_type = module_params.data.dtype
        original_device = module_params.data.device
        matrix = module_params.data.float()
    else:
        float_data = True
        matrix = module_params.data

    U,S,V = torch.linalg.svd(matrix,full_matrices = False)

    if type == 'right':
      B = V[:rank,:]
      if not float_data:
        B = B.to(original_device).type(original_type)
      return B
    if type == 'left':
      A = U[:,:rank]
      if not float_data:
        A = A.to(original_device).type(original_type)
      return A
    else:
      raise ValueError(f'unknown type {type}')



In [None]:
import torch
from torch import nn
from torch.optim import Optimizer
import math
import warnings
from typing import Callable, Iterable, Tuple

class AdamW(Optimizer):

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.0,
        correct_bias: bool = True,
        no_deprecation_warning: bool = False,
    ):
      if lr < 0.0:
              raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
      if not 0.0 <= betas[0] < 1.0:
          raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
      if not 0.0 <= betas[1] < 1.0:
          raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
      if not 0.0 <= eps:
          raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
      defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
      super().__init__(params, defaults)

    @torch.no_grad()
    def step(self,closure: Callable = None):
      """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
      loss = None
      if closure is not None:
          loss = closure()

      for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue
            grad = p.grad
            if grad.is_sparse:
              raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

            state = self.state[p]

            if 'step' not in state:
                state['step'] = 0

            #intialize the state
            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(grad) #Mt
                state['exp_avg_sq'] = torch.zeros_like(grad) #Vt

            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
            beta1,beta2 = group['betas']

            state['step'] += 1

            exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)
            exp_avg_sq.mul_(beta2).addcmul_(grad,grad,value = 1 - beta2)
            denom = exp_avg_sq.sqrt().add_(group['eps'])

            step_size = group['lr']
            if group['correct_bias']:
              bias_correction1 = 1 - beta1 ** state['step']
              bias_correction2 = 1 - beta2 ** state['step']
              step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

            norm_grad = exp_avg/denom

            p.add_(norm_grad,alpha = -step_size)

            if group["weight_decay"] > 0.0:
                    p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

      return loss



def print_memory_usage(self):
            for group in self.param_groups:
                for p in group['params']:
                    state = self.state[p]
                    print(f"Parameter {p.shape}:")
                    if 'exp_avg' in state:
                        exp_avg_size = state['exp_avg'].numel() * state['exp_avg'].element_size()
                        print(f"  exp_avg: {exp_avg_size / 1024 ** 2:.2f} MB")
                    if 'exp_avg_sq' in state:
                        exp_avg_sq_size = state['exp_avg_sq'].numel() * state['exp_avg_sq'].element_size()
                        print(f"  exp_avg_sq: {exp_avg_sq_size / 1024 ** 2:.2f} MB")
                    if 'projector' in state:
                        # Assuming the projector has some tensor attributes you want to check
                        projector_size = sum(t.numel() * t.element_size() for t in state['projector'].parameters())
                        print(f"  projector: {projector_size / 1024 ** 2:.2f} MB")



In [None]:
import torch
from torch import nn
from torch.optim import Optimizer
import math
import warnings
from typing import Callable, Iterable, Tuple

class AdamWGalore(Optimizer):

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.0,
        correct_bias: bool = True,
        no_deprecation_warning: bool = False,
    ):
      if lr < 0.0:
              raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
      if not 0.0 <= betas[0] < 1.0:
          raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
      if not 0.0 <= betas[1] < 1.0:
          raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
      if not 0.0 <= eps:
          raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
      defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
      super().__init__(params, defaults)

    @torch.no_grad()
    def step(self,closure: Callable = None):
      """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
      loss = None
      if closure is not None:
          loss = closure()

      for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue
            grad = p.grad
            if grad.is_sparse:
              raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

            state = self.state[p]

            if 'step' not in state:
                state['step'] = 0

            if 'dim'  not in group:
              group['dim'] = 2

            #Galore projection
            if 'rank' in group:
              if 'projector' not in group:
                if group['dim'] <= 2:
                  state['projector'] = GaloreProjecter(group['rank'],group['subspace_change_freq'])

              grad = state['projector'].project(grad,state['step'])

            #intialize the state
            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(grad) #Mt
                state['exp_avg_sq'] = torch.zeros_like(grad) #Vt

            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
            beta1,beta2 = group['betas']

            state['step'] += 1

            exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)
            exp_avg_sq.mul_(beta2).addcmul_(grad,grad,value = 1 - beta2)


            step_size = group['lr']
            if group['correct_bias']:
              bias_correction1 = 1 - beta1 ** state['step']
              bias_correction2 = 1 - beta2 ** state['step']
              step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
              exp_avg = exp_avg/bias_correction1
              exp_avg_sq = exp_avg_sq/bias_correction2

            denom = exp_avg_sq.sqrt().add_(group['eps'])
            norm_grad = exp_avg/denom


            if 'rank' in group:
              norm_grad = state['projector'].projectback(norm_grad)

            p.add_(norm_grad,alpha = -step_size)

            if group["weight_decay"] > 0.0:
                    p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

      return loss







In [None]:
#rank, subspace change frequency
def train_rank_subspace(rank,subspace_change_freq):
  print('-----------------------------------------------------------')
  print(f'rank: {rank}, subspace_change_freq: {subspace_change_freq}')
  print('-----------------------------------------------------------')
  model = GPT(GPTConfig())
  model = torch.compile(model)
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  model.to(device)
  torch.manual_seed(1337)


  train_loader = DataLoaderLite(16,1024)

  torch.set_float32_matmul_precision('high')


  target_modules_list = ['attn','mlp']
  galore_params = []

  for module_name, module in model.named_modules():
      if not isinstance(module, nn.Linear):
          continue

      if not any(target_key in module_name for target_key in target_modules_list):
          continue


      galore_params.append(module.weight)

  id_galore_params = [id(p) for p in galore_params]
  # make parameters without "rank" to another group
  regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]

  # then call galore_adamw
  param_groups = [{'params': regular_params},
                  {'params': galore_params, 'rank': rank, 'subspace_change_freq': subspace_change_freq}]

  ###Optimizer initialization
  optimizer = AdamWGalore(param_groups, lr = 3e-4, betas = (0.9,0.95),eps = 1e-8)


  epochs = 50
  for i in range(epochs):
    t0 = time.time()
    optimizer.zero_grad()
    x,y = train_loader.next_batch()
    x = x.to(device)
    y = y.to(device)
    with torch.autocast(device, dtype=torch.bfloat16):
      logits,loss = model(x,y)
      #import code
      #code.interact(local=locals())
    loss.backward()
    #Clip the gradient
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(),1.0) #sometimes you get unlucky during optimization and really high loss leads to high gradient
    lr = get_lr(i)
    for param_group in optimizer.param_groups:
      param_group['lr'] = lr
    optimizer.step()
    torch.cuda.synchronize() #wait for gpu to finish
    t1 = time.time()
    dt = (t1-t0)*1000
    tokens_per_sec = (train_loader.B*train_loader.T)/(t1-t0)
    print(f"step {i}, loss:{loss.item()},lr: {lr:.10f} ,norm: {norm:.4f} time {dt:.2f}ms, tokens/sec {tokens_per_sec:.2f}")

In [None]:

ranks = [1,100,500]
subspace_change_freqs = [5,25]

for rank in ranks:
  for subspace_change_freq in subspace_change_freqs:
    train_rank_subspace(rank,subspace_change_freq)

-----------------------------------------------------------
rank: 1, subspace_change_freq: 5
-----------------------------------------------------------
loaded 338025 tokens
step 0, loss:11.00564956665039,lr: 0.0000600000 ,norm: 7.1829 time 24076.71ms, tokens/sec 680.49
step 1, loss:10.729747772216797,lr: 0.0001200000 ,norm: 7.6998 time 2989.19ms, tokens/sec 5481.08
step 2, loss:10.444694519042969,lr: 0.0001800000 ,norm: 6.5645 time 2933.63ms, tokens/sec 5584.89
step 3, loss:10.155477523803711,lr: 0.0002400000 ,norm: 6.4403 time 2934.95ms, tokens/sec 5582.37
step 4, loss:9.812883377075195,lr: 0.0003000000 ,norm: 6.5733 time 2932.63ms, tokens/sec 5586.79
step 5, loss:9.488386154174805,lr: 0.0003600000 ,norm: 5.8322 time 2957.08ms, tokens/sec 5540.60
step 6, loss:9.120413780212402,lr: 0.0004200000 ,norm: 5.0925 time 2951.23ms, tokens/sec 5551.58
step 7, loss:8.64113998413086,lr: 0.0004800000 ,norm: 4.0818 time 2955.20ms, tokens/sec 5544.13
step 8, loss:8.146684646606445,lr: 0.0005400000 

KeyboardInterrupt: 

Normal AdamW implementation

In [None]:
model = GPT(GPTConfig())
model = torch.compile(model)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
torch.manual_seed(1337)


train_loader = DataLoaderLite(16,1024)

torch.set_float32_matmul_precision('high')

###Optimizer initialization
optimizer = AdamW(param_groups, lr = 3e-4, betas = (0.9,0.95),eps = 1e-8)


epochs = 50
for i in range(epochs):
  t0 = time.time()
  optimizer.zero_grad()
  x,y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  with torch.autocast(device, dtype=torch.bfloat16):
    logits,loss = model(x,y)
    #import code
    #code.interact(local=locals())
  loss.backward()
  #Clip the gradient
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(),1.0) #sometimes you get unlucky during optimization and really high loss leads to high gradient
  lr = get_lr(i)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr
  optimizer.step()
  torch.cuda.synchronize() #wait for gpu to finish
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B*train_loader.T)/(t1-t0)
  print(f"step {i}, loss:{loss.item()},lr: {lr:.10f} ,norm: {norm:.4f} time {dt:.2f}ms, tokens/sec {tokens_per_sec:.2f}")

loaded 338025 tokens
step 0, loss:11.00564956665039,lr: 0.0000600000 ,norm: 7.1830 time 41452.72ms, tokens/sec 395.25
step 1, loss:11.014896392822266,lr: 0.0001200000 ,norm: 8.9117 time 102.46ms, tokens/sec 159912.40
step 2, loss:11.011943817138672,lr: 0.0001800000 ,norm: 7.9937 time 101.94ms, tokens/sec 160716.49
step 3, loss:10.99173355102539,lr: 0.0002400000 ,norm: 7.9452 time 102.49ms, tokens/sec 159852.51
step 4, loss:11.011993408203125,lr: 0.0003000000 ,norm: 8.1919 time 102.21ms, tokens/sec 160291.37
step 5, loss:11.006282806396484,lr: 0.0003600000 ,norm: 7.5789 time 102.31ms, tokens/sec 160134.87
step 6, loss:10.990325927734375,lr: 0.0004200000 ,norm: 7.1289 time 101.96ms, tokens/sec 160690.93
step 7, loss:10.999454498291016,lr: 0.0004800000 ,norm: 7.2375 time 102.28ms, tokens/sec 160181.90
step 8, loss:10.989326477050781,lr: 0.0005400000 ,norm: 7.9409 time 102.77ms, tokens/sec 159429.73
step 9, loss:10.975311279296875,lr: 0.0006000000 ,norm: 7.8172 time 102.59ms, tokens/sec 15