In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import maxvit_t, MaxVit_T_Weights
from torch.utils.data import DataLoader, random_split

# 1. Settings
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size   = 64
num_epochs   = 7
print_every  = 1   # print summary every N epochs
best_acc     = 0.0

# 2. Preprocessing: resize → center-crop → to-tensor → normalize (ImageNet stats)
weights   = MaxVit_T_Weights.IMAGENET1K_V1
preproc   = weights.transforms()

# 3. Prepare CIFAR-10, but limit train → 5 000 images
full_train_ds = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=preproc
)
train_ds, _ = random_split(
    full_train_ds,
    [4000, len(full_train_ds) - 4000],
    generator=torch.Generator().manual_seed(42)
)
train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True)

full_val_ds = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=preproc
)
val_ds, _ = random_split(
    full_val_ds,
    [800, len(full_val_ds) - 800],
    generator=torch.Generator().manual_seed(42)  # for reproducibility
)
val_ld = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Training on classes:", train_ds.dataset.classes)

# 4. Model: load pretrained MaxViT-T, freeze everything, then append a new 1000→10 head
model = maxvit_t(weights=weights).to(device)

# 4a. Freeze all existing params
for param in model.parameters():
    param.requires_grad = False

# 4b. Extract and rebuild classifier
orig_classifier = model.classifier                  # Sequential([pool, flatten, Linear(512->1000)])
layers = list(orig_classifier.children())
old_linear = layers[-1]
num_features = old_linear.out_features              # should be 1000

# 4c. Append new head
layers.append(nn.Linear(num_features, 10))
model.classifier = nn.Sequential(*layers).to(device)

# 4d. Unfreeze only the new head
for param in model.classifier[-1].parameters():
    param.requires_grad = True

# 5. Loss & optimizer (only new head’s params)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    model.classifier[-1].parameters(),
    lr=1e-2,
    momentum=0.9,
    weight_decay=1e-4
)

# 6. Training + validation loop
for epoch in range(1, num_epochs + 1):
    # — Training —
    model.train()
    running_loss = correct = total = 0
    for imgs, labels in train_ld:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        preds = out.argmax(dim=1)
        correct   += (preds == labels).sum().item()
        total     += labels.size(0)

    train_loss = running_loss / total
    train_acc  = 100. * correct / total

    # — Validation —
    model.eval()
    val_loss = val_correct = val_total = 0
    with torch.no_grad():
        for imgs, labels in val_ld:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)
            val_loss    += criterion(out, labels).item() * imgs.size(0)
            preds        = out.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total   += labels.size(0)

    val_loss = val_loss / val_total
    val_acc  = 100. * val_correct / val_total

    # — Print stats —
    if epoch % print_every == 0:
        print(f"Epoch {epoch}/{num_epochs} — "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f},   Val Acc: {val_acc:.2f}%")

    # — Save best model —
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_maxvit_cifar10.pth")

print(f"\nDone! Best validation accuracy: {best_acc:.2f}%")
print("Best model weights saved to best_maxvit_cifar10.pth")


Training on classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Epoch 1/7 — Train Loss: 0.7896, Train Acc: 75.03% | Val Loss: 0.3620,   Val Acc: 88.38%
Epoch 2/7 — Train Loss: 0.4981, Train Acc: 83.75% | Val Loss: 0.3522,   Val Acc: 88.38%
Epoch 3/7 — Train Loss: 0.4692, Train Acc: 84.20% | Val Loss: 0.4044,   Val Acc: 87.50%
Epoch 4/7 — Train Loss: 0.4490, Train Acc: 85.45% | Val Loss: 0.3634,   Val Acc: 89.62%
Epoch 5/7 — Train Loss: 0.3958, Train Acc: 86.80% | Val Loss: 0.4321,   Val Acc: 86.62%
Epoch 6/7 — Train Loss: 0.3930, Train Acc: 86.80% | Val Loss: 0.4608,   Val Acc: 84.00%
Epoch 7/7 — Train Loss: 0.3940, Train Acc: 86.35% | Val Loss: 0.4489,   Val Acc: 85.75%

