In [1]:
#!export CUDA_VISIBLE_DEVICES=0,1

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3,4,5,6,7"
import torch
print(torch.cuda.device_count())
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from random import randint

import tqdm

8


In [3]:
from rotation_utils import random_orthogonal_matrix
from hadamard_utils import random_hadamard_matrix, apply_exact_had_to_linear
from quant_utils import ActQuantWrapper

import utils
import model_utils
import data_utils
import transformers
import quant_utils
import rotation_utils
import gptq_utils
import eval_utils
import hadamard_utils

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers.models.llama.modeling_llama import LlamaForCausalLM

In [5]:
class RotatedEmbedding(nn.Embedding):
    
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        padding_idx=None,
        max_norm=None,
        norm_type=2.0,
        scale_grad_by_freq=False,
        sparse=False,
        _weight=None,
        _freeze=False,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(num_embeddings, embedding_dim, padding_idx,
                         max_norm, norm_type, scale_grad_by_freq, sparse,
                         _weight, _freeze, device, dtype)
        
        if Q is not None:
            self.register_buffer("Q", Q)
            #self.register_parameter("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x):
        W = self.weight
        
        #if W.device != self.Q.device:
        #    self.Q = self.Q.to(W.device)
        
        if self.Q is not None:
            W_ = torch.matmul(W.to(dtype=self.Q.dtype), self.Q.to(W.device)).to(dtype=W.dtype)
            #print('emb')
            #print(W_.grad_fn)
        else:
            W_ = W
        
        #print(x.device, W.device)
        
        return F.embedding(
            x, W_,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse)

