In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np

# Define the MRL components from the code base
class MRL_Linear_Layer(nn.Module):
    def __init__(self, nesting_list, num_classes=1000, efficient=False, **kwargs):
        super(MRL_Linear_Layer, self).__init__()
        self.nesting_list = nesting_list
        self.num_classes = num_classes
        self.efficient = efficient
        
        if self.efficient:
            setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))
        else:    
            for i, num_feat in enumerate(self.nesting_list):
                setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))    

    def reset_parameters(self):
        if self.efficient:
            self.nesting_classifier_0.reset_parameters()
        else:
            for i in range(len(self.nesting_list)):
                getattr(self, f"nesting_classifier_{i}").reset_parameters()

    def forward(self, x):
        nesting_logits = ()
        for i, num_feat in enumerate(self.nesting_list):
            if self.efficient:
                if self.nesting_classifier_0.bias is None:
                    nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
                else:
                    nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
            else:
                nesting_logits += (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)

        return nesting_logits

class FixedFeatureLayer(nn.Linear):
    def __init__(self, in_features, out_features, **kwargs):
        super(FixedFeatureLayer, self).__init__(in_features, out_features, **kwargs)

    def forward(self, x):
        if not (self.bias is None):
            out = torch.matmul(x[:, :self.in_features], self.weight.t()) + self.bias
        else:
            out = torch.matmul(x[:, :self.in_features], self.weight.t())
        return out

class BlurPoolConv2d(nn.Module):
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer('blur_filter', filt)

    def forward(self, x):
        blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1),
                           groups=self.conv.in_channels, bias=None)
        return self.conv.forward(blurred)

def apply_blurpool(mod: nn.Module):
    for (name, child) in mod.named_children():
        if isinstance(child, nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16):
            setattr(mod, name, BlurPoolConv2d(child))
        else: 
            apply_blurpool(child)

In [8]:
nesting_start = 3
efficient = True
nesting_list = [2**i for i in range(nesting_start, 12)]  # 8, 16, 32, 64, 128, 256, 512, 1024, 2048
num_classes = 1000
model_path = "/home/mmkuznecov/SkolCourses/DL/FINAL_PROJECT/MRL/train/logs/98819cd7-62aa-479e-8642-f4333540615e/final_weights.pt"  # Update with your actual path

# Create the model
print("Initializing ResNet50 model...")
model = models.resnet50(pretrained=True)
model.fc = MRL_Linear_Layer(nesting_list, num_classes=num_classes, efficient=efficient)

# Apply BlurPool to the model
print("Applying BlurPool...")
apply_blurpool(model)

# Convert to channels-last memory format for better performance
model = model.to(memory_format=torch.channels_last)

Initializing ResNet50 model...




Applying BlurPool...


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [11]:
print("\nGenerating random input tensor...")
batch_size = 4
random_input = torch.randn(batch_size, 3, 224, 224, device=device)


Generating random input tensor...


In [12]:
random_input = random_input.to(memory_format=torch.channels_last)

In [15]:
print("Performing forward pass...")
with torch.no_grad():
    outputs = model(random_input)

# Print the shapes of all outputs
print("\nOutput shapes from all embedding dimensions:")
print(f"{'Embedding Dim':<15} {'Output Shape':<20} {'Parameters'}")
print("-" * 50)

total_params = 0
for i, dim in enumerate(nesting_list):
    logits = outputs[i]
    num_params = dim * num_classes
    if efficient:
        # In efficient mode, we count parameters up to this dimension
        if i == len(nesting_list) - 1:
            params_to_add = num_params
        else:
            params_to_add = 0
    else:
        # In non-efficient mode, each dimension has its own parameters
        params_to_add = num_params
        
    total_params += params_to_add
    print(f"{dim:<15} {str(logits.shape):<20} {num_params:,}")

print("-" * 50)
if efficient:
    print(f"Total classifier parameters (MRL-E): {nesting_list[-1] * num_classes:,}")
else:
    print(f"Total classifier parameters (MRL): {total_params:,}")

# Output example values from the first embedding
print(f"\nExample output values (first 5) from embedding dim {nesting_list[0]}:")
print(outputs[0][0, :5])

Performing forward pass...

