In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 1. Dataloading
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 50

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 2. Define network parameters

Din = 3*32*32 #Input size (flattened image)
K = 10
std = 1e-5

#Initialize weights and biases

w = torch.randn(Din, K, dtype=torch.float64)*std
b = torch.zeroes(K)

#Hyperparameters
iterations = 20
learning_rate = 2e-6
learning_rate_decay = 0.9
reg = 0
loss_history = []

# 3. Training
for t in range (iterations):
    running_loss = 0.0

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        Ntr = inputs.shape[0]
        x_train = inputs.view(Ntr, -1)
        y_train_onehot = nn.functional.one_hot(labels, K).float()



        #Forward pass
        y_pred = x_train.mm(w) + b

        #Compute loss
        
        loss = (1/Ntr)*torch.sum((y_pred - y_train_onehot)**2) + reg*(torch.sum(w**2) + torch.sum(b**2))
        loss_history.append(loss.item())
        running_loss += loss.item()

        #Backpropagation
        dy_pred = 2.0/Ntr*(y_pred - y_train_onehot)
        dw = x_train.t().mm(dy_pred) + reg*w
        db = dy_pred.sum(dim=0) + reg*b

        #Update weights
        w -= learning_rate * dw
        b -= learning_rate * db

    #print loss for every epoch
    if t % 1 == 0:
        print(f"Iteration {t}, loss = {running_loss/len(trainloader)}")

    #Decay learning rate
    learning_rate *= learning_rate_decay

# 4. Plot loss
plt.plot(loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training Loss history')
plt.show()

# 5. Calculate Accuracy on training set
correct_train = 0
total_train = 0

with torch.no_grad():
    for data in trainloader:
        inputs, labels = data
        Ntr = inputs.shape[0]
        x_train = inputs.view(Ntr, -1)
        y_train_onehot = nn.functional.one_hot(labels, K).float()

        y_pred = x_train.mm(w) + b
        _, predicted = torch.max(y_pred, 1)
        _, actual = torch.max(y_train_onehot, 1)
        total_train += labels.size(0)
        correct_train += (predicted == actual).sum().item()

train_accuracy = 100 * correct_train / total_train
print(f"Training accuracy = {train_accuracy}")

# 6. Calculate Accuracy on test set
correct_test = 0
total_test = 0
with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        Nte = inputs.shape[0]
        x_test = inputs.view(Nte, -1)
        y_test_onehot = nn.functional.one_hot(labels, K).float()

        y_pred = x_test.mm(w) + b
        _, predicted = torch.max(y_pred, 1)
        _, actual = torch.max(y_test_onehot, 1)
        total_test += labels.size(0)
        correct_test += (predicted == actual).sum().item()

test_accuracy = 100 * correct_test / total_test
print(f"Test accuracy = {test_accuracy}")

# 7. Visualize weights
w = w.view(3, 32, 32, 10)
w = w.permute(3, 0, 1, 2)
w = w.detach().numpy()
w_min, w_max = np.min(w), np.max(w)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

for i in range(10):
    plt.subplot(2, 5, i + 1)
    
    # Rescale the weights to be between 0 and 1
    wimg = 255.0 * (w[i].squeeze() - w_min) / (w_max - w_min)
    plt.imshow(wimg.astype('uint8'))
    plt.axis('off')
    plt.title(classes[i])

plt.show()





    

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


  1%|          | 1310720/170498071 [01:42<3:40:56, 12763.04it/s]


KeyboardInterrupt: 