In [5]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

class Logistic_Regression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(32*32*3, 10)

    def forward(self, x):
        return self.linear(x.view(-1, 32*32*3))

    def evaluate(self, test_dataloader):
        with torch.no_grad():
            total = 0
            correct = 0
            for data in test_dataloader:
                inputs, targets = data
                outputs = self.forward(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return 100 * (correct / total)

batch_size = 100
epochs = 10
lr = 0.001

model = Logistic_Regression()
criterion = nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=lr)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for epoch in range(epochs):
    print('Epoch', epoch, '- Test Accuracy', model.evaluate(test_dataloader))
    for data in training_dataloader:
        opt.zero_grad()
        inputs, targets = data
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        opt.step()

print("Final Accuracy", model.evaluate(test_dataloader))

torch.save(model.state_dict(), "./cifar10_logistic_regression.pt")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████████████████████████████████████████████████████████| 170498071/170498071 [00:15<00:00, 11161684.99it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch 0 - Test Accuracy 10.61
Epoch 1 - Test Accuracy 22.15
Epoch 2 - Test Accuracy 26.229999999999997
Epoch 3 - Test Accuracy 28.09
Epoch 4 - Test Accuracy 29.38
Epoch 5 - Test Accuracy 30.36
Epoch 6 - Test Accuracy 31.11
Epoch 7 - Test Accuracy 31.919999999999998
Epoch 8 - Test Accuracy 32.42
Epoch 9 - Test Accuracy 32.9
Final Accuracy 33.36
