In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 10
batch_size = 128
learning_rate = 0.001
lambda_kl = 0.01  # Regularization strength for KL divergence

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [13]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [14]:
def tanh_ternary(x):
    x_tanh = torch.tanh(x)
    x_tanh = torch.where(x_tanh > 0.666, torch.tensor(1.0, device=x.device), x_tanh)
    x_tanh = torch.where(x_tanh < -0.666, torch.tensor(-1.0, device=x.device), x_tanh)
    x_tanh = torch.where((x_tanh >= -0.666) & (x_tanh <= 0.666), torch.tensor(0.0, device=x.device), x_tanh)
    return x_tanh

def ternarize_weights(param):
    gamma = torch.mean(torch.abs(param)) + 1e-7  # Scaling factor (to avoid division by zero)
    return tanh_ternary(param / gamma) * gamma  # Scale the ternary weights back

def kl_divergence(original, ternary):
    original_prob = torch.softmax(original.flatten(), dim=0)
    ternary_prob = torch.softmax(ternary.flatten(), dim=0)
    kl_div = F.kl_div(ternary_prob.log(), original_prob, reduction='batchmean')
    return kl_div

def forward_with_ternarized_weights(model, x):
    original_params = []
    with torch.no_grad():
        for param in model.parameters():
            if param.requires_grad:
                original_params.append(param.clone())
                param.data = ternarize_weights(param.data)
    
    output = model(x)
    
    # Restore original weights
    with torch.no_grad():
        for param, original_param in zip(model.parameters(), original_params):
            if param.requires_grad:
                param.data = original_param
    
    return output

def calculate_accuracy(model, loader, ternary=False):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            if ternary:
                outputs = forward_with_ternarized_weights(model, images)
            else:
                outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def train_model(ternary=False):
    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model for 10 epochs
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            if ternary:
                outputs = forward_with_ternarized_weights(model, images)
            else:
                outputs = model(images)

            loss = criterion(outputs, labels)

            if ternary:
                # KL Divergence Regularization
                kl_loss = 0.0
                for param in model.parameters():
                    if param.requires_grad:
                        original_weights = param.clone().detach()
                        ternary_weights = ternarize_weights(original_weights)
                        kl_loss += kl_divergence(original_weights, ternary_weights)
                
                loss += lambda_kl * kl_loss

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

    return model

In [15]:
# Train the first model with regular weights
print("Training the first model (regular weights)...")
model_regular = train_model(ternary=False)
accuracy_regular = calculate_accuracy(model_regular, test_loader)
print(f'Accuracy of the model with regular weights: {accuracy_regular:.2f}%')

# Train the second model with ternarized weights in the forward pass
print("\nTraining the second model (ternarized weights)...")
model_ternary = train_model(ternary=True)
accuracy_ternary = calculate_accuracy(model_ternary, test_loader, ternary=True)
print(f'Accuracy of the model with ternarized weights: {accuracy_ternary:.2f}%')

Training the first model (regular weights)...


100%|██████████| 391/391 [00:12<00:00, 30.44it/s]


Epoch [1/10], Loss: 1.3499


100%|██████████| 391/391 [00:12<00:00, 30.21it/s]


Epoch [2/10], Loss: 0.9445


100%|██████████| 391/391 [00:13<00:00, 30.02it/s]


Epoch [3/10], Loss: 0.7801


100%|██████████| 391/391 [00:12<00:00, 30.54it/s]


Epoch [4/10], Loss: 0.6446


100%|██████████| 391/391 [00:12<00:00, 30.08it/s]


Epoch [5/10], Loss: 0.5181


100%|██████████| 391/391 [00:13<00:00, 30.07it/s]


Epoch [6/10], Loss: 0.3983


100%|██████████| 391/391 [00:13<00:00, 29.96it/s]


Epoch [7/10], Loss: 0.2764


100%|██████████| 391/391 [00:12<00:00, 30.21it/s]


Epoch [8/10], Loss: 0.1905


100%|██████████| 391/391 [00:12<00:00, 30.48it/s]


Epoch [9/10], Loss: 0.1145


100%|██████████| 391/391 [00:13<00:00, 29.98it/s]


Epoch [10/10], Loss: 0.0754
Accuracy of the model with regular weights: 72.96%

Training the second model (ternarized weights)...


100%|██████████| 391/391 [00:17<00:00, 22.60it/s]


Epoch [1/10], Loss: 1.5867


100%|██████████| 391/391 [00:17<00:00, 22.84it/s]


Epoch [2/10], Loss: 1.2574


100%|██████████| 391/391 [00:17<00:00, 22.74it/s]


Epoch [3/10], Loss: 1.1050


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]


Epoch [4/10], Loss: 0.9927


100%|██████████| 391/391 [00:17<00:00, 22.58it/s]


Epoch [5/10], Loss: 0.9028


100%|██████████| 391/391 [00:17<00:00, 22.92it/s]


Epoch [6/10], Loss: 0.8320


100%|██████████| 391/391 [00:17<00:00, 22.59it/s]


Epoch [7/10], Loss: 0.7678


100%|██████████| 391/391 [00:17<00:00, 22.83it/s]


Epoch [8/10], Loss: 0.6998


100%|██████████| 391/391 [00:17<00:00, 22.71it/s]


Epoch [9/10], Loss: 0.6326


100%|██████████| 391/391 [00:17<00:00, 22.63it/s]


Epoch [10/10], Loss: 0.5719
Accuracy of the model with ternarized weights: 70.65%
