In [None]:
pip install transformers datasets



In [None]:
from transformers import AutoModelForImageClassification
from datasets import load_dataset


In [None]:
model = AutoModelForImageClassification.from_pretrained("facebook/deit-tiny-patch16-224")

In [None]:
train_main = load_dataset("cifar10", split="train")
test = load_dataset("cifar10", split= "test")
split = train_main.train_test_split(test_size=0.1)
train = split['train']
val = split['test']

In [None]:
from transformers import AutoImageProcessor

In [None]:
pip install transformers[torch]



In [None]:
import accelerate

In [None]:
processor = AutoImageProcessor.from_pretrained("facebook/deit-tiny-patch16-224")

In [None]:
from torchvision.transforms import (CenterCrop,
                                    Compose,
                                    Normalize,
                                    RandomHorizontalFlip,
                                    RandomResizedCrop,
                                    Resize,
                                    ToTensor)

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose([
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,])

_val_transforms = Compose([
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,])

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

In [None]:
train.set_transform(train_transforms)
val.set_transform(val_transforms)
test.set_transform(val_transforms)

In [None]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"

args = TrainingArguments(
    f"test-cifar-10",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

In [None]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(train, collate_fn=collate_fn, batch_size=4)

In [None]:
from sklearn.metrics import accuracy_score
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))


In [None]:
import torch

trainer = Trainer(
    model,
    args,
    train_dataset=train,
    eval_dataset=val,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)


In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.5507,0.2049,0.9354
2,0.4526,0.164853,0.9464
3,0.4231,0.158358,0.9504
4,0.396,0.15858,0.9488
5,0.3886,0.151114,0.9554
6,0.3513,0.145507,0.9576
7,0.3349,0.134586,0.9628
8,0.3161,0.128973,0.9642
9,0.2879,0.124801,0.9658
10,0.2697,0.122716,0.967


TrainOutput(global_step=45000, training_loss=0.40487948811848956, metrics={'train_runtime': 3248.1619, 'train_samples_per_second': 138.54, 'train_steps_per_second': 13.854, 'total_flos': 2.3237042282496e+18, 'train_loss': 0.40487948811848956, 'epoch': 10.0})

In [None]:
trainer.evaluate()

{'eval_loss': 0.12271645665168762,
 'eval_accuracy': 0.967,
 'eval_runtime': 46.3411,
 'eval_samples_per_second': 107.896,
 'eval_steps_per_second': 26.974,
 'epoch': 10.0}

In [None]:
torch.save(model,'deit_tiny_10.pt')

In [None]:
from google.colab import files
files.download('deit_tiny_10.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.5623,0.211401,0.9274
2,0.4778,0.165889,0.9456
3,0.4098,0.157482,0.9484
4,0.3887,0.14939,0.9512
5,0.3775,0.142511,0.9588
6,0.3344,0.14126,0.9584
7,0.316,0.139086,0.9612
8,0.3092,0.127307,0.964
9,0.2881,0.125203,0.9662
10,0.2885,0.119803,0.9686


TrainOutput(global_step=45000, training_loss=0.40555804612901475, metrics={'train_runtime': 3358.5746, 'train_samples_per_second': 133.985, 'train_steps_per_second': 13.399, 'total_flos': 2.3237042282496e+18, 'train_loss': 0.40555804612901475, 'epoch': 10.0})

In [None]:
trainer.evaluate()

{'eval_loss': 0.11980312317609787,
 'eval_accuracy': 0.9686,
 'eval_runtime': 51.8887,
 'eval_samples_per_second': 96.36,
 'eval_steps_per_second': 24.09,
 'epoch': 10.0}

In [None]:
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8)
print(quantized_model)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): DynamicQuantizedLinear(in_features=192, out_features=192, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (key): DynamicQuantizedLinear(in_features=192, out_features=192, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (value): DynamicQuantizedLinear(in_features=192, out_features=192, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): DynamicQuantizedLinear(in_features=192, out_fea

In [None]:
import os
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)

Size (MB): 22.936702
Size (MB): 6.492388


In [None]:
torch.save(model, "deit_t_model.pth")

In [None]:
from google.colab import files
files.download('deit_t_model.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>