In [1]:
import torch
import math
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# torch set seed
torch.manual_seed(33)

<torch._C.Generator at 0x7f9becfdd0b0>

# Dimensions

In [3]:
n = 32
d0_in = 1000
d0_out = 512
d0_out_new = 768

d1_in = d0_out
d1_in_new = d0_out_new
d1_out = 2048
d1_out_new = 3072

d2_in = d1_out
d2_in_new = d1_out_new
d2_out = 1000

In [4]:
# target weight shape
print(f"W_0 - src: 2x[{d0_out}, {d0_in}] -> tgt: [{d0_out_new}, 2x{d0_in}]")
print(f"W_1 - src: 2x[{d1_out}, {d1_in}] -> tgt: [{d1_out_new}, {d1_in_new}]")
print(f"W_2 - src: 3x[{d2_out}, {d2_in}] -> tgt: [2x{d2_out}, {d2_in_new}]")

W_0 - src: 2x[512, 1000] -> tgt: [768, 2x1000]
W_1 - src: 2x[2048, 512] -> tgt: [3072, 768]
W_2 - src: 3x[1000, 2048] -> tgt: [2x1000, 3072]


# Weights

In [5]:
# input
x1 = torch.rand(n, d0_in) / math.sqrt(d0_in)
x2 = torch.rand(n, d0_in) / math.sqrt(d0_in)

x = torch.concat((x1, x2), dim=-1)
x.shape

torch.Size([32, 2000])

In [6]:
# 0th layer weights
A0 = torch.rand(d0_out, d0_in) / math.sqrt(d0_out)
B0 = torch.rand(d0_out, d0_in) / math.sqrt(d0_out)
zeros = torch.zeros(d0_out, d0_in)
W0_diag = torch.concat((torch.concat((A0, zeros), dim=0), torch.concat((zeros, B0), dim=0)), dim=-1)

print(A0.shape, B0.shape)
print(W0_diag.shape)

torch.Size([512, 1000]) torch.Size([512, 1000])
torch.Size([1024, 2000])


In [7]:
# 1th layer weights
A1 = torch.rand(d1_out, d1_in) / math.sqrt(d1_out)
B1 = torch.rand(d1_out, d1_in) / math.sqrt(d1_out)
zeros = torch.zeros(d1_out, d1_in)
W1_diag = torch.concat((torch.concat((A1, zeros), dim=0), torch.concat((zeros, B1), dim=0)), dim=-1)

print(A1.shape, B1.shape)
print(W1_diag.shape)

torch.Size([2048, 512]) torch.Size([2048, 512])
torch.Size([4096, 1024])


In [8]:
# 2th layer weights
A2 = torch.rand(d2_out, d2_in) / math.sqrt(d2_out)
B2 = torch.rand(d2_out, d2_in) / math.sqrt(d2_out)
zeros = torch.zeros(d2_out, d2_in)
W2_diag = torch.concat((torch.concat((A2, zeros), dim=0), torch.concat((zeros, B2), dim=0)), dim=-1)

print(A2.shape, B2.shape)
print(W2_diag.shape)

torch.Size([1000, 2048]) torch.Size([1000, 2048])
torch.Size([2000, 4096])


In [9]:
diag_weights = [W0_diag, W1_diag, W2_diag]
separate_weights = [A0, B0, A1, B1, A2, B2]
inputs = [x1, x2, x]

for m in diag_weights + separate_weights + inputs:
    m.to("cuda:2")

# PCA

In [3]:
def pca_covariance(W, n_components):
    # PCA on covariance matrix
    U, S, V = torch.pca_lowrank(W @ W.T, q=n_components, center=False)
        
    # expansion/reduction matrix
    E = torch.sqrt(torch.diag(S)) @ U.T
    E_inv = U @ torch.linalg.pinv(torch.sqrt(torch.diag(S)))  # TODO: check if V -> U makes sense

    # output weight
    W_out_expand = E @ W
    
    return W_out_expand, E_inv

In [4]:
def pca_weight(W, n_components):
    # PCA on weight matrix
    U, S, V = torch.pca_lowrank(W, q=n_components, center=False)
        
    # expansion/reduction matrix
    E = U.T
    E_inv = E.T
    
    # output weight
    W_out_expand = torch.diag(S) @ V.T
    
    return W_out_expand, E_inv

## Diagonal

In [5]:
def expand_out_pca_diag(W_diag, d_out_new):
    # dimension check
    d_out_double, d_in_double = W_diag.shape
    assert d_out_double >= d_out_new > (d_out_double / 2)
    
    if min(d_out_double, d_in_double) <= d_out_new:
        print(f"PCA on covariance matrix, since min({d_out_double}, {d_in_double}) <= {d_out_new}")
        W_out_expand, E_inv = pca_covariance(W_diag, n_components=d_out_new)
    else:
        # NOTE: this can be done with covariance matrix as well
        print(f"PCA on weight matrix")
        W_out_expand, E_inv = pca_weight(W_diag, n_components=d_out_new)
        
    print(f"Mean reconstruction error: {torch.abs(E_inv @ W_out_expand - W_diag).mean()}")
    
    # output dimension sanity check
    assert W_out_expand.shape == (d_out_new, d_in_double)
    assert E_inv.shape == (d_out_double, d_out_new)
    
    return W_out_expand, E_inv  

In [13]:
# PCA on W0_diag
W0_out_expand, E0_inv = expand_out_pca_diag(W0_diag, d0_out_new)
print(W0_out_expand.shape, E0_inv.shape)

PCA on weight matrix
Mean reconstruction error: 0.0017343397485092282
torch.Size([768, 2000]) torch.Size([1024, 768])


