In [None]:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # For full reproducibility (slightly slower on GPU)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

In [1]:
import torch

class Adam:
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        self.params = list(params)
        self.lr = lr
        self.beta1 = betas[0]
        self.beta2 = betas[1]
        self.eps = eps
        self.t = 0  # time step

        # Initialize moment vectors
        self.m = [torch.zeros_like(p.data) for p in self.params]  # 1st moment
        self.v = [torch.zeros_like(p.data) for p in self.params]  # 2nd moment

    def step(self):
        self.t += 1  # Increase timestep

        for i, param in enumerate(self.params):
            if param.grad is None:
                continue

            grad = param.grad.data

            # Update biased moments
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (grad * grad)

            # Compute bias-corrected moments
            m_hat = self.m[i] / (1 - self.beta1 ** self.t)
            v_hat = self.v[i] / (1 - self.beta2 ** self.t)

            # Parameter update
            param.data -= self.lr * m_hat / (v_hat.sqrt() + self.eps)

    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.detach_()
                param.grad.zero_()


In [12]:
import torch

class HybridAdamBinarySearch:
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, alpha=0.5, log_sign_flips=False):
        self.params = list(params)
        self.lr = lr
        self.beta1 = betas[0]
        self.beta2 = betas[1]
        self.eps = eps
        self.alpha = alpha  # blend factor (0.0 = pure Adam)
        self.t = 0  # time step

        self.log_sign_flips = log_sign_flips

        # Adam moment vectors
        self.m = [torch.zeros_like(p.data) for p in self.params]
        self.v = [torch.zeros_like(p.data) for p in self.params]

        # For sign-check logic
        self.last_data_grads = [[None, None] for _ in self.params]

        if self.log_sign_flips:
            self.sign_flip_counts = [0 for _ in self.params]

    def step(self):
        self.t += 1

        for i, param in enumerate(self.params):
            if param.grad is None:
                continue

            grad = param.grad.data
            data = param.data

            # Adam moment update
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (grad * grad)

            m_hat = self.m[i] / (1 - self.beta1 ** self.t)
            v_hat = self.v[i] / (1 - self.beta2 ** self.t)

            update = self.lr * m_hat / (v_hat.sqrt() + self.eps)
            adam_next = data - update  # candidate new value without correction

            last_data, last_grad = self.last_data_grads[i]
            if last_grad is not None:
                sign_flip_mask = torch.sign(grad) * torch.sign(last_grad) < 0

                if sign_flip_mask.any():
                    if self.log_sign_flips:
                        self.sign_flip_counts[i] += sign_flip_mask.sum().item()

                    midpoint = (data + last_data) / 2
                    # Blend midpoint with Adam update using alpha
                    blended = self.alpha * midpoint + (1 - self.alpha) * adam_next
                    data[sign_flip_mask] = blended[sign_flip_mask]

                # For params without flip: normal Adam update
                data[~sign_flip_mask] = adam_next[~sign_flip_mask]
            else:
                # First iteration — just do Adam update
                data[:] = adam_next

            # Save current state
            self.last_data_grads[i] = [data.clone(), grad.clone()]

        if self.log_sign_flips:
            flip_summary = ', '.join(
                f'p{i}: {count}' for i, count in enumerate(self.sign_flip_counts)
            )
            print(f'[Step {self.t}] Total Sign Flips: {flip_summary}')

    def zero_grad(self):
        for p in self.params:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()


In [3]:
import torch
import torch.nn as nn
import torch.optim as adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Model architecture matching paper specifications
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 1000)  # First hidden layer
        self.fc2 = nn.Linear(1000, 1000)   # Second hidden layer
        self.fc3 = nn.Linear(1000, 10)     # Output layer
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# MNIST data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

# Data loaders with paper's batch size
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 57.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.66MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.7MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.67MB/s]


In [14]:

# Initialize components
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = HybridAdamBinarySearch(model.parameters(), alpha = 0)

# Training loop
def train(epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Print training progress
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}')

# Evaluation
def test():
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test Accuracy: {accuracy:.2f}%')

# Run experiment
train(epochs=10)  # Match paper's training duration
test()


Epoch 1/10 - Loss: 0.1957
Epoch 2/10 - Loss: 0.0841
Epoch 3/10 - Loss: 0.0568
Epoch 4/10 - Loss: 0.0436
Epoch 5/10 - Loss: 0.0342
Epoch 6/10 - Loss: 0.0286
Epoch 7/10 - Loss: 0.0260
Epoch 8/10 - Loss: 0.0257
Epoch 9/10 - Loss: 0.0199
Epoch 10/10 - Loss: 0.0213
Test Accuracy: 97.96%


In [None]:
'''
Epoch 1/10 - Loss: 0.1957
Epoch 2/10 - Loss: 0.0841
Epoch 3/10 - Loss: 0.0568
Epoch 4/10 - Loss: 0.0436
Epoch 5/10 - Loss: 0.0342
Epoch 6/10 - Loss: 0.0286
Epoch 7/10 - Loss: 0.0260
Epoch 8/10 - Loss: 0.0257
Epoch 9/10 - Loss: 0.0199
Epoch 10/10 - Loss: 0.0213
Test Accuracy: 97.96%
'''

In [13]:
import torch
import torch.nn as nn
import torch.optim as adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Model architecture matching paper specifications
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 1000)  # First hidden layer
        self.fc2 = nn.Linear(1000, 1000)   # Second hidden layer
        self.fc3 = nn.Linear(1000, 10)     # Output layer
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# Initialize components
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = HybridAdamBinarySearch(model.parameters(), alpha = 1)

# Training loop
def train(epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Print training progress
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}')

# Evaluation
def test():
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test Accuracy: {accuracy:.2f}%')

# Run experiment
train(epochs=10)  # Match paper's training duration
test()


Epoch 1/10 - Loss: 0.1931
Epoch 2/10 - Loss: 0.0746
Epoch 3/10 - Loss: 0.0502
Epoch 4/10 - Loss: 0.0361
Epoch 5/10 - Loss: 0.0287
Epoch 6/10 - Loss: 0.0252
Epoch 7/10 - Loss: 0.0187
Epoch 8/10 - Loss: 0.0212
Epoch 9/10 - Loss: 0.0190
Epoch 10/10 - Loss: 0.0142
Test Accuracy: 97.85%


In [None]:
'''
Epoch 1/10 - Loss: 0.1931
Epoch 2/10 - Loss: 0.0746
Epoch 3/10 - Loss: 0.0502
Epoch 4/10 - Loss: 0.0361
Epoch 5/10 - Loss: 0.0287
Epoch 6/10 - Loss: 0.0252
Epoch 7/10 - Loss: 0.0187
Epoch 8/10 - Loss: 0.0212
Epoch 9/10 - Loss: 0.0190
Epoch 10/10 - Loss: 0.0142
Test Accuracy: 97.85%
'''