In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchmetrics import Accuracy

transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)

train_dataset = torchvision.datasets.ImageFolder("dataset/train", transform=transform)
test_dataset = torchvision.datasets.ImageFolder("dataset/val", transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4096, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4096, shuffle=True)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
import torchvision.models as models

model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.weight.shape[1], 10)
model = model.to(device)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
epochs = 100

In [9]:
def train_one_epoch(epoch_index):
    running_loss = []
    last_loss = 0.

    model.train()
    for data in train_dataloader:
        inputs, labels = data

        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss.append(loss.item())
        last_loss = running_loss[-1]

    test_loss = 0.


    accuracy = Accuracy(task="multiclass", num_classes=10).to(device)
    model.eval()
    with torch.no_grad():
        for data in test_dataloader:
            inputs, labels = data
            
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)

            loss = loss_fn(outputs, labels)

            test_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            accuracy.update(preds, labels)

    print(f"[Epoch {epoch_index}] Test loss: {test_loss} Test accuracy: {accuracy.compute()}")

    return running_loss


In [None]:
for i in range(epochs):
    losses = []
    losses.extend(train_one_epoch(i))

[Epoch 0] Test loss: 10.953638792037964 Test accuracy: 0.3158999979496002
[Epoch 1] Test loss: 10.801852941513062 Test accuracy: 0.26855000853538513
[Epoch 2] Test loss: 10.874414443969727 Test accuracy: 0.21404999494552612
[Epoch 3] Test loss: 10.955583572387695 Test accuracy: 0.21744999289512634
[Epoch 4] Test loss: 10.884592294692993 Test accuracy: 0.23420000076293945
[Epoch 5] Test loss: 10.745460033416748 Test accuracy: 0.25130000710487366
