In [None]:
import torch
from vit_transformer import VitTransformer
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, RandomHorizontalFlip,RandomResizedCrop, ToTensor

In [None]:
from transformers import ViTFeatureExtractor
from datasets import load_dataset

In [None]:
encoder_layers = 3

embed_dim = 512
num_heads = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:" +str(device))

epochs = 2

smoothing_rate = 0.1

batch_size = 1000

In [None]:
# load cifar10 
train_set, test_set = load_dataset('cifar10', split=['train', 'test'])
num_classes = 10

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

image_h_w, img_channels, patch_h_w = 224, 3, 16

In [None]:
_train_transform_steps = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ])

def train_transforms(images):
    transformed_images =  [_train_transform_steps(curr_image['img'].convert("RGB")) for curr_image in images]
    return transformed_images


In [None]:
def collate_func(images):
    labels = torch.tensor([image_t["label"] for image_t in images])
    return {"images": images, "labels": labels}

train_loader = DataLoader(train_set, collate_fn=collate_func, batch_size=batch_size, shuffle=True, pin_memory=True)

In [None]:
transformer = VitTransformer(image_h_w, img_channels, patch_h_w, embed_dim, num_heads, encoder_layers, num_classes).to(device)
adam_opt = torch.optim.Adam(transformer.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
ce_loss = torch.nn.CrossEntropyLoss(label_smoothing=smoothing_rate)

In [None]:
def train(transformer, train_loader):
    transformer.train()
    
    for epoch in range(epochs):
        batches_total_loss = 0
        batches_total_size = 0

        for i, batch_data in enumerate(train_loader):

            transformed_images = torch.stack(train_transforms(batch_data['images']))
            transformed_images = transformed_images.to(device)
            batch_size = transformed_images.shape[0]

            labels = batch_data['labels']
            labels = labels.to(device)
            labels = labels.contiguous().view(-1)  # dims: [batch_size * 1]

            preds = transformer(transformed_images)

            adam_opt.zero_grad()

            loss = ce_loss(preds, labels)

            loss.backward()
            adam_opt.step()

            batches_total_loss += loss.item() * batch_size
            batches_total_size += batch_size

            if i % 100 == 0:
                print(f"Epoch: [{epoch}] Batch:[{i}/{len(train_loader)}]\tLoss: {batches_total_loss/batches_total_size:.3f}")

        state = {'epoch': epoch, 'model': transformer, 'optimizer': adam_opt}
        torch.save(state, 'vit_model_epoch_' + str(epoch) + '.pth')
        print("saved model on epoch: "+str(epoch))
        


In [None]:
train(transformer, train_loader)