In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

In [2]:
class SimpleModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
# Initialize main model (large model M) and proxy models (M- and M+)
input_size = 10
main_hidden_size = 50
proxy_hidden_size = 10
output_size = 3
alpha = 0.5

# Large model
main_model = SimpleModel(input_size, main_hidden_size, output_size)

# Proxy model (untuned version M-)
proxy_model_untuned = SimpleModel(input_size, proxy_hidden_size, output_size)

# Proxy model (tuned version M+)
proxy_model_tuned = SimpleModel(input_size, proxy_hidden_size, output_size)

# Create a small random dataset for training
data = torch.randn(100, input_size)  
labels = torch.randint(0, output_size, (100,)) 

In [4]:
# Create DataLoader
dataset = TensorDataset(data, labels)
train_loader = DataLoader(dataset, batch_size=10, shuffle=True)

In [5]:
# Define loss and optimizer for training proxy_model_tuned
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(proxy_model_tuned.parameters(), lr=0.01)

In [6]:
# Train proxy_model_tuned for a few epochs
print("Training proxy_model_tuned...")
for epoch in range(5):
    for inputs, target_labels in train_loader:
        optimizer.zero_grad()
        output = proxy_model_tuned(inputs)
        loss = criterion(output, target_labels)
        loss.backward()
        optimizer.step()
print("Training complete for proxy_model_tuned.\n")


Training proxy_model_tuned...
Training complete for proxy_model_tuned.



In [7]:
# Testing proxy tuning on a new input
input_data = torch.randn((1, input_size)) 

In [8]:
# Forward pass through each model to get logits
logits_main = main_model(input_data)
logits_proxy_untuned = proxy_model_untuned(input_data)
logits_proxy_tuned = proxy_model_tuned(input_data)

# Define scaling factor alpha
alpha = 0.5

# Apply proxy tuning: adjust main model's logits using the difference between proxy logits
adjusted_logits = logits_main + alpha * (logits_proxy_tuned - logits_proxy_untuned)

#Convert original logits to probabilites
original_probs = F.softmax(logits_main, dim=1)

# Convert adjusted logits to probabilities
adjusted_probs = F.softmax(adjusted_logits, dim=1)

In [9]:
print("Adjusted probabilities:", adjusted_probs)
print("Original probabilities:", original_probs)

Adjusted probabilities: tensor([[0.2641, 0.3728, 0.3632]], grad_fn=<SoftmaxBackward0>)
Original probabilities: tensor([[0.2450, 0.3441, 0.4109]], grad_fn=<SoftmaxBackward0>)
