# Test `torch.linalg.eig` vs `np.linalg.eig`

Test the stability of `torch.lineal.eig` and compare to `np.linalg.eig`.
I saw that torch `eig` returns complex values often, whereas numpy `eig` almost never.

**Conclusion**: I will use `np.linalg.eig`. See below

In [1]:
import alexnet.data as data
import alexnet.transforms as transforms
import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def torch_pca(tensor: torch.Tensor) -> torch.Tensor:
    # tensor shape: C x H x W
    assert (tensor >= 0).all() and (tensor <= 1).all()
    assert isinstance(tensor, torch.Tensor), "PCAAugment applies only to tensors"
    nchannels = tensor.shape[0]
    pixels = tensor.view(nchannels, -1)
    # substracting mean is the first step to PCA
    pixels = pixels - torch.mean(pixels, dim=1, keepdim=True)
    # shape: C x C
    corr = torch.corrcoef(pixels)
    # C          C x C
    eigenvalues, eigenvectors = torch.linalg.eig(corr)
    assert torch.isreal(eigenvalues).all() and torch.isreal(eigenvectors).all()
    eigenvalues, eigenvectors = torch.real(eigenvalues), torch.real(eigenvectors)
    # C
    alpha = 0.1 * torch.randn(3)
    # C
    delta: torch.Tensor = eigenvectors @ (alpha * eigenvalues)
    return torch.clamp(tensor + delta[:, None, None], 0, 1)


def numpy_pca(tensor: torch.Tensor) -> torch.Tensor:
    """Same as torch_pca, but convert to numpy and the back to tensor"""
    # tensor shape: C x H x W
    assert (tensor >= 0).all() and (tensor <= 1).all()
    assert isinstance(tensor, torch.Tensor), "PCAAugment applies only to tensors"
    nchannels = tensor.shape[0]
    pixels = tensor.view(nchannels, -1)
    # substracting mean is the first step to PCA
    pixels = pixels - torch.mean(pixels, dim=1, keepdim=True)
    # shape: C x C
    corr = torch.corrcoef(pixels).numpy()
    eigenvalues, eigenvectors = np.linalg.eig(corr)
    assert np.isreal(eigenvalues).all() and np.isreal(eigenvectors).all()
    eigenvalues, eigenvectors = torch.from_numpy(np.real(eigenvalues)), torch.from_numpy(np.real(eigenvectors))
    # C
    alpha = 0.1 * torch.randn(3)
    # C
    delta: torch.Tensor = eigenvectors @ (alpha * eigenvalues)
    return torch.clamp(tensor + delta[:, None, None], 0, 1)

In [3]:
dataset = data.ImageNet("../data", "train", transforms.Compose([transforms.ToTensor(), torch_pca]))

bad_tensor_pca = []

for i in tqdm.trange(len(dataset), ncols=120):
    try:
        dataset[i]
    except AssertionError:
        bad_tensor_pca.append(i)

100%|██████████████████████████████████████████████████████████████████████| 1281167/1281167 [2:28:49<00:00, 143.47it/s]


In [4]:
print(len(bad_tensor_pca), bad_tensor_pca)

2539 [2483, 2652, 3139, 3506, 3842, 4048, 4059, 4855, 4894, 5358, 5535, 5818, 8375, 8665, 9649, 10416, 12758, 13037, 13162, 13333, 13699, 14054, 14280, 14983, 17084, 17265, 17344, 17494, 17855, 17912, 18289, 18682, 19837, 20323, 23488, 24091, 24484, 25106, 25137, 25641, 29743, 29772, 29912, 29936, 30285, 30300, 30352, 30487, 30539, 30724, 30928, 30953, 31643, 32340, 32398, 33628, 35555, 36206, 36356, 36407, 36411, 36466, 36900, 36968, 37010, 37021, 37030, 37046, 37059, 37248, 37412, 38507, 39074, 39737, 40550, 40774, 40920, 40945, 41029, 41549, 41560, 42355, 42437, 42565, 42699, 43409, 43503, 43548, 44450, 45492, 45526, 46073, 49494, 49534, 49796, 51662, 53109, 53654, 54393, 54925, 55104, 55105, 57754, 57849, 58602, 58918, 59950, 60154, 60772, 60781, 60921, 61094, 63349, 63683, 64271, 64387, 64473, 64551, 64586, 64676, 64741, 64886, 64946, 65041, 65064, 65478, 65665, 65733, 67184, 67188, 67274, 67492, 67806, 67896, 68059, 68539, 70258, 73301, 73312, 73313, 73346, 73351, 73371, 73390, 7

In [5]:
dataset = data.ImageNet("../data", "train", transforms.Compose([transforms.ToTensor(), numpy_pca]))

bad_numpy_pca = []

for i in tqdm.trange(len(dataset), ncols=120):
    try:
        dataset[i]
    except AssertionError:
        bad_numpy_pca.append(i)

100%|██████████████████████████████████████████████████████████████████████| 1281167/1281167 [2:25:54<00:00, 146.35it/s]


In [6]:
print(len(bad_numpy_pca), bad_numpy_pca)

0 []
