In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from torch.utils.data import DataLoader

In [2]:
# 1 Data Transformer
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5,),(0.5,))])

# 2 Create Train Dataset
trainset = torchvision.datasets.MNIST(root='./data', train =True,
                                    download = True, transform = transform)
trainloader = DataLoader(trainset, batch_size =64, shuffle =True)

#3 Create Test Dataset
testset = torchvision.datasets.MNIST(root = "./data", train = False,
                                     download = True, transform = transform)
testloader = DataLoader(testset, batch_size=64, shuffle=True)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
class NeuralNetwork(nn.Module):

    def __init__(self):
        super(NeuralNetwork,self).__init__()

        self.fc1 = nn.Linear(28*28,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,128)
        self.fc4 = nn.Linear(128,10)


    def forward(self,x):
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

In [5]:
net = NeuralNetwork().to(device)
optimizer = optim.Adam(net.parameters(),lr = 0.001)
criterion = nn.CrossEntropyLoss().to(device)
epoch_size = 5

In [6]:
net.train()

for epoch in range(epoch_size):
    epoch_loss = 0
    correct_predictions = 0
    total_predictions = 0
    for data in trainloader:
        
        
        X, y = data[0].to(device), data[1].to(device)
    
        preds = net(X)
        loss = criterion(preds,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss+= loss

        #Calculate accuracy

        _, predicted = torch.max(preds,1)
        correct_predictions += (predicted ==y).sum().item()
        total_predictions+= y.size(0)

    accuracy = correct_predictions / total_predictions

    average_loss = epoch_loss / len(trainloader)

    print(f"Epoch {epoch+1} Loss: {average_loss:.4f} Accuracy: {accuracy:.4f}")
    


Epoch 1 Loss: 0.3217 Accuracy: 0.8987
Epoch 2 Loss: 0.1440 Accuracy: 0.9557
Epoch 3 Loss: 0.1093 Accuracy: 0.9660
Epoch 4 Loss: 0.0886 Accuracy: 0.9720
Epoch 5 Loss: 0.0749 Accuracy: 0.9762


In [12]:
#faktorize the chosen layer

print(net.fc3.weight.shape,end ="\n" + "-------"*10 + "\n")

U, S, V = torch.svd(net.fc3.weight)

print(U.shape)
print(S.shape)
print(V.shape)

torch.Size([128, 256])
----------------------------------------------------------------------
torch.Size([128, 128])
torch.Size([128])
torch.Size([256, 128])


In [13]:
# Define rank
rank =64

#Truncate U, S and V
U_low_rank = U[:, :rank]
S_low_rank =torch.diag(S[:rank])
V_low_rank = V[:, :rank]

In [14]:
print(U_low_rank.shape) # torch.Size([128, 64])
print(S_low_rank.shape) # torch.Size([64, 64])
print(V_low_rank.shape) # torch.Size([256, 64])

torch.Size([128, 64])
torch.Size([64, 64])
torch.Size([256, 64])