Output shapes from all embedding dimensions:
Embedding Dim   Output Shape         Parameters
--------------------------------------------------
8               torch.Size([4, 1000]) 8,000
16              torch.Size([4, 1000]) 16,000
32              torch.Size([4, 1000]) 32,000
64              torch.Size([4, 1000]) 64,000
128             torch.Size([4, 1000]) 128,000
256             torch.Size([4, 1000]) 256,000
512             torch.Size([4, 1000]) 512,000
1024            torch.Size([4, 1000]) 1,024,000
2048            torch.Size([4, 1000]) 2,048,000
--------------------------------------------------
Total classifier parameters (MRL-E): 2,048,000

Example output values (first 5) from embedding dim 8:
tensor([-0.0172, -0.0201,  0.0015,  0.0222, -0.0080], device='cuda:0')


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
import numpy as np
from tqdm import tqdm
from datasets import load_from_disk
from PIL import Image

# Define the MRL components from the code base
class MRL_Linear_Layer(nn.Module):
    def __init__(self, nesting_list, num_classes=1000, efficient=False, **kwargs):
        super(MRL_Linear_Layer, self).__init__()
        self.nesting_list = nesting_list
        self.num_classes = num_classes
        self.efficient = efficient
        
        if self.efficient:
            setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))
        else:    
            for i, num_feat in enumerate(self.nesting_list):
                setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))    

    def reset_parameters(self):
        if self.efficient:
            self.nesting_classifier_0.reset_parameters()
        else:
            for i in range(len(self.nesting_list)):
                getattr(self, f"nesting_classifier_{i}").reset_parameters()

    def forward(self, x):
        nesting_logits = ()
        for i, num_feat in enumerate(self.nesting_list):
            if self.efficient:
                if self.nesting_classifier_0.bias is None:
                    nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
                else:
                    nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
            else:
                nesting_logits += (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)

        return nesting_logits

class FixedFeatureLayer(nn.Linear):
    def __init__(self, in_features, out_features, **kwargs):
        super(FixedFeatureLayer, self).__init__(in_features, out_features, **kwargs)

    def forward(self, x):
        if not (self.bias is None):
            out = torch.matmul(x[:, :self.in_features], self.weight.t()) + self.bias
        else:
            out = torch.matmul(x[:, :self.in_features], self.weight.t())
        return out

class BlurPoolConv2d(nn.Module):
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer('blur_filter', filt)

    def forward(self, x):
        blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1),
                           groups=self.conv.in_channels, bias=None)
        return self.conv.forward(blurred)

def apply_blurpool(mod: nn.Module):
    for (name, child) in mod.named_children():
        if isinstance(child, nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16):
            setattr(mod, name, BlurPoolConv2d(child))
        else: 
            apply_blurpool(child)

def load_model(model_path):
    # Model configuration
    nesting_start = 3
    efficient = True
    nesting_list = [2**i for i in range(nesting_start, 12)]  # 8, 16, 32, 64, 128, 256, 512, 1024, 2048
    num_classes = 1000
    
    # Create the model
    print("Initializing ResNet50 model...")
    model = models.resnet50(pretrained=True)
    model.fc = MRL_Linear_Layer(nesting_list, num_classes=num_classes, efficient=efficient)
    
    # Apply BlurPool to the model
    print("Applying BlurPool...")
    apply_blurpool(model)
    
    # Load the pretrained weights
    try:
        print(f"Loading weights from {model_path}...")
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Handle case where weights were saved in DataParallel format
        if list(checkpoint.keys())[0].startswith('module.'):
            # Remove 'module.' prefix
            clean_ckpt = {}
            for k, v in checkpoint.items():
                clean_ckpt[k[7:] if k.startswith('module.') else k] = v
            checkpoint = clean_ckpt
        
        model.load_state_dict(checkpoint)
        print("Model loaded successfully!")
    except Exception as e:
        print(f"Error loading weights: {e}")
        print("Continuing with pretrained weights only.")
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model = model.to(memory_format=torch.channels_last)
    model.eval()
    
    return model, nesting_list, device

