In [6]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

class Logistic_Regression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(32*32*3, 10)

    def forward(self, x):
        return self.linear(x.view(-1, 32*32*3))

    def evaluate(self, test_dataloader):
        with torch.no_grad():
            total = 0
            correct = 0
            for data in test_dataloader:
                inputs, targets = data
                outputs = self.forward(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return 100 * (correct / total)

batch_size = 100
epochs = 10
lr = 0.001

model = Logistic_Regression()
criterion = nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=lr)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for epoch in range(epochs):
    print('Epoch', epoch, '- Test Accuracy', model.evaluate(test_dataloader))
    for data in training_dataloader:
        opt.zero_grad()
        inputs, targets = data
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        opt.step()

print("Final Accuracy", model.evaluate(test_dataloader))

torch.save(model.state_dict(), "./cifar10_logistic_regression.pt")


Files already downloaded and verified
Files already downloaded and verified
Epoch 0 - Test Accuracy 9.94
Epoch 1 - Test Accuracy 33.12
Epoch 2 - Test Accuracy 35.69
Epoch 3 - Test Accuracy 36.95
Epoch 4 - Test Accuracy 37.84
Epoch 5 - Test Accuracy 38.42
Epoch 6 - Test Accuracy 38.690000000000005
Epoch 7 - Test Accuracy 38.67
Epoch 8 - Test Accuracy 38.76
Epoch 9 - Test Accuracy 39.26
Final Accuracy 39.26
