In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from vit_pytorch import SimpleViT
from fastprogress.fastprogress import master_bar, progress_bar
from torchvision.transforms import RandomResizedCrop, CenterCrop, Resize, RandomHorizontalFlip, Compose, ToTensor
from torch.cuda.amp import autocast, GradScaler

    ImportError: cannot import name 'builder' from 'google.protobuf.internal' (/home/dgj335/.local/lib/python3.10/site-packages/google/protobuf/internal/__init__.py)
  warn(message, cls)
    ImportError: cannot import name 'builder' from 'google.protobuf.internal' (/home/dgj335/.local/lib/python3.10/site-packages/google/protobuf/internal/__init__.py)
  warn(message, cls)


In [2]:
train_dataset = load_dataset("imagenet-1k",split="train",trust_remote_code=True)
valid_dataset = load_dataset("imagenet-1k",split="validation",trust_remote_code=True)
rgb_train = train_dataset.filter(lambda s: s["image"].mode == "RGB").with_format("torch")
rgb_valid = valid_dataset.filter(lambda s: s["image"].mode == "RGB").with_format("torch")

Loading dataset shards:   0%|          | 0/257 [00:00<?, ?it/s]

In [3]:
vit = SimpleViT(
    image_size = 224,
    patch_size = 14,
    num_classes = 1000,
    dim = 384,
    depth = 12,
    heads = 6,
    mlp_dim = 1536
).cuda()
scaler = GradScaler()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=0.001, weight_decay=0.0001)
scheduler = CosineAnnealingLR(optimizer, T_max=90, eta_min=1e-6)

In [4]:
def custom_collate_fn(batch):
    train_transforms = Compose(
        [
            Resize(224,antialias=True),
            RandomResizedCrop(224),
            RandomHorizontalFlip(),
        ]
    )
    images = torch.stack([train_transforms(item['image'].permute(2,0,1).to(torch.float)/255) for item in batch])
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.int64)
    return images, labels

In [None]:
epochs = 90
mb = master_bar(range(1, epochs + 1))
train_loss = []
test_loss = []
val_accuracy = []
batch_size = 256
val_transforms = Compose(
    [
        Resize(256,antialias=True),
        CenterCrop(224),
    ]
)
for epoch in mb:
    
    dataloader_train = DataLoader(rgb_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=6, collate_fn=custom_collate_fn)
    dataloader_valid = DataLoader(rgb_valid, batch_size=1, shuffle=False, drop_last=False)

    # Training
    vit.train()
    running_loss = 0.0
    for x,target in progress_bar(dataloader_train, parent=mb):
        x = x.cuda().to(torch.bfloat16)

        optimizer.zero_grad()
        with autocast():
            outputs = vit(x)
            loss = loss_function(outputs, target.cuda())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        print(loss.item(), end='\r')
        running_loss += loss.item()

    train_epoch_loss = running_loss / len(dataloader_train)
    train_loss.append(train_epoch_loss)
    
    # Validation
    vit.eval()
    running_loss_val = 0.0
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for inputs in dataloader_valid:
            x = inputs['image'].permute(0,3,1,2).to(torch.float)/255
            x = val_transforms(x).cuda().to(torch.bfloat16)
            targets = inputs['label'].to(torch.int64).cuda()
            outputs = vit(x)
            loss = loss_function(outputs, targets)
            if torch.isnan(loss):
                print(f"Loss is nan, skipping batch")
                continue
            running_loss_val += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
    val_epoch_loss = running_loss_val / len(dataloader_valid)
    test_loss.append(val_epoch_loss)
    epoch_accuracy = correct_predictions / total_predictions
    print(f'Epoch: {epoch}, Accuracy: {epoch_accuracy:.4f}')
    val_accuracy.append(epoch_accuracy)
    graphs = [
        [range(1, epoch + 1), train_loss],  # Training Loss
        [range(1, epoch + 1), test_loss],  # Validation Loss
    ]
    x_bounds = [1, epoch]
    y_bounds = [min(min(train_loss + test_loss) - 0.05, 0), max(max(train_loss + test_loss) + 0.05, 1)]
    mb.update_graph(graphs, x_bounds, y_bounds)
    checkpoint_filename = f'simplevit_imagenet1k_checkpoint{epoch}.pt'
    torch.save(vit.state_dict(), checkpoint_filename)

6.7879867553710945