In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from tqdm import tqdm
import timm
import detectors

In [5]:
TRAIN = True
LOAD_FROM_FILE = False

In [6]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

batch_size = 64
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
num_classes = 10

In [7]:
model_names = timm.list_models('*cifar*')
print(model_names)

['densenet121_cifar10', 'densenet121_cifar100', 'resnet18_cifar10', 'resnet18_cifar100', 'resnet34_cifar10', 'resnet34_cifar100', 'resnet34_simclr_cifar10', 'resnet34_simclr_cifar100', 'resnet34_supcon_cifar10', 'resnet34_supcon_cifar100', 'resnet50_cifar10', 'resnet50_cifar100', 'resnet50_simclr_cifar10', 'resnet50_simclr_cifar100', 'resnet50_supcon_cifar10', 'resnet50_supcon_cifar100', 'vgg16_bn_cifar10', 'vgg16_bn_cifar100', 'vit_base_patch16_224_in21k_ft_cifar10', 'vit_base_patch16_224_in21k_ft_cifar100']


In [15]:
model = timm.create_model('resnet18_cifar10', pretrained=True)
if LOAD_FROM_FILE:
    model.load_state_dict(torch.load("cifar10_model.pth"))
    print("State loaded from file")

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)

cuda


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [8]:
if TRAIN:
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        # Wrap trainloader with tqdm
        with tqdm(trainloader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            for inputs, labels in tepoch:
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                tepoch.set_postfix(loss=loss.item())

        epoch_loss = running_loss / len(trainloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {epoch_loss:.4f}")

Epoch [1/10]: 100%|██████████| 782/782 [00:35<00:00, 22.20batch/s, loss=0.13]  


Epoch [1/10] completed. Average Loss: 0.1079


Epoch [2/10]: 100%|██████████| 782/782 [00:38<00:00, 20.16batch/s, loss=0.124]  


Epoch [2/10] completed. Average Loss: 0.0924


Epoch [3/10]: 100%|██████████| 782/782 [00:39<00:00, 19.92batch/s, loss=0.467]  


Epoch [3/10] completed. Average Loss: 0.0826


Epoch [4/10]: 100%|██████████| 782/782 [00:38<00:00, 20.16batch/s, loss=0.271]  


Epoch [4/10] completed. Average Loss: 0.0768


Epoch [5/10]: 100%|██████████| 782/782 [00:38<00:00, 20.50batch/s, loss=0.459]  


Epoch [5/10] completed. Average Loss: 0.0764


Epoch [6/10]: 100%|██████████| 782/782 [00:39<00:00, 19.76batch/s, loss=0.15]   


Epoch [6/10] completed. Average Loss: 0.0651


Epoch [7/10]: 100%|██████████| 782/782 [00:39<00:00, 20.05batch/s, loss=0.0154] 


Epoch [7/10] completed. Average Loss: 0.0607


Epoch [8/10]: 100%|██████████| 782/782 [00:39<00:00, 19.94batch/s, loss=0.18]   


Epoch [8/10] completed. Average Loss: 0.0607


Epoch [9/10]: 100%|██████████| 782/782 [00:38<00:00, 20.33batch/s, loss=0.368]  


Epoch [9/10] completed. Average Loss: 0.0608


Epoch [10/10]: 100%|██████████| 782/782 [00:38<00:00, 20.32batch/s, loss=0.995]  

Epoch [10/10] completed. Average Loss: 0.0574





In [9]:
torch.save(model.state_dict(), "cifar10_model.pth")

In [14]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the test set: {100 * correct / total:.2f}%")

Accuracy on the test set: 89.24%
