In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Resize((32, 32)),
     transforms.Normalize((0.5,), (0.5,))])

batch_size = 4

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
class DCT_layer(nn.Module):
    def __init__(self, out_features: int):
        super(DCT_layer, self).__init__()
        
        self.out_features = out_features
        
        default_dtype = torch.get_default_dtype()
        self.register_parameter('y', torch.nn.Parameter(torch.zeros(self.out_features, dtype=default_dtype)))
        
        self.y.register_hook(lambda grad: grad )
    
    def dct2(self, x, norm=None):
        N = len(x)
        n = np.arange(N)

        coeff_vec = []

        for k in n:
            yk = 0
            for i in n:
                yk += x[i] * np.cos(np.pi* k*(2*i + 1)/(2*N))
            if norm=='ortho':
                if k == 0:
                    yk = np.sqrt(1/(4*N)) * yk
                else:
                    yk = np.sqrt(1/(2*N)) * yk
            yk = yk*2
            coeff_vec.append(yk)
        
#         print(torch.stack(coeff_vec))
        return np.array(coeff_vec)

    def dct_mat(self,x):  ## generate dct matrix
        dct_mat = self.dct2(np.eye(x.shape[1],x.shape[1]))
        return torch.from_numpy(dct_mat).float()
        
    def forward(self,x):
        y = torch.matmul(x,self.dct_mat(x)) ## multiply with DCT matrix
        return y

In [4]:
# device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
class DCT_LeNet(nn.Module):

    def __init__(self):
        super(DCT_LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 120, 5)  
        
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)
        self.dct = DCT_layer(10)


    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = F.relu(self.conv3(x))
        
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.dct(x) 
        return x


dct_net = DCT_LeNet()
print(dct_net)


DCT_LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=120, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=10, bias=True)
  (dct): DCT_layer()
)


In [6]:
class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 120, 5)  
        
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = F.relu(self.conv3(x))
        
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


lenet = LeNet()
print(lenet)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=120, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=10, bias=True)
)


In [7]:
def train(dataloader,model,criterion,optimizer):

    train_loss = 0.0
    for X, y in dataloader:
        inputs, labels = X, y
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics
        train_loss += loss.item()*inputs.size(0)
    train_loss = train_loss/len(dataloader)
    
    print(f'Training Loss: {train_loss:.8f}')

In [8]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X, y
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [9]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(dct_net.parameters(), lr=0.001, momentum=0.9)

print("DCT_LeNet")
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(trainloader, dct_net, criterion, optimizer)
    test(testloader, dct_net, criterion)


DCT_LeNet
Epoch 1
-------------------------------


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Training Loss: 0.63259602
Test Error: 
 Accuracy: 98.2%, Avg loss: 0.053861 

Epoch 2
-------------------------------
Training Loss: 0.22159033
Test Error: 
 Accuracy: 98.5%, Avg loss: 0.046012 

Epoch 3
-------------------------------
Training Loss: 0.16187211
Test Error: 
 Accuracy: 98.1%, Avg loss: 0.060557 



In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(lenet.parameters(), lr=0.001, momentum=0.9)

print("LeNet")
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(trainloader, lenet, criterion, optimizer)
    test(testloader, lenet, criterion)


LeNet
Epoch 1
-------------------------------
Training Loss: 0.96894082
Test Error: 
 Accuracy: 98.3%, Avg loss: 0.049676 

Epoch 2
-------------------------------
Training Loss: 0.23556465
Test Error: 
 Accuracy: 98.2%, Avg loss: 0.057111 

Epoch 3
-------------------------------
Training Loss: 0.16170212
Test Error: 
 Accuracy: 98.7%, Avg loss: 0.037079 

