In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 30
batch_size = 128
learning_rate = 0.001

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 47684233.85it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
def tanh(x):
    x_tanh = torch.tanh(x)
    x_tanh = torch.where(x_tanh > 0.666, torch.tensor(1.0, device=x.device), x_tanh)
    x_tanh = torch.where(x_tanh < -0.666, torch.tensor(-1.0, device=x.device), x_tanh)
    x_tanh = torch.where((x_tanh >= -0.666) & (x_tanh <= 0.666), torch.tensor(0.0, device=x.device), x_tanh)
    return x_tanh

def apply_ternery_transformation(net):
    with torch.no_grad():
        for param in net.parameters():
            if param.requires_grad:
                gamma = torch.mean(torch.abs(param))
                param.data = tanh(param.data/(gamma + 1e-7))

def apply_weight_transformation(net):
    apply_ternery_transformation(net)
    
def calculate_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def train_model():
    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')
    
    return model

In [5]:
model = train_model()

accuracy_before = calculate_accuracy(model, test_loader)
print(f'Accuracy before applying weight transformation: {accuracy_before:.2f}%')

apply_weight_transformation(model)

accuracy_after = calculate_accuracy(model, test_loader)
print(f'Accuracy after applying weight transformation: {accuracy_after:.2f}%')

accuracy_change = accuracy_after - accuracy_before
print(f'Change in accuracy after applying transformation: {accuracy_change:.2f}%')

100%|██████████| 391/391 [00:59<00:00,  6.60it/s]


Epoch [1/30], Loss: 1.3628


100%|██████████| 391/391 [01:01<00:00,  6.32it/s]


Epoch [2/30], Loss: 0.9752


100%|██████████| 391/391 [01:03<00:00,  6.16it/s]


Epoch [3/30], Loss: 0.7971


100%|██████████| 391/391 [01:03<00:00,  6.19it/s]


Epoch [4/30], Loss: 0.6573


100%|██████████| 391/391 [01:03<00:00,  6.21it/s]


Epoch [5/30], Loss: 0.5224


100%|██████████| 391/391 [01:03<00:00,  6.20it/s]


Epoch [6/30], Loss: 0.4011


100%|██████████| 391/391 [01:02<00:00,  6.25it/s]


Epoch [7/30], Loss: 0.2870


100%|██████████| 391/391 [01:02<00:00,  6.22it/s]


Epoch [8/30], Loss: 0.1844


100%|██████████| 391/391 [01:03<00:00,  6.20it/s]


Epoch [9/30], Loss: 0.1213


100%|██████████| 391/391 [01:02<00:00,  6.27it/s]


Epoch [10/30], Loss: 0.0736


100%|██████████| 391/391 [01:01<00:00,  6.31it/s]


Epoch [11/30], Loss: 0.0618


100%|██████████| 391/391 [01:02<00:00,  6.26it/s]


Epoch [12/30], Loss: 0.0580


100%|██████████| 391/391 [01:02<00:00,  6.23it/s]


Epoch [13/30], Loss: 0.0483


100%|██████████| 391/391 [01:03<00:00,  6.18it/s]


Epoch [14/30], Loss: 0.0404


100%|██████████| 391/391 [01:02<00:00,  6.21it/s]


Epoch [15/30], Loss: 0.0353


100%|██████████| 391/391 [01:02<00:00,  6.30it/s]


Epoch [16/30], Loss: 0.0505


100%|██████████| 391/391 [01:02<00:00,  6.21it/s]


Epoch [17/30], Loss: 0.0362


100%|██████████| 391/391 [01:03<00:00,  6.18it/s]


Epoch [18/30], Loss: 0.0341


100%|██████████| 391/391 [01:02<00:00,  6.22it/s]


Epoch [19/30], Loss: 0.0445


100%|██████████| 391/391 [01:05<00:00,  5.98it/s]


Epoch [20/30], Loss: 0.0349


100%|██████████| 391/391 [01:02<00:00,  6.25it/s]


Epoch [21/30], Loss: 0.0262


100%|██████████| 391/391 [01:01<00:00,  6.31it/s]


Epoch [22/30], Loss: 0.0304


100%|██████████| 391/391 [01:02<00:00,  6.24it/s]


Epoch [23/30], Loss: 0.0289


100%|██████████| 391/391 [01:01<00:00,  6.33it/s]


Epoch [24/30], Loss: 0.0329


100%|██████████| 391/391 [01:01<00:00,  6.39it/s]


Epoch [25/30], Loss: 0.0210


100%|██████████| 391/391 [01:02<00:00,  6.24it/s]


Epoch [26/30], Loss: 0.0262


100%|██████████| 391/391 [01:02<00:00,  6.22it/s]


Epoch [27/30], Loss: 0.0354


100%|██████████| 391/391 [01:02<00:00,  6.23it/s]


Epoch [28/30], Loss: 0.0251


100%|██████████| 391/391 [01:02<00:00,  6.24it/s]


Epoch [29/30], Loss: 0.0241


100%|██████████| 391/391 [01:02<00:00,  6.29it/s]


Epoch [30/30], Loss: 0.0231
Accuracy before applying weight transformation: 72.59%
Accuracy after applying weight transformation: 56.05%
Change in accuracy after applying transformation: -16.54%
