In [1]:
#!export CUDA_VISIBLE_DEVICES=0,1

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="4,5,6,7"
import torch
print(torch.cuda.device_count())
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from random import randint

import tqdm

0


In [2]:
from rotation_utils import random_orthogonal_matrix
from hadamard_utils import random_hadamard_matrix, apply_exact_had_to_linear
from quant_utils import ActQuantWrapper

import utils
import model_utils
import data_utils
import transformers
import quant_utils
import rotation_utils
import gptq_utils
import eval_utils
import hadamard_utils

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class RotatedEmbedding(nn.Embedding):
    
    def forward(self, x, Q=None):
        W = self.weight
        if Q is not None:
            W_ = torch.matmul(W, Q)
        else:
            W_ = W
        
        return F.embedding(
            x, W_,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse)

class RotatedHead(nn.Linear):
    
    def forward(self, x, Q=None):
        W = self.weight
        
        if Q is not None:
            W_ = torch.matmul(W, Q)
        else:
            W_ = W
        
        return F.linear(
            x, W_,
        )

class RotatedLinearIn(nn.Linear):
    def forward(self, x, Q=None):
        W = self.weight
        
        if Q is not None:
            W_ = torch.matmul(W, Q)
        else:
            W_ = W
        
        return F.linear(
            x, W_,
        )

class RotatedLinearOut(nn.Linear):
    def forward(self, x, Q=None):
        W = self.weight
        b = self.bias
        
        if Q is not None:
            W_ = torch.matmul(Q.T, W)
            if b is not None:
                b_ = torch.matmul(Q.T, b)
            else:
                b_ = b
        else:
            W_ = W
            b_ = b
        
        return F.linear(
            x, W_, b_
        )

In [4]:
from transformers.models.llama.modeling_llama import LlamaForCausalLM

In [5]:

