In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import resnet18
from tqdm import tqdm
import torch.nn.utils.parametrize as parametrize
from torchvision.models import mobilenet_v2
import math
from clearml import Task

torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    autocast = torch.cuda.amp.autocast(dtype=torch.float16)
    grad_scaler = torch.cuda.amp.GradScaler()
else:
    autocast = torch.cpu.amp.autocast(dtype=torch.float16)
    grad_scaler = torch.cpu.amp.GradScaler()


batch_size = 64
lr = 0.01
momentum = 0.9
epochs = 10

In [2]:
class MLP(nn.Module):
    def __init__(self, input_size, output_size, num_params):
        super(MLP, self).__init__()
        
        depth = int(num_params ** 0.2)  # Adjust this factor as needed
        hidden_size = int((num_params / (depth))**0.5)  # Adjust this factor as needed
        layers = []
        for _ in range(depth):
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_size))
            input_size = hidden_size
        layers.append(nn.Linear(input_size, output_size))  # Output size matching the conv output channels
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)

In [3]:
class MLPParametrization(torch.nn.Module):
     def __init__(self, kernel_size, in_channels, out_channels, groups):
          super(MLPParametrization, self).__init__()
          self.MLPLayer = MLP(2, kernel_size**2, kernel_size*kernel_size*in_channels*out_channels / groups)
          self.kernel_size = kernel_size
          self.in_channels = in_channels
          self.out_channels = out_channels
          self.groups = groups
          self.x_grid = torch.linspace(-1, 1, int(in_channels / self.groups))
          self.y_grid = torch.linspace(-1, 1, out_channels)
          self.grid = torch.cartesian_prod(self.x_grid, self.y_grid)

     def forward(self, x):
          A = self.MLPLayer(self.grid.to(x.device)).view((self.out_channels, int(self.in_channels / self.groups), self.kernel_size, self.kernel_size))
          return A
          

In [4]:
def parametrize_resnet(model):
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            parametrizator = MLPParametrization(layer.kernel_size[0], layer.in_channels, layer.out_channels, layer.groups)
            parametrize.register_parametrization(layer, "weight", parametrizator)

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

configuration_dict = {'batch_size': batch_size, 'lr': lr, 'momentum': momentum, 'epochs': epochs}
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
iterations_in_epoch = math.ceil(len(trainset) / batch_size)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()


        # input channel, output filter, kernel size
        self.conv1 = nn.Conv2d(3, 64, 5, padding=2)
        self.conv2 = nn.Conv2d(64, 128, 5, padding=2)
        self.conv3 = nn.Conv2d(128, 256, 3)

        self.BatchNorm2d1 = nn.BatchNorm2d(64)
        self.BatchNorm2d2 = nn.BatchNorm2d(128)
        self.BatchNorm2d3 = nn.BatchNorm2d(256)

        self.BatchNorm2d4 = nn.BatchNorm2d(120)



        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(p=0)

        self.fc1 = nn.Linear(256 * 3 * 3, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.BatchNorm2d1(self.pool(F.leaky_relu(self.conv1(x))))
        x = self.BatchNorm2d2(self.pool(F.leaky_relu(self.conv2(x))))
        x = self.BatchNorm2d3(self.pool(F.leaky_relu(self.conv3(x))))
        x = x.view(-1, 256 * 3 * 3)
        x = self.dropout(F.leaky_relu(self.fc1(x)))
        x = self.dropout(F.leaky_relu(self.fc2(x)))
        x = self.fc3(x)
        return x

In [7]:
net = Net()
print(f'Number of before parametrizing {sum(p.numel() for p in net.parameters())}')
parametrize_resnet(net)
net.to(device)
print(f'Number of params after parametrizing {sum(p.numel() for p in net.parameters())}')

Number of before parametrizing 793710
Number of params after parametrizing 1266777


In [8]:
task_name = f'MLP_batch{batch_size}_lr{lr}_epochs{epochs}_simple_convnet_bn'
task = Task.init(project_name="mlp_parametrization", task_name=task_name)

configuration_dict = task.connect(
    configuration_dict
)
logger = task.get_logger()

ClearML Task: created new task id=c9b80a44d1394705b31fa6947942a24b
2023-08-11 12:32:34,121 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/f88e6dc889294a42833934ab069b799f/experiments/c9b80a44d1394705b31fa6947942a24b/output/log


In [9]:
import torch.optim as optim

param_list = [param for name, param in net.named_parameters() if "parametrizations." in name]
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(param_list, lr=lr, momentum=momentum)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)

In [None]:
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(tqdm(trainloader)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        # zero the parameter gradients
        optimizer.zero_grad()
        inputs = inputs.to(device)
        # forward + backward + optimize
        with autocast:
            outputs = net(inputs)
            loss = criterion(outputs, labels.to(device))
        grad_scaler.scale(loss).backward()
        grad_scaler.step(optimizer)
        grad_scaler.update()
        scheduler.step()
        # print statistics
        running_loss += loss.item()
        if i % 50 == 49:    # print every 2000 mini-batches
            logger.report_scalar(
                title="Loss",
                series="running_loss",
                iteration=(i + 1) + iterations_in_epoch * (epoch + 1),
                value=running_loss / 50,
            )
            print(running_loss / 50)
            running_loss = 0.0

In [11]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images.to(device))
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted.cpu() == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
logger.report_scalar(
    title="Metrics",
    series="accuracy",
    iteration=len(trainset),
    value=100 * correct // total,
)

Accuracy of the network on the 10000 test images: 44 %


In [None]:
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images.to(device))
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
