In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from einops import rearrange, reduce, repeat
import torchvision
from torchvision import datasets, transforms
import torch 
from gkpd import gkpd, KroneckerConv2d
from gkpd.tensorops import kron



In [2]:
# torch.cuda.is_available()
# torch.cuda.device_count()

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

In [4]:
class KronLinear(nn.Module):
    def __init__(self, rank, a_shape, b_shape, bias=True) -> None:
        super().__init__()
        self.rank = rank
        self.s = nn.Parameter(torch.randn(*a_shape), requires_grad=True)
        self.a = nn.Parameter(torch.randn(rank, *a_shape), requires_grad=True)
        self.b = nn.Parameter(torch.randn(rank, *b_shape), requires_grad=True)
        nn.init.xavier_uniform_(self.a)
        nn.init.xavier_uniform_(self.b)
        bias_shape = np.multiply(a_shape, b_shape)
        if bias:
            self.bias = nn.Parameter(torch.randn(*bias_shape[1:]), requires_grad=True)
        else:
            self.bias = None
        
    def forward(self, x):
        
        a = self.s.unsqueeze(0) * self.a
        w = kron(a, self.b)
        
        out = x @ w 
        if self.bias is not None:
            out += self.bias.unsqueeze(0)
        return out
    
# test module 
x = torch.randn(64, 256)
m = KronLinear(10, (16, 64), (16, 64), bias=False)

m(x).shape 



torch.Size([64, 4096])

In [5]:

        
        
class KronLeNet(nn.Module):
    def __init__(self) -> None:
        super(KronLeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.relu2 = nn.LeakyReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        rank1 = 21
        rank2 = 10
        rank3 = 4
    
        self.kronfc1 = KronLinear(rank1, (16, 10), (16, 12), bias=False)
        
        self.kronfc2 = KronLinear(rank2, (10, 12), (12, 7), bias=False)
        self.kronfc3 = KronLinear(rank3, (12, 2), (7, 5), bias=False)
        self.relu3 = nn.LeakyReLU()
        self.relu4 = nn.LeakyReLU()


    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = self.relu3(self.kronfc1(x))
        x = self.relu4(self.kronfc2(x))
        x = self.kronfc3(x)
        return x
        

In [6]:
# kl = KronLinear(2, (2, 2), (2, 2), bias=False)
# optimizer = optim.Adam(kl.parameters(), lr=0.001)
# for i in kl.parameters():
#     print(i)
# x = torch.randn(2, 4)
# y = torch.randint(0, 2, (2, ))
# kl(x).shape
# loss = F.cross_entropy(kl(x), y)
# loss.backward()

# for i in kl.parameters():
#     print(i.grad.numpy())
# optimizer.step()
# for i in kl.parameters():
#     print(i)

In [7]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.relu2 = nn.LeakyReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.relu3 = nn.LeakyReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.LeakyReLU()
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x
# calculate the number of parameters in LeNet
def count_parameter(model):
    
    total = 0
    for i in model.parameters():
        total += i.numel()
    print(total)


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = KronLeNet().to(device)
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.01)
count_parameter(model)

12544


In [10]:
# for inputs, labels in train_loader:
#     inputs, labels = inputs.to(device), labels.to(device)
#     optimizer.zero_grad()
#     outputs = model(inputs)
#     loss = criterion(outputs, labels)
#     loss.backward()
#     optimizer.step()
#     break
# loss
# for i in model.parameters():
#     print(i.shape)
# model.kronfc1.a.grad.cpu().numpy()
def calculate_sparsity(model, threshold=1e-6):
    total_params = 0
    sparse_params = 0

    for param in model.parameters():
        total_params += param.numel()  # 统计参数总数
        sparse_params += torch.sum(torch.abs(param) < threshold).item()  # 统计绝对值小于阈值的参数数量

    sparsity = sparse_params / total_params  # 计算稀疏性
    return sparsity, sparse_params, total_params



