In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np


In [2]:
torch.set_default_dtype(torch.double)
torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

100%|██████████| 9.91M/9.91M [00:06<00:00, 1.64MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 241kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.59MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.55MB/s]


In [10]:
# Standard model without SVD
class Standart_model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Standart_model, self).__init__()
        self.dense1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x

In [25]:
# SVD grad method for SVD_dense
class SVD_dense(nn.Module):
    def __init__(self, in_features, out_features, rank=None):
        super(SVD_dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank if rank is not None else min(in_features, out_features)

        # Initialize learnable parameters
        self.U = nn.Parameter(torch.randn(out_features, self.rank))
        self.raw_S = nn.Parameter(torch.randn(self.rank))
        self.Vh = nn.Parameter(torch.randn(self.rank, in_features))
        
        # Initialize orthogonal matrix
        with torch.no_grad():
            Q_u, _ = torch.linalg.qr(self.U.data)
            self.U.data.copy_(Q_u)
            Q_v, _ = torch.linalg.qr(self.Vh.T.data)
            self.Vh.data.copy_(Q_v.T)
    
    def get_S(self):
        return nn.functional.softplus(self.raw_S)
    
    def forward(self, x):
        S = self.get_S()
        weight = (self.U * S.unsqueeze(0)) @ self.Vh
        return x @ weight.T
    
    def prune(self, threshold_ratio=0.1):
        S = self.get_S().detach()
        max_s = torch.max(S)
        threshold = max_s * threshold_ratio
        
        mask = S > threshold
        new_rank = torch.sum(mask).item()
        
        if new_rank == 0:
            new_rank = 1
            mask[0] = True
        
        with torch.no_grad():
            self.U.data = self.U.data[:, mask]
            self.raw_S.data = self.raw_S.data[mask]
            self.Vh.data = self.Vh.data[mask, :]
        
        self.rank = new_rank
        return new_rank
    
    def orthogonality_regularization(self):
        I_r = torch.eye(self.rank, device=self.U.device)
        U_loss = torch.norm(self.U.T @ self.U - I_r, p='fro')**2
        V_loss = torch.norm(self.Vh @ self.Vh.T - I_r, p='fro')**2
        return U_loss + V_loss

In [12]:
# SVD_model with SVD
class SVD_Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rank_ratio=0.5):
        super(SVD_Model, self).__init__()
        self.svd1 = SVD_dense(input_size, hidden_size, int(hidden_size * rank_ratio))
        self.relu = nn.ReLU()
        self.svd2 = SVD_dense(hidden_size, output_size, int(output_size * rank_ratio))
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.svd1(x))
        x = self.svd2(x)
        return x
    
    def ortho_loss(self):
        return self.svd1.orthogonality_regularization() + self.svd2.orthogonality_regularization()
    
    def prune(self, threshold_ratio=0.1):
        rank1 = self.svd1.prune(threshold_ratio)
        rank2 = self.svd2.prune(threshold_ratio)
        return rank1, rank2

In [18]:
# train 
def train_model(model, train_loader, test_loader, epochs, ortho_weight=0.1, is_svd=False):
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    train_losses, test_accuracies = [], []
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):

            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            
            if is_svd:
                ortho_loss = model.ortho_loss()
                loss += ortho_weight * ortho_loss
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        train_losses.append(total_loss / len(train_loader))
        
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:

                data, target = data.to(device), target.to(device)

                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        accuracy = 100 * correct / total
        test_accuracies.append(accuracy)
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Loss = {train_losses[-1]:.4f}, Accuracy = {accuracy:.2f}%')
    
    return train_losses, test_accuracies

In [19]:
input_size = 28 * 28
hidden_size = 256
output_size = 10
epochs = 50

In [20]:
with torch.cuda.device(0):
    print("Обучаем стандартную модель")
    standard_model = Standart_model(input_size, hidden_size, output_size)
    standard_loss, standard_acc = train_model(standard_model, train_loader, test_loader, epochs)

Обучаем стандартную модель
Epoch 0: Loss = 0.3997, Accuracy = 92.86%
Epoch 10: Loss = 0.0437, Accuracy = 97.45%
Epoch 20: Loss = 0.0191, Accuracy = 97.65%
Epoch 30: Loss = 0.0125, Accuracy = 97.83%
Epoch 40: Loss = 0.0085, Accuracy = 97.99%


In [23]:
standard_model_cpu = standard_model.cpu()

print(standard_model)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = count_parameters(standard_model)
print(f"Общее количество параметров: {total_params:,}")



Standart_model(
  (dense1): Linear(in_features=784, out_features=256, bias=True)
  (relu): ReLU()
  (dense2): Linear(in_features=256, out_features=10, bias=True)
)
Общее количество параметров: 203,530


In [27]:
with torch.cuda.device(0):
    print("Обучаем SVD модель")
    svd_model = SVD_Model(input_size, hidden_size, output_size)
    svd_loss, svd_acc = train_model(svd_model, train_loader, test_loader, epochs, is_svd=True)

Обучаем SVD модель
Epoch 0: Loss = 0.8425, Accuracy = 88.09%
Epoch 10: Loss = 0.1273, Accuracy = 95.72%
Epoch 20: Loss = 0.0849, Accuracy = 96.96%
Epoch 30: Loss = 0.0765, Accuracy = 96.94%
Epoch 40: Loss = 0.0768, Accuracy = 97.26%


In [28]:
print(svd_model)

total_params_svd = count_parameters(svd_model)
print(f"Общее количество параметров: {total_params_svd:,}")

SVD_Model(
  (svd1): SVD_dense()
  (relu): ReLU()
  (svd2): SVD_dense()
)
Общее количество параметров: 134,583


In [29]:
with torch.cuda.device(0): 
    print("Применяем прунинг")
    rank1, rank2 = svd_model.prune(threshold_ratio=0.1)
    print(f"Новые ранги после прунинга: fc1={rank1}, fc2={rank2}")

    print("Дообучаем после прунинга")
    svd_pruned_loss, svd_pruned_acc = train_model(svd_model, train_loader, test_loader, epochs//2, is_svd=True)

Применяем прунинг
Новые ранги после прунинга: fc1=70, fc2=5
Дообучаем после прунинга
Epoch 0: Loss = 0.0698, Accuracy = 96.90%
Epoch 10: Loss = 0.0618, Accuracy = 96.82%
Epoch 20: Loss = 0.0661, Accuracy = 97.16%


In [31]:
print(svd_pruned_acc[-1])

97.31


In [30]:
print(svd_model)

total_params_svd_2 = count_parameters(svd_model)
print(f"Общее количество параметров: {total_params_svd_2:,}")

SVD_Model(
  (svd1): SVD_dense()
  (relu): ReLU()
  (svd2): SVD_dense()
)
Общее количество параметров: 74,205
