# TUCKER DECOMPOSITION EXPERIMENTATION


In [3]:
from toolbox.models import ResNet112, ResNet56

import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

from rich import print as pprint
from tqdm import trange

device = "cuda"
plt.rcParams["image.cmap"] = "magma"

In [4]:
batch_size = 128
mean, std = (0.5071, 0.4867, 0.4409), (0.267, 0.256, 0.276)
testloader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100(
        root="../data",
        train=False,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
)


trainloader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100(
        root="../data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
)

testloader_data = next(iter(testloader))

model_path = r"../experiments-24-03\Cifar100\ResNet112\1\Cifar100_ResNet112.pth"
model = ResNet112(100).to(device)
model.load_state_dict(torch.load(model_path, weights_only=True)["weights"])
print("Loaded ResNet112")

Files already downloaded and verified
Files already downloaded and verified
Loaded ResNet112


In [5]:
model.eval()

inputs, targets = testloader_data
inputs, targets = inputs.to(device), targets.to(device)

with torch.no_grad():
    outputs = model(inputs)

probs = torch.nn.functional.softmax(outputs[3], dim=1)
confidence, predicted = torch.max(probs.data, 1)

predicted_arr = predicted.eq(targets.data).cpu().float()
predicted_class_arr = np.array(testloader.dataset.classes)[predicted.data.cpu()]
correct_class_arr = np.array(testloader.dataset.classes)[targets.data.cpu()]

sample = [
    {
        "feature_map": outputs[2][i],
        "correct": predicted_arr[i].item() == 1.0,
        "predicted_class": predicted_class_arr[i],
        "correct_class": correct_class_arr[i],
        "confidence": confidence[i].item(),
    }
    for i in range(batch_size)
]

whole_fmap = outputs[2]


In [6]:
outputs[2].shape

torch.Size([128, 64, 8, 8])

In [None]:
def visualize_channels(index):
    feature_map = sample[index]["feature_map"]

    fig, axs = plt.subplots(8, 8, figsize=(5, 5))
    for i in range(8):
        for j in range(8):
            # axs[i, j*8].matshow(feature_map[i+j*8].cpu(), vmin=0)
            axs[i, j].matshow(feature_map[i + j * 8].cpu(), vmin=0)
            axs[i, j].axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
def print_sample_data(index):
    batch_sample = sample[index]
    print(
        batch_sample["correct"],
        f"| {batch_sample["predicted_class"]} {batch_sample['confidence'] * 100:.2f}% |",
        batch_sample["correct_class"],
    )


def show_sample(index):
    images, labels = testloader_data
    img = images[index]
    img = img * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
    img = img.permute(1, 2, 0).numpy()
    plt.imshow(img)
    plt.axis("off")
    plt.show()


def visualize_processed_feature_map(feature_map):  # feature_map shape: [64,8,8]
    feature_map = torch.mean(feature_map, dim=0, keepdim=True)  # Shape: [1, H, W]
    feature_map = feature_map.view(1, -1)  # Shape: [1, H*W]
    feature_map = torch.softmax(feature_map, dim=1)  # Shape: [1, H*W]
    feature_map = feature_map.view(8, 8).cpu()  # Shape: [H, W]

    plt.figure(figsize=(5, 5))
    plt.imshow(feature_map, vmin=0)
    plt.axis("off")
    plt.title(f"Processed Feature Map for sample")
    plt.show()


def visualize_channels(feature_map):

    fig, axs = plt.subplots(4, 8, figsize=(10, 5))
    for i in range(4):
        for j in range(8):
            axs[i, j].matshow(feature_map[i * 4 + j].cpu(), vmin=0)
            axs[i, j].axis("off")
    plt.show()

In [7]:
import tensorly as tl
import torch.nn.functional as F

tl.set_backend("pytorch")
def tucker(feature_map): #expects 4d
    batch_size, channels, height, width = feature_map.shape
    core, factors = tl.decomposition.tucker(feature_map, rank=[batch_size, 32, 8, 8])
    return core


def tucker_sample_split(feature_map):
    core_tensor_list = []
    batch_size, channels, height, width = feature_map.shape
    for i in range(batch_size):
        single_fmap = feature_map[i].unsqueeze(0)
        core = tucker(single_fmap)[0]
        core_tensor_list.append(core)
        
    combined_tensor = torch.stack(core_tensor_list, dim=0)
    return combined_tensor


In [9]:
untrained_model = ResNet56(100).to(device)
untrained_model.train()

ResNet56

In [None]:
current_output = None
for i in range(150):
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if batch_idx % 10 == 0:
            print(batch_idx)
            
        inputs, targets = inputs.to(device), targets.to(device)
        student_outputs = untrained_model(inputs)

        current_output = student_outputs[2] 


        tucker_fmap = tucker(student_outputs[2])

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
2

KeyboardInterrupt: 