In [4]:
# import
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.vit import VisionTransformer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [5]:
#Define Data Transforms:

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

## Download data

In [6]:
train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
batch_size = 128

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Model

In [8]:
model = VisionTransformer(
    img_size=32,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    embed_dim=256,
    num_heads=8,
    hidden_dim=512,
    num_layers=6,
    dropout=0.1
).to(device)



In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

In [10]:
num_epochs = 1
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    model.train()
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    # Validation and logging code here

Epochs: 100%|██████████| 1/1 [06:22<00:00, 382.70s/it]