Done! Best validation accuracy: 89.62%
Best model weights saved to best_maxvit_cifar10.pth


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.utils.data import DataLoader, random_split

# 1. Settings
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size   = 64
num_epochs   = 7
print_every  = 1   # print summary every N epochs
best_acc     = 0.0

# 2. Preprocessing: resize → center-crop → to-tensor → normalize (ImageNet stats)
weights   = ViT_B_16_Weights.IMAGENET1K_V1  # or ViT_B_16_Weights.DEFAULT, or "DEFAULT"
preproc   = weights.transforms()

# 3. Prepare CIFAR-10, but limit train → 4 000 images
full_train_ds = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=preproc
)
train_ds, _ = random_split(
    full_train_ds,
    [4000, len(full_train_ds) - 4000],
    generator=torch.Generator().manual_seed(42)
)
train_ld = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

#    Limit validation → 800 images
full_val_ds = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=preproc
)
val_ds, _ = random_split(
    full_val_ds,
    [800, len(full_val_ds) - 800],
    generator=torch.Generator().manual_seed(42)
)
val_ld = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Training on classes:", train_ds.dataset.classes)

# 4. Model: load pretrained ViT-B_16, freeze everything, then append a new 1000→10 head
model = vit_b_16(weights=weights).to(device)

# 4a. Freeze all existing params
for param in model.parameters():
    param.requires_grad = False

# 4b. Extract and rebuild the classification head
#     VisionTransformer stores its head in `model.heads`, a nn.Sequential
orig_heads = model.heads                    # e.g. Sequential([LayerNorm, Linear(768→1000)])
layers = list(orig_heads.children())
old_linear = layers[-1]
num_features = old_linear.out_features      # should be 1000

# 4c. Append new 1000→10 Linear layer
layers.append(nn.Linear(num_features, 10))
model.heads = nn.Sequential(*layers).to(device)

# 4d. Un-freeze only the new head
for param in model.heads[-1].parameters():
    param.requires_grad = True

# 5. Loss & optimizer (only new head’s parameters)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    model.heads[-1].parameters(),
    lr=1e-2,
    momentum=0.9,
    weight_decay=1e-4
)

# 6. Training + validation loop
for epoch in range(1, num_epochs + 1):
    # — Training —
    model.train()
    running_loss = correct = total = 0
    for imgs, labels in train_ld:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        preds = out.argmax(dim=1)
        correct   += (preds == labels).sum().item()
        total     += labels.size(0)

    train_loss = running_loss / total
    train_acc  = 100. * correct / total

    # — Validation —
    model.eval()
    val_loss = val_correct = val_total = 0
    with torch.no_grad():
        for imgs, labels in val_ld:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)
            val_loss    += criterion(out, labels).item() * imgs.size(0)
            preds        = out.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total   += labels.size(0)

    val_loss = val_loss / val_total
    val_acc  = 100. * val_correct / val_total

    # — Print stats —
    if epoch % print_every == 0:
        print(f"Epoch {epoch}/{num_epochs} — "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f},   Val Acc: {val_acc:.2f}%")

    # — Save best model —
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_vit_cifar10.pth")

print(f"\nDone! Best validation accuracy: {best_acc:.2f}%")
print("Best model weights saved to best_vit_cifar10.pth")


100%|██████████| 170M/170M [00:13<00:00, 13.0MB/s]


Training on classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 193MB/s]


Epoch 1/7 — Train Loss: 0.4493, Train Acc: 86.17% | Val Loss: 0.2667,   Val Acc: 93.00%
Epoch 2/7 — Train Loss: 0.1839, Train Acc: 94.03% | Val Loss: 0.2407,   Val Acc: 92.62%
Epoch 3/7 — Train Loss: 0.1455, Train Acc: 95.08% | Val Loss: 0.2353,   Val Acc: 92.25%
Epoch 4/7 — Train Loss: 0.1184, Train Acc: 96.20% | Val Loss: 0.2316,   Val Acc: 92.50%
Epoch 5/7 — Train Loss: 0.0991, Train Acc: 97.00% | Val Loss: 0.2290,   Val Acc: 93.12%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0c847f7100>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1582, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/connection.py", line 948, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
     

KeyboardInterrupt: 