<a href="https://colab.research.google.com/github/kenjitee/KenjiTee/blob/master/ViT_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install transformers pytorch-lightning --quiet
! sudo apt -qq install git-lfs

[K     |████████████████████████████████| 2.6 MB 7.1 MB/s 
[K     |████████████████████████████████| 916 kB 54.0 MB/s 
[K     |████████████████████████████████| 636 kB 55.8 MB/s 
[K     |████████████████████████████████| 895 kB 65.5 MB/s 
[K     |████████████████████████████████| 3.3 MB 45.8 MB/s 
[K     |████████████████████████████████| 118 kB 66.0 MB/s 
[K     |████████████████████████████████| 272 kB 72.4 MB/s 
[K     |████████████████████████████████| 829 kB 62.6 MB/s 
[K     |████████████████████████████████| 1.3 MB 60.2 MB/s 
[K     |████████████████████████████████| 294 kB 69.8 MB/s 
[K     |████████████████████████████████| 142 kB 70.6 MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
The following NEW packages will be installed:
  git-lfs
0 upgraded, 1 newly installed, 0 to remove and 40 not upgraded.
Need to get 2,129 kB of archives.
After this operation, 7,662 kB of additional disk space will be used.
debconf: unable to initialize frontend: 

In [None]:
import math
import matplotlib.pyplot as plt

import numpy as np
from PIL import Image, UnidentifiedImageError
from pathlib import Path
import torch
import glob
import pytorch_lightning as pl
from huggingface_hub import HfApi, Repository
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy
from transformers import ViTFeatureExtractor, ViTForImageClassification
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
from google.colab import files
files.upload()

In [None]:
!pip install -q kaggle

In [None]:
!mkdir~/.kaggle

In [None]:
! cp kaggle.json ~/.kaggle/

In [None]:
! chmod 600 ~/.kaggle/kaggle.json


In [None]:
! kaggle datasets list 

In [None]:
!kaggle datasets download -d grassknoted/asl-alphabet

In [None]:
!unzip asl-alphabet.zip

In [None]:
data_dir = Path("/content/asl_alphabet_train/asl_alphabet_train")

In [None]:
ds=ImageFolder(data_dir)
indices = torch. randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * .15)
train_ds = torch.utils.data.Subset(ds, indices[:-n_val])
val_ds = torch.utils.data.Subset(ds, indices[-n_val:])

In [None]:
plt.figure(figsize=(100,50))
num_examples_per_class = 1
i = 1
for class_idx, class_name in enumerate(ds.classes):
    folder = ds.root / class_name
    print(folder)
    for image_idx, image_path in enumerate(sorted(folder.glob('*'))):
        print(image_path)
        if image_path.suffix in ds.extensions:
            image = Image.open(image_path)
            plt.subplot(len(ds.classes), num_examples_per_class, i)
            ax = plt.gca()
            ax.set_title(
                class_name,
                size='xx-large',
                pad=5,
                loc='left',
                y=0,
                backgroundcolor='white'
            )
            ax.axis('off')
            plt.imshow(image)
            i += 1

            if image_idx + 1 == num_examples_per_class:
                break

In [None]:
label2id = {}
id2label = {}

for i, class_name in enumerate(ds.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name

In [None]:
  class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor
 
    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings 

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

collator = ImageClassificationCollator(feature_extractor)

train_loader = DataLoader(train_ds, batch_size=32, collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, collate_fn=collator, num_workers=2)

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label
)

In [None]:
class Classifier(pl.LightningModule):


    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.val_acc = Accuracy()
        self.train_acc= Accuracy()

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"train_loss", outputs.loss)
        acc1 = self.train_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"train_acc", acc1, prog_bar=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"val_loss", outputs.loss)
        acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"val_acc", acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr,weight_decay=0.0025)

In [None]:
pl.seed_everything(42)
classifier = Classifier(model, lr=2e-5)
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='/content/trainmebby',
    filename='ViT-{epoch:02d}-{val_loss:.2f}',
)
trainer = pl.Trainer(callbacks=[checkpoint_callback],gpus=1, precision=16, max_epochs=3)
trainer.fit(classifier, train_loader, val_loader)


In [None]:
test_data_path = '/content/asl_alphabet_test/asl_alphabet_test'
image_path1= '/content/asl_alphabet_test/asl_alphabet_test/O_test.jpg'
images_path=glob.glob(test_data_path+'/*.jpg')

In [None]:
def prediction(img_path):
  im=Image.open(img_path)
  encoding = feature_extractor(images=im, return_tensors="pt")
  encoding.keys()

  pixel_values = encoding['pixel_values']

  outputs = model(pixel_values)
  result = outputs.logits.softmax(1).argmax(1)
  new_result = result.tolist()
  for i in new_result:
    return(id2label[str(i)])


In [None]:
 def process_image(image_path):
    ''' Scales, crops, and normalizes a PIL image for a PyTorch model,
        returns an Numpy array
    '''
    
    pil_image = Image.open(image_path)
    
    # Resize
    if pil_image.size[0] > pil_image.size[1]:
        pil_image.thumbnail((5000, 256))
    else:
        pil_image.thumbnail((256, 5000))
        
    # Crop 
    left_margin = (pil_image.width-224)/2
    bottom_margin = (pil_image.height-224)/2
    right_margin = left_margin + 224
    top_margin = bottom_margin + 224
    
    pil_image = pil_image.crop((left_margin, bottom_margin, right_margin, top_margin))
    
    # Normalize
    np_image = np.array(pil_image)/255
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    np_image = (np_image - mean) / std
    
    # PyTorch expects the color channel to be the first dimension but it's the third dimension 
    # in the PIL image and Numpy array
    # Color channel needs to be first; retain the order of the other two dimensions.
    np_image = np_image.transpose((2, 0, 1))
    
    return np_image

In [None]:
def imshow(image, ax=None, title=None):
    if ax is None:
        fig, ax = plt.subplots()
    
    # PyTorch tensors assume the color channel is the first dimension
    # but matplotlib assumes is the third dimension
    image = image.transpose((1, 2, 0))
    
    # Undo preprocessing
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = std * image + mean
    
    if title is not None:
        ax.set_title(title)
    
    # Image needs to be clipped between 0 and 1 or it looks like noise when displayed
    image = np.clip(image, 0, 1)
    
    ax.imshow(image)
    
    return ax

In [None]:
def display_image(image_dir):

    # Plot flower input image
    plt.figure(figsize = (6,10))
    plot_1 = plt.subplot(2,1,1)
    
    image = process_image(image_dir)
    

    asl_sign = image_dir[image_dir.rfind('/')+1:]
    
    pred= prediction(image_dir)

    plot_1.set_xlabel("The predicted sign: "+pred)

    imshow(image, plot_1, title=asl_sign);


  

In [None]:
for i in images_path:
  display_image(i)

In [None]:
display_image(image_path1)