
class BinarySearchOptimizer:
    def __init__(self, params, eps=0.1, initial_hop=None):
        self.params = list(params)
        self.last_data_grads = [[None, None] for _ in self.params]
        self.hops = [torch.full_like(p.data, eps) for p in self.params]  # tensor hop per param
        self.eps = eps
        self.initial_hop = initial_hop

    def step(self):
        for i, p in enumerate(self.params):
            if p.grad is None:
                continue

            grad = p.grad
            data = p.data

            # First step: just initialize
            if self.last_data_grads[i][0] is None:
                self.last_data_grads[i] = [data.clone(), grad.clone()]
                if self.initial_hop is not None:
                    self.hops[i] = self.eps*torch.full_like(p.data, self.initial_hop)
                else:
                  self.hops[i] = -1*self.eps*grad.clone()
                data -= torch.sign(grad) * self.hops[i]
                # print(self.hops[i], 'self.hops[i]')
                continue
            hop = self.hops[i]

            # Get previous data and grad
            last_data, last_grad = self.last_data_grads[i]

            # Compare signs
            same_sign = torch.sign(grad) * torch.sign(last_grad)

            # Same sign → move in same direction
            mask_same = same_sign > 0
            data[mask_same] -= torch.sign(grad[mask_same]) * hop[mask_same]

            # Sign flipped → binary search midpoint
            mask_flip = same_sign < 0
            midpoint = (data + last_data) / 2
            data[mask_flip] = midpoint[mask_flip]
            hop[mask_flip] /= 2  # reduce hop only where sign changed

            # Save current state
            self.last_data_grads[i] = [data.clone(), grad.clone()]

    def zero_grad(self):
        for p in self.params:
            if p.grad is not None:
                p.grad.zero_()


In [None]:
import torch
import torch.nn as nn
import torch.optim as adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Model architecture matching paper specifications
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 1000)  # First hidden layer
        self.fc2 = nn.Linear(1000, 1000)   # Second hidden layer
        self.fc3 = nn.Linear(1000, 10)     # Output layer
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# MNIST data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

# Data loaders with paper's batch size
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)


### Checking performance with adam optimizer

In [None]:

# Initialize components
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = adam.SGD(model.parameters(), lr=0.001)  # Paper's base learning rate

# Training loop
def train(epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Print training progress
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}')

# Evaluation
def test():
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test Accuracy: {accuracy:.2f}%')

# Run experiment
train(epochs=10)  # Match paper's training duration
test()


### Checking performance with binary search optimizer

In [None]:
import torch
import torch.nn as nn
import torch.optim as adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Model architecture matching paper specifications
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 1000)  # First hidden layer
        self.fc2 = nn.Linear(1000, 1000)   # Second hidden layer
        self.fc3 = nn.Linear(1000, 10)     # Output layer
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# Initialize components
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = BinarySearchOptimizer(model.parameters(), eps=0.001)

# Training loop
def train(epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Print training progress
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}')

# Evaluation
def test():
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test Accuracy: {accuracy:.2f}%')

# Run experiment
train(epochs=10)  # Match paper's training duration
test()