In [14]:
# PCA on W1_diag
W1_out_expand, E1_inv = expand_out_pca_diag(W1_diag, d1_out_new)
print(W1_out_expand.shape, E1_inv.shape)

PCA on covariance matrix, since min(4096, 1024) <= 3072
Mean reconstruction error: 8.01863251354007e-08
torch.Size([3072, 1024]) torch.Size([4096, 3072])


In [15]:
# new weights
W0_new = W0_out_expand
W1_new = W1_out_expand @ E0_inv
W2_new = W2_diag @ E1_inv

In [16]:
# compare outputs
out = x @ W0_diag.T @ W1_diag.T @ W2_diag.T
out_new = x @ W0_new.T @ W1_new.T @ W2_new.T
torch.abs(out - out_new).mean()

tensor(0.0002)

## Separately

In [14]:
torch.Tensor([1,2,3,4]).data[:]

tensor([1., 2., 3., 4.])

In [7]:
def expand_out_pca_separately(A, B, d_out_new):
    # dimension check
    d_out, d_in = A.shape
    assert A.shape == B.shape
    assert d_out * 2 >= d_out_new > d_out
    
    d_out_new_half = d_out_new // 2
    
    W_out_expands = []
    E_invs = []
    for W in (A, B):
        if min(d_out, d_in) <= d_out_new_half:
            print(f"PCA on covariance matrix, since min({d_out}, {d_in}) <= {d_out_new_half}")
            W_out_expand, E_inv = pca_covariance(W, n_components=d_out_new_half)
        else:
            print(f"PCA on weight matrix")
            # NOTE: this can be done with covariance matrix as well
            W_out_expand, E_inv = pca_weight(W, n_components=d_out_new_half)
            
            print(f"Mean reconstruction error: {torch.abs(E_inv @ W_out_expand - W).mean()}")
            
        print(W_out_expand.shape, E_inv.shape)
            
        W_out_expands.append(W_out_expand)
        E_invs.append(E_inv)
        
    return W_out_expands, E_invs
    
    # output dimension sanity check
    assert W_out_expand.shape == (d_out_new, d_in_double)
    assert E_inv.shape == (d_out_double, d_out_new)
    
    return W_out_expand, E_inv
        
        

In [18]:
# PCA on A0, B0 separately
# diagonal weights -> W0_out_expand: [768, 2000], E0_inv: [1024, 768]
# equivelent to concatnating the output
W0_out_expands, E0_invs = expand_out_pca_separately(A0, B0, d0_out_new)

PCA on weight matrix
Mean reconstruction error: 0.0026181077118963003
torch.Size([384, 1000]) torch.Size([512, 384])
PCA on weight matrix
Mean reconstruction error: 0.002612665994092822
torch.Size([384, 1000]) torch.Size([512, 384])


In [19]:
# PCA on A1, B1 separately
# diagonal weights -> W1_out_expand: [3072, 1024], E0_inv: [4096, 3072]
W1_out_expands, E1_invs = expand_out_pca_separately(A1, B1, d1_out_new)

PCA on covariance matrix, since min(2048, 512) <= 1536
torch.Size([1536, 512]) torch.Size([2048, 1536])
PCA on covariance matrix, since min(2048, 512) <= 1536
torch.Size([1536, 512]) torch.Size([2048, 1536])


In [20]:
# new weights
W0_news = W0_out_expands
W1_news = [W1_out_expand @ E0_inv for W1_out_expand, E0_inv in zip(W1_out_expands, E0_invs)]
W2_news = [W2 @ E1_inv for W2, E1_inv in zip([A2, B2], E1_invs)]

In [None]:
# compare outputs
out = x @ W0_diag.T @ W1_diag.T @ W2_diag.T
out.shape

In [21]:
# compare outputs
out = x @ W0_diag.T @ W1_diag.T @ W2_diag.T
out.shape

torch.Size([32, 2000])

In [22]:
out_new_1 = x1 @ W0_news[0].T @ W1_news[0].T @ W2_news[0].T
out_new_2 = x2 @ W0_news[1].T @ W1_news[1].T @ W2_news[1].T
out_new_1.shape, out_new_2.shape

(torch.Size([32, 1000]), torch.Size([32, 1000]))

In [23]:
out_new = torch.concat([out_new_1, out_new_2], dim=-1)
out_new.shape

torch.Size([32, 2000])

In [24]:
torch.abs(out - out_new).mean()

tensor(0.0002)

# BERT

In [6]:
from transformers import AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")

Downloading: 100%|██████████| 570/570 [00:00<00:00, 4.64MB/s]
Downloading: 100%|██████████| 440M/440M [00:03<00:00, 114MB/s] 
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
embedding_a = model.bert.embeddings.word_embeddings.weight.detach()[:, :384]
embedding_b = model.bert.embeddings.word_embeddings.weight.detach()[:, -384:]

print(embedding_a.shape, embedding_b.shape)

torch.Size([30522, 384]) torch.Size([30522, 384])


In [10]:
# pca on concatenated embeddings
emb_out_expand, emb_inv = expand_out_pca_diag(torch.concat((embedding_a, embedding_b), dim=-1).T, 512)
emb_out_expand.shape, emb_inv.shape

PCA on weight matrix
Mean reconstruction error: 0.011709071695804596


(torch.Size([512, 30522]), torch.Size([768, 512]))

In [11]:
# pca on separate embeddings
emb_out_expands, emb_invs = expand_out_pca_separately(embedding_a.T, embedding_b.T, 512)

PCA on weight matrix
Mean reconstruction error: 0.013286596164107323
torch.Size([256, 30522]) torch.Size([384, 256])
PCA on weight matrix
Mean reconstruction error: 0.01326266210526228
torch.Size([256, 30522]) torch.Size([384, 256])