class RotatedHead(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Q is not None:
            self.register_buffer("Q", Q)
            #self.register_parameter("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x, Q=None):
        W = self.weight
        
        #if W.device != self.Q.device:
        #    self.Q = self.Q.to(W.device)
        
        if self.Q is not None:
            W_ = torch.matmul(W.to(dtype=self.Q.dtype), self.Q.to(W.device)).to(dtype=W.dtype)
            
            #print('head')
            #print(W_.grad_fn)
        else:
            W_ = W
        
        return F.linear(
            x, W_,
        )

class RotatedLinearIn(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Q is not None:
            self.register_buffer("Q", Q)
            #self.register_parameter("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x, Q=None):
        W = self.weight
        
        #if W.device != self.Q.device:
        #    self.Q = self.Q.to(W.device)
        
        if self.Q is not None:
            W_ = torch.matmul(W.to(dtype=self.Q.dtype), self.Q.to(W.device)).to(dtype=W.dtype)
            
            #print('linear in')
            #print(W_.grad_fn)
        else:
            W_ = W
        
        return F.linear(
            x, W_,
        )


class RotatedOVProj(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Qin=None,
        Qout=None,
        output=False,
        nheads=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Qin is not None:
            self.register_buffer("Qin", Qin)
            #self.register_parameter("Qin", Qin)
        else:
            self.Qin = None
        
        if Qout is not None:
            self.register_buffer("Qout", Qout)
            #self.register_parameter("Qout", Qout)
        else:
            self.Qout = None
        
        self.output = output
        self.nheads = nheads
    
    def forward(self, x):
        W = self.weight
        
        # if W.device != self.Qin.device:
        #     self.Qin = self.Qin.to(W.device)
        
        # if W.device != self.Qout.device:
        #     self.Qout = self.Qout.to(W.device)
        
        if self.Qin is not None:
            if self.output:
                W_ = torch.matmul(W.to(dtype=self.Qin.dtype), self.Qin.to(W.device)).to(dtype=W.dtype)
            else:
                W_ = W.to(dtype=self.Qin.dtype).reshape(W.size(0), self.nheads, -1)
                W_ = torch.einsum('inh,hj->inj', W_, self.Qin.to(W.device)).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear o', W.grad_fn, W_.grad_fn, self.Qin.grad_fn)
        else:
            W_ = W
        
        if self.Qout is not None:
            if self.output:
                W_ = W_.to(dtype=self.Qout.dtype).reshape(self.nheads, -1, W.size(1))
                W_ = torch.einsum('ih,nhj->nij', self.Qout.to(W.device).T, W_).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear v', W.grad_fn, W_.grad_fn, self.Qout.grad_fn)
                #print(W_.grad_fn)
            else:
                W_ = torch.matmul(self.Qout.to(W.device).T, W_.to(dtype=self.Qout.dtype)).to(dtype=W.dtype)
        else:
            pass
        
        return F.linear(
            x, W_,
        )


class RotatedLinearOut(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Q is not None:
            self.register_buffer("Q", Q)
            #self.register_parameter("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x, Q=None):
        W = self.weight
        b = self.bias
        
        # if W.device != self.Q.device:
        #     self.Q = self.Q.to(W.device)
        
        if self.Q is not None:
            W_ = torch.matmul(self.Q.to(W.device).T, W.to(dtype=self.Q.dtype)).to(dtype=W.dtype)
            
            #print('linear out')
            #print(W_.grad_fn)
            if b is not None:
                b_ = torch.matmul(self.Q.to(W.device).T, b.to(dtype=self.Q.dtype)).to(dtype=b.dtype)
            else:
                b_ = b
        else:
            W_ = W
            b_ = b
        
        return F.linear(
            x, W_, b_
        )

In [6]:
def rotate_embeddings(model, Q):
    
    original_emb = model.model.embed_tokens
    
    new_emb = RotatedEmbedding(
        original_emb.num_embeddings,
        original_emb.embedding_dim,
        original_emb.padding_idx,
        original_emb.max_norm,
        original_emb.norm_type,
        original_emb.scale_grad_by_freq,
        original_emb.sparse,
        original_emb.weight.data,
        not original_emb.weight.requires_grad,
        original_emb.weight.data.device,
        original_emb.weight.data.dtype,
        Q
    )
    
    setattr(model.model, 'embed_tokens', new_emb)


def rotate_attention_inputs(layer, Q) -> None:
    # Rotate the WQ, WK and WV matrices of the self-attention layer.
    for name in ['q_proj', 'k_proj']:#, 'v_proj']:
        original_matrix = getattr(layer.self_attn, name)
        
        new_matrix = RotatedLinearIn(
            original_matrix.in_features,
            original_matrix.out_features,
            original_matrix.bias is not None,
            original_matrix.weight.data.device,
            original_matrix.weight.data.dtype,
            Q
        )
        
        new_matrix.weight.data = original_matrix.weight.data.clone()
        if original_matrix.bias is not None:
            new_matrix.bias.data = original_matrix.bias.data.clone()
        
        setattr(layer.self_attn, name, new_matrix)
        del original_matrix


def rotate_attention_output(layer, Q) -> None:
    # Rotate output matrix of the self-attention layer.
    original_matrix = layer.self_attn.o_proj
    
    new_matrix = RotatedLinearOut(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.self_attn, 'o_proj', new_matrix)
    del original_matrix


def rotate_mlp_input(layer, Q):
    # Rotate the MLP input weights.
    
    for name in ['up_proj', 'gate_proj']:
        original_matrix = getattr(layer.mlp, name)
        
        new_matrix = RotatedLinearIn(
            original_matrix.in_features,
            original_matrix.out_features,
            original_matrix.bias is not None,
            original_matrix.weight.data.device,
            original_matrix.weight.data.dtype,
            Q
        )
        
        new_matrix.weight.data = original_matrix.weight.data.clone()
        if original_matrix.bias is not None:
            new_matrix.bias.data = original_matrix.bias.data.clone()
        
        setattr(layer.mlp, name, new_matrix)
        del original_matrix


def rotate_mlp_output(layer, Q):
    # Rotate the MLP output weights and bias.
    original_matrix = layer.mlp.down_proj
    
    new_matrix = RotatedLinearOut(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.mlp, 'down_proj', new_matrix)
    del original_matrix


def rotate_head(model, Q: torch.Tensor) -> None:
    # Rotate the head.
    original_matrix = model.lm_head
    
    new_matrix = RotatedLinearIn(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(model, 'lm_head', new_matrix)
    del original_matrix


def rotate_ov_proj(layer, Q1, Q2, nheads):
    #print(nheads)
    original_matrix = layer.self_attn.o_proj
    
    new_matrix = RotatedOVProj(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q2, Q1, False, nheads
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.self_attn, 'o_proj', new_matrix)
    del original_matrix
    
    original_matrix = layer.self_attn.v_proj
    
    new_matrix = RotatedOVProj(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q1, Q2, True, nheads
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.self_attn, 'v_proj', new_matrix)
    del original_matrix


def rotate_model(model, args):
    #q = random_orthogonal_matrix(model.config.hidden_size, utils.DEV).to(dtype=torch.float32)
    
    config = model.config
    num_heads = config.num_attention_heads
    model_dim = config.hidden_size
    head_dim = model_dim // num_heads
    
    Q1 = random_hadamard_matrix(model.config.hidden_size, 'cpu').to(dtype=torch.float64).requires_grad_(True)
    Q1 = nn.Parameter(Q1, requires_grad=True)
    Q2s = []
    #q2 = random_hadamard_matrix(head_dim, utils.DEV).to(dtype=torch.float32)
    #Q2 = nn.Parameter(q2, requires_grad=True)
    
    model_type = model_utils.model_type_extractor(model)
    rotate_embeddings(model, Q1)
    rotate_head(model, Q1)
    utils.cleanup_memory()
    layers = model_utils.get_transformer_layers(model, 
                                                model_type=model_type)
    
    for idx, layer in enumerate(tqdm.tqdm(layers, unit="layer", desc="Rotating")):
        Q2 = random_hadamard_matrix(head_dim, layers[idx].self_attn.v_proj.weight.device).to(dtype=torch.float64).requires_grad_(True)#.clone().detach().requires_grad_(True)
        #q2 = random_hadamard_matrix(head_dim, utils.DEV).to(dtype=torch.float64)#.clone().detach().requires_grad_(True)
        #print(q2.device)
        Q2 = nn.Parameter(Q2, requires_grad=True)
        rotate_attention_inputs(layers[idx], Q1)
        #rotate_attention_output(layers[idx], Q1)
        rotate_mlp_input(layers[idx], Q1)
        rotate_mlp_output(layers[idx], Q1)
        rotate_ov_proj(layers[idx], Q1, Q2, num_heads)
        
        Q2s.append(Q2)
        #print(str(idx) + '-----------')
        #print(layer.self_attn.v_proj)
        #print(layer.self_attn.o_proj)
    return Q1, Q2s

In [7]:
args = utils.parser_gen('--model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --v_bits 4 --k_bits 4 --w_bits 16 --w_clip --bsz 1'.split())

Arguments: 
{'a_asym': False,
 'a_bits': 4,
 'a_clip_ratio': 1.0,
 'a_groupsize': -1,
 'act_order': False,
 'bsz': 1,
 'cal_dataset': 'wikitext2',
 'capture_layer_io': False,
 'distribute': False,
 'eval_dataset': 'wikitext2',
 'fp32_had': False,
 'hf_token': None,
 'int8_down_proj': False,
 'k_asym': False,
 'k_bits': 4,
 'k_clip_ratio': 1.0,
 'k_groupsize': -1,
 'k_pre_rope': False,
 'layer_idx': 10,
 'lm_eval': False,
 'lm_eval_batch_size': 128,
 'load_qmodel_path': None,
 'model': 'meta-llama/Llama-2-7b-hf',
 'nsamples': 128,
 'percdamp': 0.01,
 'rotate': True,
 'rotate_mode': 'hadamard',
 'rotation_seed': -1,
 'save_name': '20240618_075552',
 'save_path': '/ceph/echoi/codes/QuaRot/fake_quant/experiments/meta-llama/Llama-2-7b-hf/20240618_075552',
 'save_qmodel_path': None,
 'seed': 0,
 'tasks': ['piqa',
           'hellaswag',
           'arc_easy',
           'arc_challenge',
           'winogrande',
           'lambada'],
 'v_asym': False,
 'v_bits': 4,
 'v_clip_ratio': 1.0,
 'v_

In [8]:
transformers.set_seed(args.seed)
model = model_utils.get_model(args.model, args.hf_token)
model.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  9.96it/s]
---> Loading meta-llama/Llama-2-7b-hf Model with seq_len: 2048


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [9]:
rotation_utils.fuse_layer_norms(model)
utils.cleanup_memory(verbos=True)


GPU memory (from <module>): 0.00 -> 0.00 GB (0.00 GB)


In [10]:
Q1, Q2s = rotate_model(model, args)
utils.cleanup_memory(verbos=True)

quant_utils.add_actquant(
    model,
    layers=[nn.Linear,
            ActQuantWrapper,
            RotatedHead,
            RotatedLinearIn,
            RotatedLinearOut,
            RotatedOVProj]
)

qlayers = quant_utils.find_qlayers(
    model.model,
    [nn.Linear,
     ActQuantWrapper,
     RotatedHead,
     RotatedLinearIn,
     RotatedLinearOut,
     RotatedOVProj])

for name in qlayers:
    if 'down_proj' in name:
        had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
        qlayers[name].online_full_had = True
        qlayers[name].had_K = had_K
        qlayers[name].K = K
        qlayers[name].fp32_had = args.fp32_had
    # if 'o_proj' in name:
    #     had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads)
    #     qlayers[name].online_partial_had = True
    #     qlayers[name].had_K = had_K
    #     qlayers[name].K = K
    #     qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads
    #     qlayers[name].fp32_had = args.fp32_had

if args.w_bits < 16:
    save_dict = {}
    if args.load_qmodel_path: # Load Quantized Rotated Model
        assert args.rotate, "Model should be rotated to load a quantized model!"
        assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
        print("Load quantized model from ", args.load_qmodel_path)
        save_dict = torch.load(args.load_qmodel_path)
        model.load_state_dict(save_dict["model"])
        
    elif not args.w_rtn: # GPTQ Weight Quantization
        assert "llama" in args.model, "Only llama is supported for GPTQ!"
        
        trainloader = data_utils.get_loaders(
            args.cal_dataset, nsamples=args.nsamples,
            seed=args.seed, model=args.model,
            seqlen=model.seqlen, eval_mode=False
        )
        quantizers = gptq_utils.gptq_fwrd(
            model,
            trainloader,
            utils.DEV,
            args,
            [RotatedHead, RotatedLinearIn, RotatedLinearOut, RotatedOVProj])
        save_dict["w_quantizers"] = quantizers
    else: # RTN Weight Quantization
        quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args)
        save_dict["w_quantizers"] = quantizers
        
    if args.save_qmodel_path:
        save_dict["model"] = model.state_dict()
        torch.save(save_dict, args.save_qmodel_path)

if args.a_bits < 16 or args.v_bits < 16:
    qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
    down_proj_groupsize = -1
    if args.a_groupsize > 0 and "llama" in args.model:
        down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize)
    
    for name in qlayers:            
        layer_input_bits = args.a_bits
        layer_groupsize = args.a_groupsize
        layer_a_sym = not(args.a_asym)
        layer_a_clip = args.a_clip_ratio
        
        if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision
            qlayers[name].out_quantizer.configure(bits=args.v_bits,
                                            groupsize=args.v_groupsize,
                                            sym=not(args.v_asym),
                                            clip_ratio=args.v_clip_ratio)
        
        if 'lm_head' in name: #Skip lm_head quantization   
            layer_input_bits = 16
        
        if 'down_proj' in name: #Set the down_proj precision
            if args.int8_down_proj:
                layer_input_bits = 8
            layer_groupsize = down_proj_groupsize

            
        qlayers[name].quantizer.configure(bits=layer_input_bits,
                                            groupsize=layer_groupsize,
                                            sym=layer_a_sym,
                                            clip_ratio=layer_a_clip)

if args.k_bits < 16:
    if args.k_pre_rope:
        raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
    else:
        rope_function_name = model_utils.get_rope_function_name(model)
        layers = model_utils.get_layers(model)
        k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize,
                                        "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}
        for layer in layers:
            rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
                        layer.self_attn, 
                        rope_function_name, 
                        config=model.config,
                        **k_quant_config)
# rope_function_name = model_utils.get_rope_function_name(model)
# layers = model_utils.get_layers(model)
# k_quant_config = {'k_bits': 16, "k_groupsize": args.k_groupsize,
#                     "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}

# for idx, layer in enumerate(layers):
#     print('Wrapping QK', idx)
#     rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
#                 layer.self_attn, 
#                 rope_function_name, 
#                 config=model.config,
#                 **k_quant_config)

#model = model.to(utils.DEV)



GPU memory (from rotate_model): 0.00 -> 0.00 GB (0.00 GB)
Rotating: 100%|██████████| 32/32 [00:33<00:00,  1.05s/layer]
GPU memory (from <module>): 0.00 -> 0.00 GB (0.00 GB)


In [11]:
utils.distribute_model(model)

#for p in model.parameters():
#    p.requires_grad = False

Q1.requires_grad=True
for Q2 in Q2s:
    Q2.requires_grad=True

#optimizer = torch.optim.Adam([Q1, Q2], lr=1e-4)

{0: 50758680576, 1: 50758680576, 2: 50758680576, 3: 50758680576, 4: 50758680576, 5: 50758680576, 6: 50758680576, 7: 50758680576, 'cpu': 245580857344}
{0: 2219574570, 1: 2219574570, 2: 2219574570, 3: 2219574570, 4: 2219574570, 5: 2219574570, 6: 2219574570, 7: 50758680576, 'cpu': 245580857344}

Treating module model.
Not enough space on 0 to put model (space available 2219574554, module size 13423478532).
Splitting model.

Treating module model.embed_tokens.
Putting model.embed_tokens (size=396361728) on 0 (available=1812477170).

Treating module model.layers.
Not enough space on 0 to put model.layers (space available 1823212826, module size 13027116800).
Splitting model.layers.

Treating module model.layers.0.
Putting model.layers.0 (size=407097400) on 0 (available=1416115442).

Treating module model.layers.1.
Putting model.layers.1 (size=407097400) on 0 (available=1009018042).

Treating module model.layers.2.
Putting model.layers.2 (size=407097400) on 0 (available=601920642).

Treating

GPU memory (from distribute_model): 40.88 -> 40.88 GB (0.00 GB)


In [12]:
from accelerate import Accelerator

accelerator = Accelerator()

In [13]:
from stiefel import stiefel_optimizer
optimizer = stiefel_optimizer.SGDG([
    {'params': [Q1] + Q2s, 'lr': 1.5, 'momentum': 0.0, 'stiefel': True}
])

In [14]:
trainloader = data_utils.get_loaders(
    args.cal_dataset, nsamples=args.nsamples,
    seed=args.seed, model=args.model,
    seqlen=model.seqlen, eval_mode=False
)

from transformers.trainer_pt_utils import LabelSmoother
label_smoother = LabelSmoother(0.0)

In [15]:
#model, optimizer, trainloader = accelerator.prepare(model, optimizer, trainloader)

In [16]:
pbar = tqdm.tqdm(range(0 + 1, 100 + 1), desc="Training progress",
            total=100, dynamic_ncols=True)


Training progress:   0%|          | 0/100 [00:00<?, ?it/s]

In [17]:
idx_stack = None

In [18]:
# M1 = torch.zeros_like(Q1).requires_grad_(False)
# M2s = []

# for Q2 in Q2s:
#     M2 = torch.zeros_like(Q2).requires_grad_(False)
#     M2s.append(M2)

# beta = 0.9
# epsilon = 1e-8
# s = 5

# @torch.no_grad()
# def cayley_sgd(X, M, l, beta, epsilon, q, s):
#     if X.grad is not None:
#         M = beta * M - X.grad
#         #print('M', M.isnan().any().item())
#         MK = torch.matmul(M, X.T)
#         W_hat = MK - 0.5 * torch.matmul(X, torch.matmul(X.T, MK))
#         #print('W_hat', W_hat.isnan().any().item())
#         W = W_hat - W_hat.T
        
#         M = torch.matmul(W, X)
        
#         alpha = min(l, 2. * q / (torch.norm(W) + epsilon))
#         #print('alpha', alpha)
#         Y = X + alpha * M
        
#         for i in range(s):
#             Y = X + alpha / 2 * torch.matmul(W, X + Y)
        
#         X.data = Y
#         X.grad.fill_(0)

In [19]:
def lr_schedule(iter, total_iter, max_lr, min_lr):
    return max_lr - iter / total_iter * (max_lr - min_lr)

In [20]:
for iteration in pbar:
    if not idx_stack:
        idx_stack = list(range(0, len(trainloader)))
    
    idx = idx_stack.pop(randint(0, len(idx_stack) - 1))
    
    data = trainloader[idx]
    
    input = data[0]#.to(utils.DEV)
    target = data[1]#.to(utils.DEV)
    #print()
    output = model(input)
    
    loss = label_smoother(output, input, shift_labels=True)
    
    Q1.retain_grad()
    for Q2 in Q2s:
        Q2.retain_grad()
    optimizer.zero_grad()
    loss.backward()
    #accelerator.backward(loss)
    lr = lr_schedule(iteration, 100, 1.5, 0)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    for p in [Q1]+Q2s:
        print(p.grad is None)
    optimizer.step()
    
    with torch.no_grad():
        pbar.set_postfix(
            {'CE': f'{loss.item():.3f}',
             'ortho1': f'{(torch.matmul(Q1, Q1.T) - torch.eye(Q1.size(0)).to(Q1.device)).sum().item():.3f}',
             'ortho2': f'{(torch.matmul(Q2, Q2.T) - torch.eye(Q2.size(0)).to(Q2.device)).sum().item():.3f}',
             #'det(Q)': f'{torch.linalg.det(Q):.3f}'
             }
        )

Training progress:   1%|          | 1/100 [03:25<5:38:29, 205.15s/it, CE=12.595, ortho1=0.000, ortho2=0.000]

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [21]:
Q1.requires_grad

True

In [21]:
for layer in model.model.layers:
    print(layer.self_attn.v_proj.weight.device)

cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0


In [21]:
Q1

Parameter containing:
tensor([[-0.0154, -0.0159, -0.0153,  ..., -0.0157, -0.0155, -0.0157],
        [ 0.0160, -0.0157,  0.0155,  ..., -0.0158,  0.0159, -0.0156],
        [ 0.0156,  0.0157, -0.0153,  ...,  0.0157, -0.0157, -0.0158],
        ...,
        [-0.0156,  0.0157, -0.0158,  ...,  0.0157, -0.0155,  0.0157],
        [-0.0157, -0.0155,  0.0158,  ..., -0.0163,  0.0156,  0.0155],
        [-0.0155,  0.0157,  0.0156,  ...,  0.0156,  0.0159, -0.0160]],
       device='cuda:0', requires_grad=True)

In [22]:
Q2

Parameter containing:
tensor([[-0.0884, -0.0884, -0.0884,  ..., -0.0884, -0.0884, -0.0884],
        [ 0.0884, -0.0884,  0.0884,  ..., -0.0884,  0.0884, -0.0884],
        [ 0.0884,  0.0884, -0.0884,  ...,  0.0884, -0.0884, -0.0884],
        ...,
        [-0.0884,  0.0884, -0.0884,  ..., -0.0884,  0.0884, -0.0884],
        [ 0.0884,  0.0884, -0.0884,  ..., -0.0884,  0.0884,  0.0884],
        [-0.0884,  0.0884,  0.0884,  ..., -0.0884, -0.0884,  0.0884]],
       requires_grad=True)

In [24]:
Q2.grad

In [21]:
torch.save(Q1, 'Q1.pt')
torch.save(Q2, 'Q2.pt')

In [None]:
model.layers[0].self_attn.head_dim

NameError: name 'model' is not defined