In [1]:
import numpy as np
from scipy.fft import fft, dct, idct
import matplotlib.pyplot as plt
import torch
from torch import nn

In [2]:
x = np.random.randn(8,8)
x

array([[ 1.73215191,  0.54294611,  1.56590685, -0.91695743, -1.00120418,
        -0.80280446, -0.97943249,  0.17296584],
       [ 0.43183087,  1.48997897, -0.49781624,  1.34170328, -0.91755683,
         0.05664579,  0.16711643,  0.71790835],
       [-0.46605828, -0.84193941,  2.27934631,  2.14197876,  0.29502806,
         0.62191721,  0.29632082, -0.85321215],
       [-0.85501491,  0.72052904,  0.19869512,  0.31871158, -1.90068156,
         0.7199079 ,  0.51315394, -1.54138014],
       [ 0.53158408,  0.89551265,  0.42123485, -0.14232498, -0.34953415,
         1.07922836,  0.61428987,  0.18008172],
       [ 0.1202554 , -1.726     ,  1.84161538,  1.36929263,  1.10586899,
         0.64084555, -0.59963633, -0.18126166],
       [ 0.17428156,  1.46111922, -0.14085399, -0.18912921,  0.25636288,
         2.2096622 , -0.75866467,  1.01377409],
       [-0.78908171,  0.4447428 ,  0.07465252,  1.49430559, -0.27393362,
        -0.71875   , -0.72012747,  2.24360725]])

In [3]:
def dct2(x, norm=None):
    N = len(x)
    n = np.arange(N)

    coeff_vec = []
    
    for k in n:
        yk = 0
        for i in n:
            yk += x[i] * np.cos(np.pi* k*(2*i + 1)/(2*N))
        if norm=='ortho':
            if k == 0:
                yk = np.sqrt(1/(4*N)) * yk
            else:
                yk = np.sqrt(1/(2*N)) * yk
        yk = yk*2
#         print(yk)
        coeff_vec.append(yk)
        
    return np.array(coeff_vec)

In [4]:
x_t = torch.tensor(x)
x_t

tensor([[ 1.7322,  0.5429,  1.5659, -0.9170, -1.0012, -0.8028, -0.9794,  0.1730],
        [ 0.4318,  1.4900, -0.4978,  1.3417, -0.9176,  0.0566,  0.1671,  0.7179],
        [-0.4661, -0.8419,  2.2793,  2.1420,  0.2950,  0.6219,  0.2963, -0.8532],
        [-0.8550,  0.7205,  0.1987,  0.3187, -1.9007,  0.7199,  0.5132, -1.5414],
        [ 0.5316,  0.8955,  0.4212, -0.1423, -0.3495,  1.0792,  0.6143,  0.1801],
        [ 0.1203, -1.7260,  1.8416,  1.3693,  1.1059,  0.6408, -0.5996, -0.1813],
        [ 0.1743,  1.4611, -0.1409, -0.1891,  0.2564,  2.2097, -0.7587,  1.0138],
        [-0.7891,  0.4447,  0.0747,  1.4943, -0.2739, -0.7187, -0.7201,  2.2436]],
       dtype=torch.float64)

In [10]:
x_t.shape[0:2]

torch.Size([8, 8])

In [12]:
dct_mat = dct(np.eye(8,8))

dct_mat_t = torch.from_numpy(dct_mat)

In [15]:
w = nn.Linear(x.shape[-1],dct_mat.shape[-1])
w

TypeError: __init__() missing 1 required positional argument: 'out_features'

In [17]:
x_t.float()

tensor([[ 1.4712, -0.6680,  1.4762, -0.3055,  0.7022,  0.2845, -0.0600, -0.2740],
        [ 1.7415, -0.2349, -1.3155, -0.9800, -1.4243,  0.0251,  0.8457,  0.0347],
        [-0.8676,  0.3023,  1.4543,  0.0264, -0.8784, -1.2394, -0.1875, -0.0977],
        [-1.8637,  0.6358,  1.3720,  0.8934, -0.8793,  0.4587,  2.0996, -0.5554],
        [ 0.6220,  1.3647,  1.5779, -0.4486, -1.6523,  0.1901, -0.5762,  0.1141],
        [ 0.1218,  0.8237,  0.5210,  1.1253, -0.6004,  0.1044,  0.4689, -0.5826],
        [-1.2013,  0.0945, -1.3161,  1.0133, -0.1885,  0.0860, -0.3563, -0.8378],
        [-0.2395,  0.6671,  0.3260,  0.5894,  0.8885,  0.7625, -0.0207,  1.5142]])

In [16]:
w(x_t.float())

tensor([[-0.9320,  0.8032,  0.0978, -1.0619,  0.5850,  0.0866, -0.5464,  0.8784],
        [-0.2679, -0.3967,  0.1066, -1.9992, -0.0438, -0.0728, -0.3301, -0.2624],
        [-0.0883,  0.4035, -0.1068,  0.1774, -0.1371,  0.5819, -0.0672,  0.3043],
        [ 0.3696, -1.3498,  0.2293, -0.0891, -0.2107,  1.3288,  0.0360, -0.2566],
        [-1.0969, -0.1031, -1.0905, -0.8541,  0.6813,  0.2211, -0.8763,  0.3768],
        [-0.1404, -0.8817, -0.1616, -0.0644,  0.0186,  0.3967,  0.4071,  0.2387],
        [ 0.7883, -0.9928,  0.5209,  0.5251, -0.6042,  0.1857,  1.1361, -0.6475],
        [-0.2791, -0.1745, -0.2613,  0.0621, -0.2188, -0.7187,  0.0553, -0.0798]],
       grad_fn=<AddmmBackward>)

In [14]:
nn.Parameter(torch.matmul(x_t,dct_mat_t)).shape

torch.Size([8, 8])

In [12]:
dct(x)

array([[ 5.25297253,  3.34320115, -0.42555331,  1.92134655,  0.79365402,
         1.92096491,  5.2112503 ,  5.31501372],
       [-2.61543931,  0.23481018,  9.17980728,  5.39598995,  0.07292985,
         4.23160028, -0.3133831 , -1.23432864],
       [-2.97499493,  2.65037318, -0.2858436 , -7.76052158, -3.03637589,
         0.73940275,  0.09809528,  1.85990469],
       [ 4.32203494, -3.29397679, -3.80359423, -5.36573971, -9.85867482,
         4.72206198, -3.53415001, -0.84242621],
       [ 2.38338335,  6.23563183,  4.49240541, -3.97266837, -5.54568973,
        -0.69964077,  3.98124922, -2.01173588],
       [ 3.96406831,  3.1078359 , -1.31069221, -1.70185721, -2.62177314,
         3.11895831, -1.98704827, -2.81187932],
       [-5.41240231, -1.05216257, -4.55057255,  0.63457023,  0.39265741,
         0.16332852, -3.98109797, -5.33190049],
       [ 8.97497683, -2.89809481, -0.71397381, -1.99627371,  1.43935471,
        -3.96548557,  0.6615237 , -1.5876508 ]])

In [None]:
dct2(np.eye(256,256))

In [None]:
# plt.figure()
plt.plot(dct2(np.eye(256,256))[1])
plt.plot(dct2(np.eye(256,256))[2])
plt.plot(dct2(np.eye(256,256))[3])

In [None]:
dct2(x, norm='ortho') - dct(x,2, axis=0, norm='ortho')