def evaluate_dataset(model, dataset, nesting_list, device, batch_size=64, num_samples=None):
    # Standard ImageNet normalization
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    
    # Create transformation pipeline
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    
    # Initialize accuracy trackers for each embedding dimension
    correct_top1 = {dim: 0 for dim in nesting_list}
    correct_top5 = {dim: 0 for dim in nesting_list}
    
    # Limit samples if specified
    if num_samples is not None:
        total_samples = min(num_samples, len(dataset))
    else:
        total_samples = len(dataset)
    
    # Process in batches
    batches = []
    labels = []
    
    print(f"Processing {total_samples} images...")
    for i in tqdm(range(0, total_samples, batch_size)):
        batch_indices = range(i, min(i + batch_size, total_samples))
        batch_images = []
        batch_labels = []
        
        # Process each image in the batch
        for idx in batch_indices:
            sample = dataset[idx]
            image = sample['image']
            label = sample['label']
            
            # Apply transformations
            img_tensor = transform(image)
            batch_images.append(img_tensor)
            batch_labels.append(label)
        
        # Stack into a batch tensor
        images_tensor = torch.stack(batch_images).to(device)
        labels_tensor = torch.tensor(batch_labels).to(device)
        
        # Convert to channels last format for better performance
        images_tensor = images_tensor.to(memory_format=torch.channels_last)
        
        # Forward pass
        with torch.no_grad():
            outputs = model(images_tensor)
        
        # Calculate accuracy for each embedding dimension
        for i, dim in enumerate(nesting_list):
            # Top-1 accuracy
            _, predicted = outputs[i].max(1)
            correct_top1[dim] += (predicted == labels_tensor).sum().item()
            
            # Top-5 accuracy
            _, top5_indices = outputs[i].topk(5, dim=1)
            labels_expanded = labels_tensor.view(-1, 1).expand_as(top5_indices)
            correct_top5[dim] += (top5_indices == labels_expanded).sum().item()
    
    # Calculate final accuracy for each dimension
    accuracy_top1 = {dim: correct_top1[dim] / total_samples * 100 for dim in nesting_list}
    accuracy_top5 = {dim: correct_top5[dim] / total_samples * 100 for dim in nesting_list}
    
    return accuracy_top1, accuracy_top5, total_samples

def main():
    # Configuration
    model_path = "/home/mmkuznecov/SkolCourses/DL/FINAL_PROJECT/MRL/train/logs/98819cd7-62aa-479e-8642-f4333540615e/final_weights.pt"
    dataset_path = "data/imagenet_1k_resized_256_val"
    batch_size = 64  # Adjust based on your GPU memory
    num_samples = None  # Set to None to process all images, or a number to limit samples
    
    # Load model
    model, nesting_list, device = load_model(model_path)
    
    # Load dataset
    print(f"Loading dataset from {dataset_path}...")
    dataset = load_from_disk(dataset_path)
    print(f"Dataset loaded with {len(dataset)} samples")
    
    # Evaluate the model on the dataset
    accuracy_top1, accuracy_top5, total_samples = evaluate_dataset(
        model, dataset, nesting_list, device, batch_size, num_samples
    )
    
    # Print results
    print(f"\nEvaluation completed on {total_samples} images")
    print("\nAccuracy results for each embedding dimension:")
    print(f"{'Embedding Dim':<15} {'Top-1 Accuracy (%)':<20} {'Top-5 Accuracy (%)'}")
    print("-" * 60)
    
    for dim in nesting_list:
        print(f"{dim:<15} {accuracy_top1[dim]:<20.2f} {accuracy_top5[dim]:<20.2f}")
    
    # Print a summary for the best dimension
    best_dim = max(accuracy_top1, key=accuracy_top1.get)
    print(f"\nBest performing dimension: {best_dim}")
    print(f"Top-1 Accuracy: {accuracy_top1[best_dim]:.2f}%")
    print(f"Top-5 Accuracy: {accuracy_top5[best_dim]:.2f}%")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Initializing ResNet50 model...


  checkpoint = torch.load(model_path, map_location='cpu')


Applying BlurPool...
Loading weights from /home/mmkuznecov/SkolCourses/DL/FINAL_PROJECT/MRL/train/logs/98819cd7-62aa-479e-8642-f4333540615e/final_weights.pt...
Model loaded successfully!
Loading dataset from data/imagenet_1k_resized_256_val...
Dataset loaded with 50000 samples
Processing 50000 images...


100%|██████████| 782/782 [02:45<00:00,  4.73it/s]


Evaluation completed on 50000 images

Accuracy results for each embedding dimension:
Embedding Dim   Top-1 Accuracy (%)   Top-5 Accuracy (%)
------------------------------------------------------------
8               53.55                77.07               
16              67.61                85.14               
32              70.27                87.76               
64              71.12                88.80               
128             71.55                89.61               
256             71.85                90.15               
512             72.02                90.33               
1024            72.11                90.43               
2048            72.15                90.55               

Best performing dimension: 2048
Top-1 Accuracy: 72.15%
Top-5 Accuracy: 90.55%