class RotatedEmbedding(nn.Embedding):
    
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        padding_idx=None,
        max_norm=None,
        norm_type=2.0,
        scale_grad_by_freq=False,
        sparse=False,
        _weight=None,
        _freeze=False,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(num_embeddings, embedding_dim, padding_idx,
                         max_norm, norm_type, scale_grad_by_freq, sparse,
                         _weight, _freeze, device, dtype)
        
        if Q is not None:
            self.register_buffer("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x):
        W = self.weight
        if self.Q is not None:
            W_ = torch.matmul(W.to(dtype=self.Q.dtype), self.Q.to(W.device)).to(dtype=W.dtype)
            #print('emb')
            #print(W_.grad_fn)
        else:
            W_ = W
            
        return F.embedding(
            x, W_,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse)

class RotatedHead(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Q is not None:
            self.register_buffer("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x, Q=None):
        W = self.weight
        
        if self.Q is not None:
            W_ = torch.matmul(W.to(dtype=self.Q.dtype), self.Q.to(W.device)).to(dtype=W.dtype)
            
            #print('head')
            #print(W_.grad_fn)
        else:
            W_ = W
        
        return F.linear(
            x, W_,
        )

class RotatedLinearIn(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Q is not None:
            self.register_buffer("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x, Q=None):
        W = self.weight
        
        if self.Q is not None:
            W_ = torch.matmul(W.to(dtype=self.Q.dtype), self.Q.to(W.device)).to(dtype=W.dtype)
            
            #print('linear in')
            #print(W_.grad_fn)
        else:
            W_ = W
        
        return F.linear(
            x, W_,
        )


class RotatedOVProj(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Qin=None,
        Qout=None,
        output=False,
        nheads=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Qin is not None:
            self.register_buffer("Qin", Qin)
        else:
            self.Qin = None
        
        if Qout is not None:
            self.register_buffer("Qout", Qout)
        else:
            self.Qout = None
        
        self.output = output
        self.nheads = nheads
    
    def forward(self, x):
        W = self.weight
        
        # O
        if self.Qin is not None:
            if self.output:
                W_ = torch.matmul(W.to(dtype=self.Qin.dtype), self.Qin.to(W.device)).to(dtype=W.dtype)
            else:
                W_ = W.to(dtype=self.Qin.dtype).reshape(W.size(0), self.nheads, -1)
                W_ = torch.einsum('inh,hj->inj', W_, self.Qin.to(W.device)).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear o')
                #print(W_.grad_fn)
        else:
            W_ = W
        
        # V
        if self.Qout is not None:
            if self.output:
                W_ = W_.to(dtype=self.Qout.dtype).reshape(self.nheads, -1, W.size(1))
                W_ = torch.einsum('ih,nhj->nij', self.Qout.to(W.device).T, W_).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear v')
                #print(W_.grad_fn)
            else:
                W_ = torch.matmul(self.Qout.to(W.device).T, W_.to(dtype=self.Qout.dtype)).to(dtype=W.dtype)
        else:
            pass
        
        return F.linear(
            x, W_,
        )


class RotatedLinearOut(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Q is not None:
            self.register_buffer("Q", Q)
        else:
            self.Q = None
    
    def forward(self, x, Q=None):
        W = self.weight
        b = self.bias
        
        if self.Q is not None:
            W_ = torch.matmul(self.Q.to(W.device).T, W.to(dtype=self.Q.dtype)).to(dtype=W.dtype)
            
            #print('linear out')
            #print(W_.grad_fn)
            if b is not None:
                b_ = torch.matmul(self.Q.to(W.device).T, b.to(dtype=self.Q.dtype)).to(dtype=b.dtype)
            else:
                b_ = b
        else:
            W_ = W
            b_ = b
        
        return F.linear(
            x, W_, b_
        )

In [6]:
def rotate_embeddings(model, Q):
    
    original_emb = model.model.embed_tokens
    
    new_emb = RotatedEmbedding(
        original_emb.num_embeddings,
        original_emb.embedding_dim,
        original_emb.padding_idx,
        original_emb.max_norm,
        original_emb.norm_type,
        original_emb.scale_grad_by_freq,
        original_emb.sparse,
        original_emb.weight.data,
        not original_emb.weight.requires_grad,
        original_emb.weight.data.device,
        original_emb.weight.data.dtype,
        Q
    )
    
    setattr(model.model, 'embed_tokens', new_emb)


def rotate_attention_inputs(layer, Q) -> None:
    # Rotate the WQ, WK and WV matrices of the self-attention layer.
    for name in ['q_proj', 'k_proj']:#, 'v_proj']:
        original_matrix = getattr(layer.self_attn, name)
        
        new_matrix = RotatedLinearIn(
            original_matrix.in_features,
            original_matrix.out_features,
            original_matrix.bias is not None,
            original_matrix.weight.data.device,
            original_matrix.weight.data.dtype,
            Q
        )
        
        new_matrix.weight.data = original_matrix.weight.data.clone()
        if original_matrix.bias is not None:
            new_matrix.bias.data = original_matrix.bias.data.clone()
        
        setattr(layer.self_attn, name, new_matrix)
        del original_matrix


def rotate_attention_output(layer, Q) -> None:
    # Rotate output matrix of the self-attention layer.
    original_matrix = layer.self_attn.o_proj
    
    new_matrix = RotatedLinearOut(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.self_attn, 'o_proj', new_matrix)
    del original_matrix


def rotate_mlp_input(layer, Q):
    # Rotate the MLP input weights.
    
    for name in ['up_proj', 'gate_proj']:
        original_matrix = getattr(layer.mlp, name)
        
        new_matrix = RotatedLinearIn(
            original_matrix.in_features,
            original_matrix.out_features,
            original_matrix.bias is not None,
            original_matrix.weight.data.device,
            original_matrix.weight.data.dtype,
            Q
        )
        
        new_matrix.weight.data = original_matrix.weight.data.clone()
        if original_matrix.bias is not None:
            new_matrix.bias.data = original_matrix.bias.data.clone()
        
        setattr(layer.mlp, name, new_matrix)
        del original_matrix


def rotate_mlp_output(layer, Q):
    # Rotate the MLP output weights and bias.
    original_matrix = layer.mlp.down_proj
    
    new_matrix = RotatedLinearOut(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.mlp, 'down_proj', new_matrix)
    del original_matrix


def rotate_head(model, Q: torch.Tensor) -> None:
    # Rotate the head.
    original_matrix = model.lm_head
    
    new_matrix = RotatedLinearIn(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(model, 'lm_head', new_matrix)
    del original_matrix


def rotate_ov_proj(layer, Q1, Q2, nheads):
    #print(nheads)
    original_matrix = layer.self_attn.o_proj
    
    new_matrix = RotatedOVProj(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q2, Q1, False, nheads
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.self_attn, 'o_proj', new_matrix)
    del original_matrix
    
    original_matrix = layer.self_attn.v_proj
    
    new_matrix = RotatedOVProj(
        original_matrix.in_features,
        original_matrix.out_features,
        original_matrix.bias is not None,
        original_matrix.weight.data.device,
        original_matrix.weight.data.dtype,
        Q1, Q2, True, nheads
    )
    
    new_matrix.weight.data = original_matrix.weight.data.clone()
    if original_matrix.bias is not None:
        new_matrix.bias.data = original_matrix.bias.data.clone()
    
    setattr(layer.self_attn, 'v_proj', new_matrix)
    del original_matrix


def rotate_model(model, args):
    #q = random_orthogonal_matrix(model.config.hidden_size, utils.DEV).to(dtype=torch.float32)
    
    config = model.config
    num_heads = config.num_attention_heads
    model_dim = config.hidden_size
    head_dim = model_dim // num_heads
    
    q1 = random_hadamard_matrix(model.config.hidden_size, utils.DEV).to(dtype=torch.float64)
    Q1 = nn.Parameter(q1, requires_grad=True)
    Q2s = []
    #q2 = random_hadamard_matrix(head_dim, utils.DEV).to(dtype=torch.float32)
    #Q2 = nn.Parameter(q2, requires_grad=True)
    
    model_type = model_utils.model_type_extractor(model)
    rotate_embeddings(model, Q1)
    rotate_head(model, Q1)
    utils.cleanup_memory()
    layers = model_utils.get_transformer_layers(model, 
                                                model_type=model_type)
    
    for idx, layer in enumerate(tqdm.tqdm(layers, unit="layer", desc="Rotating")):
        q2 = random_hadamard_matrix(head_dim, layers[idx].self_attn.v_proj.weight.device).to(dtype=torch.float64)#.clone().detach().requires_grad_(True)
        Q2 = nn.Parameter(q2, requires_grad=True)
        rotate_attention_inputs(layers[idx], Q1)
        #rotate_attention_output(layers[idx], Q1)
        rotate_mlp_input(layers[idx], Q1)
        rotate_mlp_output(layers[idx], Q1)
        rotate_ov_proj(layers[idx], Q1, Q2, num_heads)
        
        Q2s.append(Q2)
        #print(str(idx) + '-----------')
        #print(layer.self_attn.v_proj)
        #print(layer.self_attn.o_proj)
    return Q1, Q2s

In [7]:
args = utils.parser_gen('--model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --v_bits 4 --k_bits 4 --w_bits 4 --w_clip --bsz 1'.split())

Arguments: 
{'a_asym': False,
 'a_bits': 4,
 'a_clip_ratio': 1.0,
 'a_groupsize': -1,
 'act_order': False,
 'bsz': 1,
 'cal_dataset': 'wikitext2',
 'capture_layer_io': False,
 'distribute': False,
 'eval_dataset': 'wikitext2',
 'fp32_had': False,
 'hf_token': None,
 'int8_down_proj': False,
 'k_asym': False,
 'k_bits': 4,
 'k_clip_ratio': 1.0,
 'k_groupsize': -1,
 'k_pre_rope': False,
 'layer_idx': 10,
 'learn_r1': True,
 'learn_r2': True,
 'lm_eval': False,
 'lm_eval_batch_size': 128,
 'load_qmodel_path': None,
 'model': 'meta-llama/Llama-2-7b-hf',
 'momentum': 0.0,
 'nsamples': 128,
 'percdamp': 0.01,
 'prefix_r': '',
 'rotate': True,
 'rotate_mode': 'hadamard',
 'rotation_seed': -1,
 'save_name': '20240712_120507',
 'save_path': '/ceph/echoi/codes/QuaRot/fake_quant/experiments/meta-llama/Llama-2-7b-hf/20240712_120507',
 'save_qmodel_path': None,
 'seed': 0,
 'tasks': ['piqa',
           'hellaswag',
           'arc_easy',
           'arc_challenge',
           'winogrande',
        

In [None]:
transformers.set_seed(args.seed)
model = model_utils.get_model(args.model, args.hf_token)
model.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.81s/it]
---> Loading meta-llama/Llama-2-7b-hf Model with seq_len: 2048


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [None]:
for name, p in model.named_parameters():
    print(name, p.dtype)

model.embed_tokens.weight torch.float16
model.layers.0.self_attn.q_proj.weight torch.float16
model.layers.0.self_attn.k_proj.weight torch.float16
model.layers.0.self_attn.v_proj.weight torch.float16
model.layers.0.self_attn.o_proj.weight torch.float16
model.layers.0.mlp.gate_proj.weight torch.float16
model.layers.0.mlp.up_proj.weight torch.float16
model.layers.0.mlp.down_proj.weight torch.float16
model.layers.0.input_layernorm.weight torch.float16
model.layers.0.post_attention_layernorm.weight torch.float16
model.layers.1.self_attn.q_proj.weight torch.float16
model.layers.1.self_attn.k_proj.weight torch.float16
model.layers.1.self_attn.v_proj.weight torch.float16
model.layers.1.self_attn.o_proj.weight torch.float16
model.layers.1.mlp.gate_proj.weight torch.float16
model.layers.1.mlp.up_proj.weight torch.float16
model.layers.1.mlp.down_proj.weight torch.float16
model.layers.1.input_layernorm.weight torch.float16
model.layers.1.post_attention_layernorm.weight torch.float16
model.layers.2

In [None]:
rotation_utils.fuse_layer_norms(model)
utils.cleanup_memory(verbos=True)
utils.distribute_model(model)

GPU memory (from <module>): 0.00 -> 0.00 GB (0.00 GB)
GPU memory (from distribute_model): 12.62 -> 12.62 GB (0.00 GB)


In [None]:
Q1, Q2s = rotate_model(model, args)
utils.cleanup_memory(verbos=True)

# quant_utils.add_actquant(
#     model,
#     layers=[nn.Linear,
#             ActQuantWrapper,
#             RotatedHead,
#             RotatedLinearIn,
#             RotatedLinearOut,
#             RotatedOVProj]
# )

# qlayers = quant_utils.find_qlayers(model.model, [nn.Linear, ActQuantWrapper, RotatedHead, RotatedLinearIn, RotatedLinearOut, RotatedOVProj])

# for name in qlayers:
#     if 'down_proj' in name:
#         had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
#         qlayers[name].online_full_had = True
#         qlayers[name].had_K = had_K
#         qlayers[name].K = K
#         qlayers[name].fp32_had = args.fp32_had
#     if 'o_proj' in name:
#         had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads)
#         qlayers[name].online_partial_had = True
#         qlayers[name].had_K = had_K
#         qlayers[name].K = K
#         qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads
#         qlayers[name].fp32_had = args.fp32_had

# rope_function_name = model_utils.get_rope_function_name(model)
# layers = model_utils.get_layers(model)
# k_quant_config = {'k_bits': 16, "k_groupsize": args.k_groupsize,
#                     "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}

# for idx, layer in enumerate(layers):
#     print('Wrapping QK', idx)
#     rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
#                 layer.self_attn, 
#                 rope_function_name, 
#                 config=model.config,
#                 **k_quant_config)


for p in model.parameters():
    p.requires_grad = False

Q1.requires_grad=True
for Q2 in Q2s:
    Q2.requires_grad=True

#optimizer = torch.optim.Adam([Q1, Q2], lr=1e-4)

GPU memory (from rotate_model): 13.30 -> 12.68 GB (-0.61 GB)
Rotating: 100%|██████████| 32/32 [00:00<00:00, 283.84layer/s]
GPU memory (from <module>): 25.56 -> 13.37 GB (-12.19 GB)


In [None]:
trainloader = data_utils.get_loaders(
    args.cal_dataset, nsamples=args.nsamples,
    seed=args.seed, model=args.model,
    seqlen=model.seqlen, eval_mode=False
)

from transformers.trainer_pt_utils import LabelSmoother
label_smoother = LabelSmoother(0.0)

In [13]:
pbar = tqdm.tqdm(range(0 + 1, 100 + 1), desc="Training progress",
            total=100, dynamic_ncols=True)


Training progress:   0%|          | 0/100 [00:00<?, ?it/s]

In [14]:
idx_stack = None

In [15]:
# Straightforward ortho reg w/ SFTT
# for iteration in pbar:
#     if not idx_stack:
#         idx_stack = list(range(0, len(trainloader)))
    
#     idx = idx_stack.pop(randint(0, len(idx_stack) - 1))
    
#     data = trainloader[idx]
    
#     input = data[0]
#     target = data[1]
    
#     output = model(input)
    
#     loss = label_smoother(output, input, shift_labels=True)
#     sym = torch.mm(Q, torch.t(Q))
#     sym -= torch.eye(Q.shape[0]).to(sym.device)
#     # ls_ort = sym.abs().sum()   # poor match to geometry of orthogonal matrices
#     ortho_reg = sym.pow(2.0).sum()
#     loss = loss + 10.0 * ortho_reg
    
#     with torch.no_grad():
#         pbar.set_postfix(
#             {'CE': f'{loss.item():.3f}',
#             'Ortho': f'{ortho_reg.item():.3f}',
#             'det(Q)': f'{torch.linalg.det(Q):.3f}'}
#         )
    
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
# torch.save(Q, 'Q_Hreg1k.pt')
# Q


In [16]:
# Pre-forward hook-style test
# q = random_orthogonal_matrix(2, device).to(dtype=torch.float32)

# Q = nn.Parameter(q, requires_grad=True)

# print('Q', Q)

# layer = nn.Linear(2, 2, False).to(device)

# for p in layer.parameters():
#     p.requires_grad = False
    
# layer.register_buffer('Q', Q)

# def hook(module, input):
#     x = F.linear(input[0], module.Q)
#     return (x,) + input[1:]

# layer.register_forward_pre_hook(hook)

# x = torch.eye(2).to(device)
# target = torch.eye(2).to(device)

# optim = torch.optim.Adam([Q], lr=1e-2)

# for i in range(1000000):
#     y = layer(x)
#     loss = F.mse_loss(y, target)
    
#     sym = torch.mm(Q, torch.t(Q))
#     sym -= torch.eye(Q.shape[0]).to(device)
#     # ls_ort = sym.abs().sum()   # poor match to geometry of orthogonal matrices
#     ortho_reg = sym.pow(2.0).sum()
    
#     loss = loss + 0.1 * ortho_reg
    
#     optim.zero_grad()
#     loss.backward()
#     optim.step()
    
#     if i % 10000 == 0:
#         #print("-----W-----")
#         #print(layer.weight)
#         print("-----Q-----")
#         print(Q)
#         with torch.no_grad():
#             print(torch.linalg.det(Q).item(), ortho_reg.item())
#         print("-----Y-----")
#         print(y)

In [17]:
M1 = torch.zeros_like(Q1).requires_grad_(False)
M2s = []

for Q2 in Q2s:
    M2 = torch.zeros_like(Q2).requires_grad_(False)
    M2s.append(M2)

beta = 0.9
epsilon = 1e-8
s = 5

@torch.no_grad()
def cayley_sgd(X, M, l, beta, epsilon, q, s):
    M = beta * M - X.grad
    #print('M', M.isnan().any().item())
    MK = torch.matmul(M, X.T)
    W_hat = MK - 0.5 * torch.matmul(X, torch.matmul(X.T, MK))
    #print('W_hat', W_hat.isnan().any().item())
    W = W_hat - W_hat.T
    
    M = torch.matmul(W, X)
    
    alpha = min(l, 2. * q / (torch.norm(W) + epsilon))
    #print('alpha', alpha)
    Y = X + alpha * M
    
    for i in range(s):
        Y = X + alpha / 2 * torch.matmul(W, X + Y)
    
    X.data = Y
    X.grad.fill_(0)

In [18]:
def lr_schedule(iter, total_iter, max_lr, min_lr):
    return max_lr - iter / total_iter * (max_lr - min_lr)

In [19]:
for iteration in pbar:
    if not idx_stack:
        idx_stack = list(range(0, len(trainloader)))
    
    idx = idx_stack.pop(randint(0, len(idx_stack) - 1))
    
    data = trainloader[idx]
    
    input = data[0]
    target = data[1]
    #print()
    output = model(input)
    
    loss = label_smoother(output, input, shift_labels=True)
    
    #Q1.retain_grad()
    #for Q2 in Q2s:
    #    Q2.retain_grad()
    loss.backward()
    lr = lr_schedule(iteration, 100, 1.5, 0)
    #print(iteration, output['logits'].isnan().any().item(), loss.isnan().any().item(), lr)
    cayley_sgd(Q1, M1, lr, 0.9, 1e-8, 0.5, 5)
    count = 0
    for Q2, M2 in zip(Q2s, M2s):
        cayley_sgd(Q2, M2, lr, 0.9, 1e-8, 0.5, 5)
        #if Q2.grad is not None:
        #    count += 1
        try:
            with torch.no_grad():
                _ = torch.linalg.cholesky(Q2)
            count += 1
        except:
            pass
    print(iteration, count)
    
    with torch.no_grad():
        pbar.set_postfix(
            {'CE': f'{loss.item():.3f}',
             'ortho1': f'{(torch.matmul(Q1, Q1.T) - torch.eye(Q1.size(0)).to(Q1.device)).sum().item():.3f}',
             'ortho2': f'{(torch.matmul(Q2, Q2.T) - torch.eye(Q2.size(0)).to(Q2.device)).sum().item():.3f}',
             #'det(Q)': f'{torch.linalg.det(Q):.3f}'
             }
        )

Training progress:   1%|          | 1/100 [00:08<14:38,  8.88s/it, CE=1.801, ortho1=-0.000, ortho2=0.000]

1 0


Training progress:   2%|▏         | 2/100 [00:17<13:46,  8.44s/it, CE=1.725, ortho1=-0.000, ortho2=-0.000]

2 0


Training progress:   3%|▎         | 3/100 [00:24<13:12,  8.17s/it, CE=1.901, ortho1=-0.000, ortho2=0.000] 

3 0


Training progress:   4%|▍         | 4/100 [00:32<12:51,  8.03s/it, CE=1.074, ortho1=-0.000, ortho2=-0.000]

4 0


Training progress:   5%|▌         | 5/100 [00:40<12:36,  7.96s/it, CE=1.887, ortho1=-0.000, ortho2=0.000] 

5 0


Training progress:   6%|▌         | 6/100 [00:48<12:23,  7.91s/it, CE=1.509, ortho1=-0.000, ortho2=0.000]

6 0


Training progress:   7%|▋         | 7/100 [00:56<12:09,  7.84s/it, CE=1.758, ortho1=-0.000, ortho2=-0.000]

7 0


Training progress:   8%|▊         | 8/100 [01:03<12:01,  7.84s/it, CE=2.013, ortho1=-0.000, ortho2=-0.000]

8 0


Training progress:   9%|▉         | 9/100 [01:11<11:53,  7.84s/it, CE=1.449, ortho1=-0.000, ortho2=-0.000]

9 0


Training progress:  10%|█         | 10/100 [01:19<11:46,  7.85s/it, CE=1.852, ortho1=-0.000, ortho2=-0.000]

10 0


Training progress:  11%|█         | 11/100 [01:27<11:40,  7.87s/it, CE=1.910, ortho1=-0.000, ortho2=-0.000]

11 0


Training progress:  12%|█▏        | 12/100 [01:35<11:34,  7.89s/it, CE=1.834, ortho1=-0.000, ortho2=-0.000]

12 0


Training progress:  13%|█▎        | 13/100 [01:43<11:25,  7.88s/it, CE=2.189, ortho1=-0.000, ortho2=-0.000]

13 0


Training progress:  14%|█▍        | 14/100 [01:51<11:21,  7.92s/it, CE=2.136, ortho1=-0.000, ortho2=-0.000]

14 0


Training progress:  15%|█▌        | 15/100 [01:59<11:12,  7.91s/it, CE=1.591, ortho1=-0.000, ortho2=-0.000]

15 0


Training progress:  16%|█▌        | 16/100 [02:07<11:03,  7.90s/it, CE=1.766, ortho1=-0.000, ortho2=-0.000]

16 0


Training progress:  17%|█▋        | 17/100 [02:14<10:56,  7.90s/it, CE=1.645, ortho1=-0.000, ortho2=-0.000]

17 0


Training progress:  18%|█▊        | 18/100 [02:23<10:51,  7.94s/it, CE=1.564, ortho1=-0.000, ortho2=-0.000]

18 0


Training progress:  19%|█▉        | 19/100 [02:31<10:47,  7.99s/it, CE=2.075, ortho1=-0.000, ortho2=-0.000]

19 0


Training progress:  20%|██        | 20/100 [02:39<10:38,  7.98s/it, CE=1.623, ortho1=-0.000, ortho2=-0.000]

20 0


Training progress:  21%|██        | 21/100 [02:47<10:29,  7.97s/it, CE=1.642, ortho1=-0.000, ortho2=0.000] 

21 0


Training progress:  22%|██▏       | 22/100 [02:54<10:19,  7.94s/it, CE=1.613, ortho1=-0.000, ortho2=-0.000]

22 0


Training progress:  23%|██▎       | 23/100 [03:02<10:10,  7.93s/it, CE=1.514, ortho1=-0.000, ortho2=-0.000]

23 0


Training progress:  24%|██▍       | 24/100 [03:10<10:03,  7.95s/it, CE=1.578, ortho1=-0.000, ortho2=-0.000]

24 0


Training progress:  25%|██▌       | 25/100 [03:18<09:56,  7.96s/it, CE=1.837, ortho1=-0.000, ortho2=0.000] 

25 0


In [20]:
Q1

Parameter containing:
tensor([[-0.0156, -0.0158, -0.0156,  ..., -0.0156, -0.0156, -0.0156],
        [ 0.0156, -0.0157,  0.0157,  ..., -0.0157,  0.0156, -0.0156],
        [ 0.0157,  0.0156, -0.0156,  ...,  0.0157, -0.0156, -0.0157],
        ...,
        [-0.0155,  0.0155, -0.0155,  ...,  0.0156, -0.0157,  0.0156],
        [-0.0156, -0.0156,  0.0156,  ..., -0.0157,  0.0157,  0.0157],
        [-0.0156,  0.0155,  0.0156,  ...,  0.0157,  0.0155, -0.0157]],
       device='cuda:0', requires_grad=True)

In [21]:
for Q2 in Q2s:
    print(Q2)

Parameter containing:
tensor([[-0.0884, -0.0884, -0.0884,  ..., -0.0884, -0.0884, -0.0884],
        [ 0.0884, -0.0884,  0.0884,  ..., -0.0884,  0.0884, -0.0884],
        [ 0.0884,  0.0884, -0.0884,  ...,  0.0884, -0.0884, -0.0884],
        ...,
        [-0.0884,  0.0884, -0.0884,  ..., -0.0884,  0.0884, -0.0884],
        [ 0.0884,  0.0884, -0.0884,  ..., -0.0884,  0.0884,  0.0884],
        [-0.0884,  0.0884,  0.0884,  ..., -0.0884, -0.0884,  0.0884]],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[ 0.0884,  0.0884,  0.0884,  ...,  0.0884,  0.0884,  0.0884],
        [ 0.0884, -0.0884,  0.0884,  ..., -0.0884,  0.0884, -0.0884],
        [-0.0884, -0.0884,  0.0884,  ..., -0.0884,  0.0884,  0.0884],
        ...,
        [ 0.0884, -0.0884,  0.0884,  ...,  0.0884, -0.0884,  0.0884],
        [ 0.0884,  0.0884, -0.0884,  ..., -0.0884,  0.0884,  0.0884],
        [-0.0884,  0.0884,  0.0884,  ..., -0.0884, -0.0884,  0.0884]],
       device='cuda:0', requires_grad=True)


In [20]:
torch.save(Q1, 'Q1.pt')
for idx, Q2 in enumerate(Q2s):
    torch.save(Q2, f'Q2_{idx}.pt')

In [None]:
model.layers[0].self_attn.head_dim

NameError: name 'model' is not defined

In [21]:
Q2s[0].size()

torch.Size([128, 128])