In [11]:
def train(model, train_loader, criterion, optimizer, epochs, l1_weight=0.01):
    decay_weight = [1, 0.1, 0.01, 0.001, 0.0001]
    weight1 = l1_weight
    for epoch in range(epochs):
        running_loss = 0.0
        l1_weight = decay_weight[epoch//20] * weight1
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            # print(outputs, outputs.shape)
            loss = criterion(outputs, labels)
            loss += l1_weight * torch.norm(model.kronfc1.s, p=1)
            loss += l1_weight * torch.norm(model.kronfc2.s, p=1)
            loss += l1_weight * torch.norm(model.kronfc3.s, p=1)
            
            # print(loss)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss/len(train_loader)}")
        # print the sparsity of a, dont use == use the abs less than 1e-5
        print(calculate_sparsity(model))
        # calcu the s sparsity
        fc1_sparse = torch.sum(torch.abs(model.kronfc1.s) < 1e-5).item() 
        fc2_sparse = torch.sum(torch.abs(model.kronfc2.s) < 1e-5).item() 
        fc3_sparse = torch.sum(torch.abs(model.kronfc3.s) < 1e-5).item() 
        # total_params = model.kronfc1.s.numel() + model.kronfc2.s.numel() + model.kronfc3.s.numel()
        
        print(f"fc1 sparsity: {fc1_sparse}, fc2 sparsity: {fc2_sparse}, fc3 sparsity: {fc3_sparse}")
        print(f"total sparse params: {fc1_sparse + fc2_sparse + fc3_sparse}")
        # print(f"fc1 total params: {model.kronfc1.s.numel()}, fc2 total params: {model.kronfc2.s.numel()}, fc3 total params: {model.kronfc3.s.numel()}")
        # print(f"total params: {total_params}")
        

In [12]:
train(model, train_loader, criterion, optimizer, epochs=100)


Epoch 1/100, Loss: 4.574455800086959
(0.0, 0, 12544)
fc1 sparsity: 0, fc2 sparsity: 2, fc3 sparsity: 0
total sparse params: 2
Epoch 2/100, Loss: 4.304711117673276
(0.0, 0, 12544)
fc1 sparsity: 1, fc2 sparsity: 2, fc3 sparsity: 0
total sparse params: 3
Epoch 3/100, Loss: 3.6294926471039175
(7.971938775510203e-05, 1, 12544)
fc1 sparsity: 3, fc2 sparsity: 3, fc3 sparsity: 0
total sparse params: 6
Epoch 4/100, Loss: 2.1339125919189534
(7.971938775510203e-05, 1, 12544)
fc1 sparsity: 1, fc2 sparsity: 3, fc3 sparsity: 1
total sparse params: 5
Epoch 5/100, Loss: 1.651762590352406
(0.0, 0, 12544)
fc1 sparsity: 10, fc2 sparsity: 3, fc3 sparsity: 1
total sparse params: 14
Epoch 6/100, Loss: 1.41234271511086
(0.0, 0, 12544)
fc1 sparsity: 4, fc2 sparsity: 6, fc3 sparsity: 1
total sparse params: 11
Epoch 7/100, Loss: 1.2173531121536614
(0.00047831632653061223, 6, 12544)
fc1 sparsity: 10, fc2 sparsity: 5, fc3 sparsity: 1
total sparse params: 16
Epoch 8/100, Loss: 1.055313567108691
(7.971938775510203e

In [13]:
def test(model, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
    print(f"Accuracy: {100 * correct/total}")

In [14]:
test(model, test_loader)

Accuracy: 98.45999908447266


In [15]:
weight1 = model.fc1.weight.data.detach()
weight2 = model.fc2.weight.data.detach()
weight3 = model.fc3.weight.data.detach()


AttributeError: 'KronLeNet' object has no attribute 'fc1'

In [13]:
# calcu the nums of the three weight matrix
params = 0
for i in model.parameters():
    params += i.numel()
params

21644

In [11]:
print(weight1.shape, weight2.shape, weight3.shape)
w11_shape = [10, 16] 
w12_shape = [12, 16]
w21_shape = [7, 12]
w22_shape = [12, 10]
w31_shape = [5, 7]
w32_shape = [2, 12]



torch.Size([120, 256]) torch.Size([84, 120]) torch.Size([10, 84])


In [20]:
r1 = 50 # 160
w11_hat, w12_hat = gkpd(weight1, w11_shape, w12_shape, atol=1e-1)
r2 = 31 # 84
w21_hat, w22_hat = gkpd(weight2, w21_shape, w22_shape, atol=1e-1)
r3 = 8 # 24
w31_hat, w32_hat = gkpd(weight3, w31_shape, w32_shape, atol=1e-1)
w11_hat.shape, w12_hat.shape, w21_hat.shape, w22_hat.shape, w31_hat.shape, w32_hat.shape

(torch.Size([152, 10, 16]),
 torch.Size([152, 12, 16]),
 torch.Size([84, 7, 12]),
 torch.Size([84, 12, 10]),
 torch.Size([24, 5, 7]),
 torch.Size([24, 2, 12]))

In [21]:
# calcu the nums of the six weight matrix
params = 0
for i in [w11_hat[0:r1,], w12_hat[0:r1,], w21_hat[0:r2,], w22_hat[0:r2,], w31_hat[0:r3,], w32_hat[0:r3,]]:
    params += i.numel()
params

24396

In [22]:
w1_hat = kron(w11_hat[0:r1], w12_hat[0:r1])
w2_hat = kron(w21_hat[0:r2], w22_hat[0:r2]) 
w3_hat = kron(w31_hat[0:r3], w32_hat[0:r3])
model.fc1.weight.data = w1_hat
model.fc2.weight.data = w2_hat
model.fc3.weight.data = w3_hat
model = model.to(device)
test(model, test_loader)

Accuracy: 79.9000015258789


In [58]:
w111_shape, w112_shape = [16, 2, 4], [10, 5, 4]
w211_shape, w212_shape = [28, 1, 3], [3, 7, 4]
w311_shape, w312_shape = [4, 1, 7], [6, 5, 1]


In [63]:
r1 = 64 # 128
w111_hat, w112_hat = gkpd(w11_hat, w111_shape, w112_shape)
r2 = 41 # 83
w211_hat, w212_hat = gkpd(w21_hat, w211_shape, w212_shape)
w211_hat.shape
r3 = 14 # 28
w311_hat, w312_hat = gkpd(w31_hat, w311_shape, w312_shape)


In [64]:
# calcu the nums of the 9 weight matrix
params = 0
for i in [w111_hat[0:r1,], w112_hat[0:r1,], w211_hat[0:r2,], w212_hat[0:r2,], w311_hat[0:r3,], w312_hat[0:r3,]]:
    params += i.numel()
for i in [w12_hat, w22_hat, w32_hat]:
    params += i.numel()
params
    

70068

In [5]:
def group_transpose(param):
    N,_,_ = param.shape
    return torch.reshape(param, (N, -1)).T
a = torch.randn(2, 3, 4)
group_transpose(a).shape

torch.Size([12, 2])

In [None]:
!nohup python -u main.py > log.txt 2>&1 &