In [None]:
from datasets import load_dataset

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3,4,5,6,7"

import torch
import torch.nn as nn
import torch.nn.functional as F

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

from quant_utils import ActQuantizer

from utils import distribute_model
import hadamard_utils
import fast_hadamard_transform

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class RotatedLinear(nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
    def forward(self, x, Q=None):
        W = self.weight
        b = self.bias
        if Q is None:
            W_ = W
            b_ = b

        else:
        
        #if W.device != self.Q.device:
        #    self.Q = self.Q.to(W.device)
        
            W_ = torch.matmul(Q.to(W.device).T, W.to(dtype=Q.dtype)).to(dtype=W.dtype)
            
            #print('linear out')
            #print(W_.grad_fn)
            if b is not None:
                b_ = torch.matmul(Q.to(W.device).T, b.to(dtype=Q.dtype)).to(dtype=b.dtype)
            else:
                b_ = b
        
        x = torch.nn.functional.linear(x, W_, b_)
        return x
        # return F.linear(
        #     x, W_, b_
        # )

class RotatedOVProj(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        output=False,
        nheads=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        
        self.output = output
        self.nheads = nheads
    
    def forward(self, x, Qin=None, Qout=None):
        W = self.weight
        
        # if W.device != self.Qin.device:
        #     self.Qin = self.Qin.to(W.device)
        
        # if W.device != self.Qout.device:
        #     self.Qout = self.Qout.to(W.device)
        
        if Qin is not None:
            if self.output:
                W_ = torch.matmul(W.to(dtype=Qin.dtype), Qin.to(W.device)).to(dtype=W.dtype)
            else:
                W_ = W.to(dtype=Qin.dtype).reshape(W.size(0), self.nheads, -1)
                print('linear o', W_.size(), Qin.size())
                W_ = torch.einsum('inh,hj->inj', W_, Qin.to(W.device)).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear o', W.grad_fn, W_.grad_fn, self.Qin.grad_fn)
        else:
            W_ = W
        
        if Qout is not None:
            if self.output:
                W_ = W_.to(dtype=Qout.dtype).reshape(self.nheads, -1, W.size(1))
                print('linear v', W_.size(), Qin.size())
                W_ = torch.einsum('ih,nhj->nij', Qout.to(W.device).T, W_).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear v', W.grad_fn, W_.grad_fn, self.Qout.grad_fn)
                #print(W_.grad_fn)
            else:
                W_ = torch.matmul(Qout.to(W.device).T, W_.to(dtype=Qout.dtype)).to(dtype=W.dtype)
        else:
            pass
        
        x = torch.nn.functional.linear(x, W_)
        return x

In [3]:
class ActQuantWrapper(torch.nn.Module):
    '''
        This class is a wrapper for the activation quantization.
        We extract the FP features in the forward pass and quantize the rest using
        the self.quantizer object.
        If a rotation Q is provided, the weight matrix will be rotated,
        a pre-forward hook will be registerd to rotate the activation before quantization.
    '''

    def __init__(self, module:torch.nn.Linear):
        super(ActQuantWrapper, self).__init__()
        assert isinstance(module, torch.nn.Linear)
        self.module = module
        self.weight = module.weight
        self.bias = module.bias
        self.quantizer = ActQuantizer()
        self.out_quantizer = ActQuantizer()
        self.register_buffer('had_K', torch.tensor(0))
        self._buffers['had_K'] = None
        self.K = 1
        self.online_full_had = False
        self.online_partial_had = False
        self.had_dim = 0
        self.fp32_had = False

    def extra_repr(self) -> str:
        str_ = f'Input Quantizer Bits: {self.quantizer.bits}'
        if self.quantizer.bits < 16:
            str_ += f' (Asymmetric Per-Token)' if not self.quantizer.sym else f' (Symmetric Per-Token)'

        str_ += f'\nOutput Quantizer Bits: {self.out_quantizer.bits}'
        if self.out_quantizer.bits < 16:
            str_ += f' (Asymmetric Per-Token)' if not self.out_quantizer.sym else f' (Symmetric Per-Token)'

        return str_

    def forward(self, x, Q=None, Qin=None, Qout=None):
        x_dtype = x.dtype

        # Rotate, if needed
        if self.online_full_had:
            
            if self.fp32_had: # Full Hadamard in FP32
                x = hadamard_utils.matmul_hadU_cuda(x.float(), self.had_K, self.K).to(x_dtype)
            else: # Full Hadamard in FP16
                x = hadamard_utils.matmul_hadU_cuda(x, self.had_K, self.K)
            
        elif self.online_partial_had:
            # todo: implement this in QAttention to avoid reshaping!
            
            if self.fp32_had:
                x = x.float()
                
            init_shape = x.shape
            if self.K == 1:
                x = fast_hadamard_transform.hadamard_transform(x.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim).transpose(1, 2),
                                                               scale=1/math.sqrt(init_shape[-1]//self.had_dim)).transpose(1, 2)
            else:
                x = (self.had_K.to(x.dtype) @ x.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim)) / math.sqrt(init_shape[-1]//self.had_dim)
                
            if self.fp32_had:
                x = x.to(x_dtype)
            x = x.reshape(init_shape)

        if self.quantizer.bits < 16: #Quantize, if needed
            self.quantizer.find_params(x)
            x = self.quantizer(x).to(x_dtype)
            self.quantizer.free()

        if isinstance(self.module, RotatedLinear):
            x = self.module(x, Q=Q).to(x_dtype)
        elif isinstance(self.module, RotatedOVProj):
            x = self.module(x, Qin=Qin, Qout=Qout).to(x_dtype)
        else:
            x = self.module(x).to(x_dtype)

        if self.out_quantizer.bits < 16: #Quantize the output, if needed
            self.out_quantizer.find_params(x)
            x = self.out_quantizer(x).to(x_dtype)
            self.out_quantizer.free()

        return x

In [4]:
a = RotatedLinear(4, 4, False)
b = RotatedLinear(4, 4, False)
c = RotatedOVProj(4, 4, False, output=True, nheads=2)
d = RotatedOVProj(4, 4, False, output=False, nheads=2)
e = RotatedOVProj(4, 4, False, output=True, nheads=2)
f = RotatedOVProj(4, 4, False, output=False, nheads=2)

In [5]:
class OVWrapper(nn.Module):
    def __init__(self, O, V):
        super().__init__()
        self.O = O
        self.V = V
    
    def forward(self, x, Q1, Q2):
        x = self.V(x, Qin=Q1, Qout=Q2)
        x = self.O(x, Qin=Q2, Qout=Q1)
        return x

In [6]:
A = ActQuantWrapper(a)
B = ActQuantWrapper(b)
C = ActQuantWrapper(c)
D = ActQuantWrapper(d)
E = ActQuantWrapper(e)
F = ActQuantWrapper(f)

In [7]:
class Wrapper(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
    
    def forward(self, x, Q=None):
        x = self.module(x, Q=Q)
        
        return x

In [8]:
A_ = Wrapper(A)
B_ = Wrapper(B)
# C_ = Wrapper(C)
# D_ = Wrapper(D)
# E_ = Wrapper(E)
# F_ = Wrapper(F)
L1 = OVWrapper(D, C)
L2 = OVWrapper(F, E)

In [9]:
class TopModel(nn.Module):
    def __init__(self, heads, modules):
        super().__init__()
        self.heads = nn.ModuleList(heads)
        self.layers = nn.ModuleList(modules)
    
    def forward(self, x, Q1=None, Q2s=None):
        for layer in self.heads:
            x = layer(x, Q=Q1)
        
        for idx, layer in enumerate(self.layers):
            if Q2s is not None:
                Q2 = Q2s[idx]
            x = layer(x, Q1=Q1, Q2=Q2)
        
        return x

In [10]:
model = TopModel([A_, B_], [L1, L2])
q = torch.ones((4, 4)) / 4
Q1 = nn.Parameter(q, requires_grad=True)

Q2s = torch.ones((2, 2, 2)) / 2
Q2s = nn.Parameter(Q2s, requires_grad=True)

In [11]:
model

TopModel(
  (heads): ModuleList(
    (0-1): 2 x Wrapper(
      (module): ActQuantWrapper(
        Input Quantizer Bits: 16
        Output Quantizer Bits: 16
        (module): RotatedLinear(in_features=4, out_features=4, bias=False)
        (quantizer): ActQuantizer()
        (out_quantizer): ActQuantizer()
      )
    )
  )
  (layers): ModuleList(
    (0-1): 2 x OVWrapper(
      (O): ActQuantWrapper(
        Input Quantizer Bits: 16
        Output Quantizer Bits: 16
        (module): RotatedOVProj(in_features=4, out_features=4, bias=False)
        (quantizer): ActQuantizer()
        (out_quantizer): ActQuantizer()
      )
      (V): ActQuantWrapper(
        Input Quantizer Bits: 16
        Output Quantizer Bits: 16
        (module): RotatedOVProj(in_features=4, out_features=4, bias=False)
        (quantizer): ActQuantizer()
        (out_quantizer): ActQuantizer()
      )
    )
  )
)

In [12]:
for p in model.parameters():
    p.requires_grad_(False)

In [13]:
distribute_model(model)

{0: 50758680576, 1: 50758680576, 2: 50758680576, 3: 50758680576, 4: 50758680576, 5: 50758680576, 6: 50758680576, 7: 50758680576, 'cpu': 255111868416}
{0: 92, 1: 92, 2: 92, 3: 92, 4: 92, 5: 92, 6: 92, 7: 50758680576, 'cpu': 255111868416}

Treating module heads.
Not enough space on 0 to put heads (space available 76, module size 192).
Splitting heads.

Treating module heads.0.
Not enough space on 0 to put heads.0 (space available 76, module size 96).
Splitting heads.0.

Treating module heads.0.module.
Not enough space on 0 to put heads.0.module (space available 76, module size 96).
Splitting heads.0.module.

Treating module heads.0.module.weight.
  Found the relevant tied param groups [['heads.0.module.module.weight', 'heads.0.module.weight']]
  So those parameters need to be taken into account ['heads.0.module.module.weight']
Modules to treat ['heads.0.module.module', 'heads.0.module.quantizer', 'heads.0.module.out_quantizer', 'heads.1', 'layers']
Tied params ['heads.0.module.module.wei

In [14]:
x = torch.randn(1, 2, 4)

In [15]:
y = model(x, Q1, Q2s)

linear v torch.Size([2, 2, 4]) torch.Size([4, 4])
linear o torch.Size([4, 2, 2]) torch.Size([2, 2])
linear v torch.Size([2, 2, 4]) torch.Size([4, 4])
linear o torch.Size([4, 2, 2]) torch.Size([2, 2])


In [18]:
target = torch.ones(1, 2, 4)

In [19]:
loss = (y - target).pow(2).mean()
loss.backward()

In [20]:
optimizer = torch.optim.SGD([Q1, Q2s], 1e-3)

In [21]:
Q1.grad

tensor([[0.0094, 0.0114, 0.0078, 0.0095],
        [0.0219, 0.0235, 0.0207, 0.0220],
        [0.0153, 0.0186, 0.0126, 0.0155],
        [0.0018, 0.0018, 0.0018, 0.0018]])

In [22]:
Q2s.grad

tensor([[[0.0044, 0.0044],
         [0.0119, 0.0119]],

        [[0.0027, 0.0027],
         [0.0136, 0.0136]]])

In [5]:
a = a.to('cuda:0')
b = b.to('cuda:1')

In [6]:
q = torch.ones((4, 4)) / 4
Q = nn.Parameter(q, requires_grad=True)

In [7]:
x = torch.randn(4).to('cuda:0')

In [8]:
x1 = a(x, Q)

In [9]:
x1 = x1.to('cuda:1')

In [10]:
y = b(x1)

In [11]:
target = torch.ones(4, 4).to('cuda:1')

In [12]:
loss = (y - target).pow(2).mean()

In [13]:
loss.backward()

In [14]:
Q.grad

tensor([[-0.0046,  0.0215, -0.0045, -0.0481],
        [ 0.0032, -0.0150,  0.0032,  0.0336],
        [ 0.0069, -0.0322,  0.0068,  0.0720],
        [ 0.0020, -0.0091,  0.0019,  0.0205]])

In [14]:
Q.grad

tensor([[-0.0647, -0.0342,  0.2066,  0.0121],
        [ 0.0619,  0.0657,  0.0959,  0.0715],
        [-0.0967, -0.0812,  0.0414, -0.0577],
        [ 0.0437,  0.0336, -0.0464,  0.0182]])

In [13]:
class TopModel(nn.Module):
    def __init__(self, modules):
        super().__init__()
        self.layers = nn.ModuleList(modules)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        
        return x

In [14]:
class Wrapper(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
    
    def forward(self, x):
        x = self.module(x)
        
        return x

In [15]:
class RotatedLinear(nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Q=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        
        if Q is not None:
            self.register_parameter('Q', Q)
        else:
            self.Q = None
    
    def forward(self, x):
        W = self.weight
        b = self.bias
        
        #if W.device != self.Q.device:
        #    self.Q = self.Q.to(W.device)
        
        if self.Q is not None:
            W_ = torch.matmul(self.Q.to(W.device).T, W.to(dtype=self.Q.dtype)).to(dtype=W.dtype)
            
            #print('linear out')
            #print(W_.grad_fn)
            if b is not None:
                b_ = torch.matmul(self.Q.to(W.device).T, b.to(dtype=self.Q.dtype)).to(dtype=b.dtype)
            else:
                b_ = b
        else:
            W_ = W
            b_ = b
        
        return F.linear(
            x, W_, b_
        )

In [16]:
class RotatedOVProj(nn.Linear):
    
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        device=None,
        dtype=None,
        Qin=None,
        Qout=None,
        output=False,
        nheads=None):
        super().__init__(in_features, out_features, bias, device, dtype)
    
        if Qin is not None:
            self.register_buffer("Qin", Qin)
            #self.register_parameter("Qin", Qin)
        else:
            self.Qin = None
        
        if Qout is not None:
            self.register_buffer("Qout", Qout)
            #self.register_parameter("Qout", Qout)
        else:
            self.Qout = None
        
        self.output = output
        self.nheads = nheads
    
    def forward(self, x):
        W = self.weight
        
        # if W.device != self.Qin.device:
        #     self.Qin = self.Qin.to(W.device)
        
        # if W.device != self.Qout.device:
        #     self.Qout = self.Qout.to(W.device)
        
        if self.Qin is not None:
            if self.output:
                W_ = torch.matmul(W.to(dtype=self.Qin.dtype), self.Qin.to(W.device)).to(dtype=W.dtype)
            else:
                W_ = W.to(dtype=self.Qin.dtype).reshape(W.size(0), self.nheads, -1)
                W_ = torch.einsum('inh,hj->inj', W_, self.Qin.to(W.device)).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear o', W.grad_fn, W_.grad_fn, self.Qin.grad_fn)
        else:
            W_ = W
        
        if self.Qout is not None:
            if self.output:
                W_ = W_.to(dtype=self.Qout.dtype).reshape(self.nheads, -1, W.size(1))
                W_ = torch.einsum('ih,nhj->nij', self.Qout.to(W.device).T, W_).reshape(W.size(0), -1).to(dtype=W.dtype)
                
                #print('linear v', W.grad_fn, W_.grad_fn, self.Qout.grad_fn)
                #print(W_.grad_fn)
            else:
                W_ = torch.matmul(self.Qout.to(W.device).T, W_.to(dtype=self.Qout.dtype)).to(dtype=W.dtype)
        else:
            pass
        
        return F.linear(
            x, W_,
        )

In [17]:
q = torch.eye(4)
Q = nn.Parameter(q, requires_grad=True)

q1 = torch.eye(2)
Q1 = nn.Parameter(q1, requires_grad=True)
q2 = torch.eye(2)
Q2 = nn.Parameter(q2, requires_grad=True)

In [18]:
a = RotatedLinear(4, 4, False, Q=Q)
b = RotatedLinear(4, 4, False, Q=Q)
A = ActQuantWrapper(a)
B = ActQuantWrapper(b)
c = RotatedOVProj(4, 4, False, Qin=Q, Qout=Q1, nheads=2, output=True)
d = RotatedOVProj(4, 4, False, Qin=Q1, Qout=Q, nheads=2, output=False)
C = ActQuantWrapper(c)
D = ActQuantWrapper(d)

In [19]:
model = TopModel([A, B, C, D])

In [20]:
max_memory = get_balanced_memory(
    model,
    no_split_module_classes=[RotatedLinear, RotatedOVProj],
)

In [21]:
max_memory

{0: 121,
 1: 121,
 2: 121,
 3: 121,
 4: 121,
 5: 121,
 6: 121,
 7: 50482708480,
 'cpu': 258156986368}

In [22]:
device_map = infer_auto_device_map(
    model, max_memory=max_memory, no_split_module_classes=[RotatedLinear, RotatedOVProj], verbose=True
)


Treating module layers.
Not enough space on 0 to put layers (space available 121, module size 528).
Splitting layers.

Treating module layers.0.
  Found the relevant tied param groups [['layers.0.module.Q', 'layers.1.module.Q']]
  So those parameters need to be taken into account ['layers.1.module.Q']
Not enough space on 0 to put layers.0 (space available -23, module size 160).
Splitting layers.0.

Treating module layers.0.module.
  Found the relevant tied param groups [['layers.0.module.Q', 'layers.1.module.Q']]
  So those parameters need to be taken into account ['layers.1.module.Q']
Not enough space on 0 to put layers.0.module (space available -23, module size 128).
This module cannot be split, going to the next device.

Treating module layers.0.module.
  Found the relevant tied param groups [['layers.0.module.Q', 'layers.1.module.Q']]
  So those parameters need to be taken into account ['layers.1.module.Q']
Not enough space on 1 to put layers.0.module (space available 121, module 

In [23]:
for p in model.parameters():
    p.requires_grad_(False)

Q.requires_grad_(True)
Q1.requires_grad_(True)

Parameter containing:
tensor([[1., 0.],
        [0., 1.]], requires_grad=True)

In [24]:
x = torch.randn(1, 2, 4)
y = torch.randn(1, 2, 4)

In [25]:
optimizer = torch.optim.SGD([Q, Q1], 1e-3)

In [26]:
pred = model(x)
loss = (y - pred).pow(2).mean()
loss.backward()
optimizer.step()

In [27]:
Q

Parameter containing:
tensor([[ 1.0001e+00,  3.7043e-05,  6.6987e-05,  9.8963e-05],
        [ 1.1188e-04,  1.0001e+00, -6.4361e-05, -9.0941e-05],
        [ 8.1210e-06, -7.2100e-05,  9.9999e-01,  1.3119e-05],
        [-3.2840e-05, -1.2416e-04,  2.6047e-05,  1.0001e+00]],
       requires_grad=True)

In [28]:
Q1

Parameter containing:
tensor([[ 9.9994e-01, -4.9638e-05],
        [-4.9638e-05,  1.0002e+00]], requires_grad=True)

In [29]:
optimizer.zero_grad()
pred = model(x)
loss = (y - pred).pow(2).mean()
loss.backward()
optimizer.step()

In [30]:
Q

Parameter containing:
tensor([[ 1.0003e+00,  7.4198e-05,  1.3396e-04,  1.9791e-04],
        [ 2.2387e-04,  1.0002e+00, -1.2878e-04, -1.8199e-04],
        [ 1.6252e-05, -1.4422e-04,  9.9997e-01,  2.6272e-05],
        [-6.5733e-05, -2.4846e-04,  5.2157e-05,  1.0001e+00]],
       requires_grad=True)

In [31]:
Q1

Parameter containing:
tensor([[ 9.9987e-01, -9.9302e-05],
        [-9.9302e-05,  1.0004e+00]], requires_grad=True)