In [1]:
import torch
from torchvision import transforms

import torchvision.datasets as datasets
# Using CIFAR-10 dataset
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# Define the transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Load the CIFAR-10 dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 71.2MB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [24]:
from torchvision import models

resnet_pretrained = models.resnet18(pretrained=True)
resnet_random = models.resnet18(pretrained=False)



In [3]:
def compute_expectation_variance(K, M, T=1.0):
    if K.shape != M.shape:
        raise ValueError("K and M must have the same shape.")

    sigma = torch.sigmoid(K / T)              # σ(K_{ij}/T)
    expectation = (M * sigma).sum()           # Σ M σ
    variance = ((M**2) * sigma * (1 - sigma)).sum()  # Σ M² σ(1−σ)

    return expectation, variance  

In [None]:
def linear_probe_accuracy(model, n_epoch = 3):
    model.eval()
    
    linear = torch.nn.Linear(1000, 10).cuda()
    
    # Define loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(linear.parameters(), lr=0.001)
    
    # Training phase
    model.cuda()
    model.eval()
    for epoch in range(n_epoch):  # Train for 5 epochs
        running_loss = 0.0
        for images, labels in test_loader:
            # Resize images to fit the model
            images = torch.nn.functional.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
            # alernatively, you can use the following line to resize images
            images, labels = images.cuda(), labels.cuda() if torch.cuda.is_available() else (images, labels)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward + backward + optimize
            outputs = linear(model(images))
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f'Linear Probe Epoch {epoch+1}, Loss: {running_loss/len(test_loader):.4f}')
    
    # Evaluation phase
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            # Resize images to fit the model
            images = torch.nn.functional.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
            images, labels = images.cuda(), labels.cuda() if torch.cuda.is_available() else (images, labels)
            
            outputs = linear(model(images))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Accuracy on test set: {accuracy:.2f}%')
    
    return accuracy

In [9]:
linear_probe_accuracy(resnet_pretrained)

Epoch 1, Loss: 0.9260
Epoch 2, Loss: 0.7065
Epoch 3, Loss: 0.6724
Accuracy on test set: 78.96%


78.96

In [10]:
linear_probe_accuracy(resnet_random)

Epoch 1, Loss: 2.1384
Epoch 2, Loss: 1.9875
Epoch 3, Loss: 1.9263
Accuracy on test set: 32.93%


32.93

In [25]:
beta = .5
T = 1

n_epoch = 20

resnet_random.train()
resnet_pretrained.eval()

optimizer = torch.optim.Adam(resnet_random.parameters(), lr=0.001)

resnet_random.cuda()
resnet_pretrained.cuda()

for epoch in range(n_epoch):  
    running_loss = 0.0
    running_exp = 0.0
    running_var = 0.0
    n = 0

    for images, labels in train_loader:
        images = torch.nn.functional.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
        images = images.cuda()

        with torch.no_grad():
            outputs = resnet_pretrained(images)

            norm_outputs = outputs / outputs.norm(dim=1, keepdim=True)
            K = norm_outputs @ norm_outputs.T
            
        outputs = resnet_random(images)

        norm_outputs = outputs / outputs.norm(dim=1, keepdim=True)
        M = norm_outputs @ norm_outputs.T

        exp, var = compute_expectation_variance(K, M, T)

        loss = var - beta * exp

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print statistics  

        running_loss += loss.item()
        running_exp += exp.item()
        running_var += var.item()

        n += 1

    print(f'Epoch {epoch+1}, Loss: {running_loss/n:.4f}, Expectation: {running_exp/n:.4f}, Variance: {running_var/n:.4f}')
    if epoch % 5 == 0:
        resnet_random.eval()
        linear_probe_accuracy(resnet_random, n_epoch=1)
        resnet_random.train()

Epoch 1, Loss: -440.0636, Expectation: 1839.2235, Variance: 479.5482
Linear Probe Epoch 1, Loss: 2.0240
Accuracy on test set: 35.69%
Epoch 2, Loss: -444.1681, Expectation: 1812.7230, Variance: 462.1934
Epoch 3, Loss: -445.2079, Expectation: 1808.6774, Variance: 459.1308
Epoch 4, Loss: -445.7682, Expectation: 1806.2496, Variance: 457.3566
Epoch 5, Loss: -446.3699, Expectation: 1806.2420, Variance: 456.7511
Epoch 6, Loss: -446.7475, Expectation: 1805.4800, Variance: 455.9925
Linear Probe Epoch 1, Loss: 1.4958
Accuracy on test set: 60.38%
Epoch 7, Loss: -447.0115, Expectation: 1804.9412, Variance: 455.4591
Epoch 8, Loss: -447.1502, Expectation: 1804.6099, Variance: 455.1547
Epoch 9, Loss: -447.3201, Expectation: 1804.3300, Variance: 454.8449
Epoch 10, Loss: -447.4683, Expectation: 1804.1418, Variance: 454.6026
Epoch 11, Loss: -447.5839, Expectation: 1803.7309, Variance: 454.2816
Linear Probe Epoch 1, Loss: 1.3909
Accuracy on test set: 66.16%
Epoch 12, Loss: -447.6556, Expectation: 1803.41