In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from einops import rearrange, reduce, repeat
import torchvision
from data_factorcy.data_factorcy import loader_generate 
from utils.decomposition import kron

In [2]:
resnet50 = torchvision.models.resnet50(pretrained=True)
train_loader, test_loader = loader_generate('cifar10')



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 102558972.38it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [11]:
test_conv2d = nn.Conv2d(8, 64, 3)
test_conv2d.weight.data.shape

torch.Size([64, 8, 3, 3])

In [15]:
type(resnet50)
model = torchvision.models.resnet.ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=10)
class KronConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias, a_shape, b_shape) -> None:
        super().__init__()
        weight_shape = [out_channels, in_channels, kernel_size, kernel_size]
        # assert each i, a[i] * b[i] == weight[i]
        assert np.array(a_shape[1:]) * np.array(b_shape[1:]) == np.array(weight_shape)
        self.a = nn.Parameter(torch.randn(a_shape))
        self.b = nn.Parameter(torch.randn(b_shape))
        self.stride = stride
        self.padding = padding
        self.bias = bias
        self.dilation = 1
        self.groups = 1
        
    def forward(self, x):
        weight = kron(self.a, self.b)
        return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
    
KronConv2d(3, 64, 7, 2, 3, False, [3, 8, 3, 7, 1], [3, 8, 1, 1, 7])
        

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [3]:
# calcu the params of the model 
def model_params(model):
    params = 0
    for p in model.parameters():
        if p.requires_grad:
            params += p.numel()
    return params
# print the params of the model, for example, 25555 is 25,555
model_params(resnet50)

25557032

In [4]:
# examine the model 
resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
from gkpd.convolution import kroneckerconv2d
resnet50.conv1
nn.Conv2d

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [None]:
# use minist to fine-tune resnet50
import torchvision.transforms as transforms


In [None]:
kroneckerconv2d(
    3, 64, (7, 7), (2, 2), (3, 3)
)

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# Define data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to the input size expected by ResNet-50
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load MNIST dataset
mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=32, shuffle=False)

# Load pre-trained ResNet-50 model
resnet50 = models.resnet50(pretrained=True)

# Modify the last fully connected layer to match the number of classes in MNIST (10)
num_ftrs = resnet50.fc.in_features
resnet50.fc = nn.Linear(num_ftrs, 10)  # Assuming 10 classes in MNIST

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=0.001, momentum=0.9)

# Fine-tuning the model
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)

for epoch in range(num_epochs):
    resnet50.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = resnet50(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

# Evaluate the fine-tuned model on the test set
resnet50.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet50(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")



RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# Define data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),  # Convert single-channel to three channels
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load MNIST dataset
mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=32, shuffle=False)

# Load pre-trained ResNet-50 model
resnet50 = models.resnet50(pretrained=True)

# Modify the last fully connected layer to match the number of classes in MNIST (10)
num_ftrs = resnet50.fc.in_features
resnet50.fc = nn.Linear(num_ftrs, 10)  # Assuming 10 classes in MNIST

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=0.001, momentum=0.9)

# Fine-tuning the model
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)

for epoch in range(num_epochs):
    resnet50.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = resnet50(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

# Evaluate the fine-tuned model on the test set
resnet50.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet50(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")


Epoch 1/5, Loss: 0.11093835338565211
Epoch 2/5, Loss: 0.021187182665864626
Epoch 3/5, Loss: 0.011634554800235977
Epoch 4/5, Loss: 0.006403804848863122
Epoch 5/5, Loss: 0.004385608656009814
Test Accuracy: 99.60%
