In [139]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) )])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 128, shuffle = True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform = transform)
testloader = torch.utils.data.DataLoader(testset, batch_size = 128, shuffle = True)

images, labels = next(iter(trainloader))
plt.imshow(torchvision.utils.make_grid(images).permute(1,2,0) / 2 + 0.5)
plt.title(' '.join(trainset.classes[label] for label in labels));
plt.show()
test_losses = []

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


 18%|███████▏                               | 31.3M/170M [00:02<00:10, 13.8MB/s]

In [None]:
class CIFAR10_NN(nn.Module):
    def __init__(self, activation_name):
        super(CIFAR10_NN, self).__init__()
        self.fcLayer1 = nn.Linear(32*32*3, 512)
        self.fcLayer2 = nn.Linear(512, 256)
        self.fcLayer3 = nn.Linear(256, 128)
        self.fcLayer4 = nn.Linear(128, 64)
        self.fcLayer5 = nn.Linear(64, 32)
        self.fcLayer6 = nn.Linear(32, 16)
        self.fcLayer7 = nn.Linear(16, 10)
        if activation_name == "sigmoid": 
            self.activation = nn.Sigmoid()
        elif activation_name == "relu":
            self.activation = nn.ReLU()
        else:
            self.activation = nn.Tanh()
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.activation(self.fcLayer1(x))
        x = self.activation(self.fcLayer2(x))
        x = self.activation(self.fcLayer3(x))
        x = self.activation(self.fcLayer4(x))
        x = self.activation(self.fcLayer5(x))
        x = self.activation(self.fcLayer6(x))
        x = self.fcLayer7(x)
        return x

In [None]:
def train_and_test(activation):
    device = torch.device("cpu")
    model = CIFAR10_NN(activation).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr = 0.001)

    epochs = 10
    train_losses = []
    test_losses = []
    for epoch in range(epochs):
        running_loss = 0.0
        for batch_idx, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss=criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            last_loss = loss.item()

            # Only print gradients for the last batch of the epoch
            if batch_idx == len(trainloader) - 1:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {last_loss:.4f}")
                check_gradients(model)
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
        test_losses.append(test_loss / len(testloader))
        average_loss = running_loss / len(trainloader)
        train_losses.append(average_loss)
        print(f"{epoch + 1} / {epochs}, activation = {activation}, loss = {average_loss: .4f}")
    return train_losses, test_losses

In [None]:
def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(f"{name} Gradient Norm: {param.grad.norm().item()}")

In [None]:
activation_fns = ["sigmoid", "tanh", "relu"]
results = {}
for activation in activation_fns:
    results[activation] = (train_and_test(activation))

In [None]:
for name, result in results.items():
    print(result[0], result[1])
    plt.figure(figsize=(8, 5))
    plt.plot(result[0], label = 'Train Loss')
    plt.plot(result[1], label = 'Test Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title({name})
    plt.legend()
    plt.show()