# SVD with a padded row demo

In [8]:
import torch
import torch.nn.functional as F
dp_size = 2
pad_dim = 0 # add a new row of zeros
rank = 2

a = torch.randn(3, 3) # working param grads
a_copy = a.clone()
working_shape = a.shape

def shard(a, rank):
    a_local = a.chunk(dp_size)[rank]
    # Galore padding for SVD
    # make it reshapable to a matrix potentially with new all-zero rows
    padding = working_shape[pad_dim] - a_local.numel() % working_shape[pad_dim]
    a_local = F.pad(a_local, [0, padding]) 
    return a_local, padding

# ZeRO padding 
a = a.view(-1)
zero_padding = a.numel() % dp_size
a_zero = F.pad(a, [0,  zero_padding])
print("padded master param shape:", a_zero.shape)

# Sharding
a_rank0, padding = shard(a_zero, 0)
a_rank1, _ = shard(a_zero, 1)
print("rank 0 padded shape:", a_rank0.shape)

# Gather
# Galore unpad 
if padding > 0:
    a_rank0 = a_rank0[:-padding]
    a_rank1 = a_rank1[:-padding]
    
a = torch.cat([a_rank0, a_rank1])
# ZeRO unpad
if zero_padding > 0:
    a = a[:-zero_padding]
# Must "re-pad" to add a zero row/col
a = a.reshape((working_shape[pad_dim], -1) if pad_dim == 0 else (-1, working_shape[pad_dim])) 
U, s, _ = torch.linalg.svd(a)

# Galore projector
galore_m = U[:rank].T
print(f"galore projector: {galore_m}")
_U, _s, _ = torch.linalg.svd(a_copy)
print(f"projector w/o padding:{_U[:rank]}")

# NOTE: We can't correctly do this...must pad whole-rows or columns 
a_rank0, padding = shard(a_zero, 0)
a_rank0 = a_rank0.reshape((working_shape[pad_dim], -1) if pad_dim == 0 else (-1, working_shape[pad_dim]))
galore_m.T @ a_rank0, galore_m.T @ a

padded master param shape: torch.Size([10])
rank 0 padded shape: torch.Size([6])
galore projector: tensor([[-0.2985, -0.0810],
        [ 0.8539, -0.4677],
        [-0.4263, -0.8802]])
projector w/o padding:tensor([[-0.2985,  0.8539, -0.4263],
        [-0.0810, -0.4677, -0.8802]])


(tensor([[-0.9201,  0.2474],
         [-0.4949,  0.1362]]),
 tensor([[-1.3687,  0.5449,  0.3914],
         [-2.0443, -0.9566,  1.4912]]))

In [9]:
a_rank0, a_rank1, a

(tensor([[ 1.0492, -1.1110],
         [-0.3780, -0.0987],
         [ 0.6666,  0.0000]]),
 tensor([-0.3969,  2.2785,  0.8350, -1.4485,  0.0000]),
 tensor([[ 1.0492, -1.1110, -0.3780],
         [-0.0987,  0.6666, -0.3969],
         [ 2.2785,  0.8350, -1.4485]]))

## This is because the left singular matrix is the eigenvectors of XX^T, which when X is padded with a zero row remain the same.
**Reference: https://en.wikipedia.org/wiki/Singular_value_decomposition**

In [10]:
a = torch.cat([torch.ones(3, 3), torch.zeros(3, 1)], dim=-1)
U, s, v = torch.linalg.svd(a @ a.T)
_, eigvec = torch.linalg.eig(a @ a.T)
_U, *_ = torch.linalg.svd(torch.ones(3,3))
U, _U, eigvec

(tensor([[-5.7735e-01,  8.1650e-01,  9.4017e-08],
         [-5.7735e-01, -4.0825e-01, -7.0711e-01],
         [-5.7735e-01, -4.0825e-01,  7.0711e-01]]),
 tensor([[-0.5774,  0.8165,  0.0000],
         [-0.5774, -0.4082, -0.7071],
         [-0.5774, -0.4082,  0.7071]]),
 tensor([[-5.7735e-01+0.j, -4.5907e-04+0.j, -4.5923e-04+0.j],
         [-5.7735e-01+0.j,  7.0734e-01+0.j, -7.0688e-01+0.j],
         [-5.7735e-01+0.j, -7.0688e-01+0.j,  7.0734e-01+0.j]]),
 tensor([[ 3.0000e+00,  4.2426e+00, -3.0526e-07],
         [ 3.0000e+00,  4.2426e+00,  1.5263e-07],
         [ 3.0000e+00,  4.2426e+00,  1.5263e-07]]))

In [7]:
import torch
import gc
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import torch
from transformers import AutoModel

model = AutoModel.from_pretrained(
            pretrained_model_name_or_path='roberta-large',
        ).to("cuda")

# del model
# gc.collect()
# torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())
for obj in gc.get_objects():
    try:
        if isinstance(obj, torch.Tensor) or (hasattr(obj, 'data') and isinstance(obj.data, torch.Tensor)):
            print(type(obj), obj.shape())
    except:
        pass

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
