In [None]:
%%capture
!pip install transformers pytorch-lightning --quiet
!sudo apt -qq install git-lfs
!git config --global credential.helper store

In [None]:
import math
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from google.colab import drive
from pathlib import Path
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision.datasets import ImageFolder

In [None]:
data_dir = Path("images")
if data_dir.exists():
    shutil.rmtree(data_dir)

In [None]:
drive.mount("/content/gdrive")
!unzip /content/gdrive/MyDrive/images.zip -d images

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: images/chicken_bistek/3AANd9GcSjAsKTFdksRlmZE05ztb6wOUiUtyOAXGJ1sraSkm17SuAa-Xy9.jpg  
  inflating: images/chicken_bistek/3AANd9GcSLpqAD9cZaIuGKYR4LoGBolixU4XdxC22MajHAyhehh_ClEQhx.jpg  
  inflating: images/chicken_bistek/3AANd9GcSMMc8qXvuvI9a1G0iPofsquCIqbULQGoPips3xdF1YZSXGwKAG.jpg  
  inflating: images/chicken_bistek/3AANd9GcSOCh2C8mgUzkBMaruM_iW8MrGT-JOP7I9oDpwQ3QfL-gBGSWxe.jpg  
  inflating: images/chicken_bistek/3AANd9GcSP22sw3xXkEsYhOzDASS5mtq9sA85hUZ1S55W5uyj66neC2cxl.jpg  
  inflating: images/chicken_bistek/3AANd9GcSpKI5SRCxdbwluqjIuhp_JFAquHgSWbLciRrZrHjwQJWSrV16q.jpg  
  inflating: images/chicken_bistek/3AANd9GcSq0-pIOrylPmmU86Y1X2S2zJoSHLL_V3FEmi-YriLtbRBL3yCm.jpg  
  inflating: images/chicken_bistek/3AANd9GcSqAOnkrtbXOnC4AleEGpcyQCSZQ7xTof-XsgY0E8Rh8ZjWTKJn.jpg  
  inflating: images/chicken_bistek/3AANd9GcSQUZeK7awr_AUagkFjo3LhzDdmmDw9bXmcWXLp5yTezk1FGxoo.jpg  
  inflating: images/chicken_bistek/

In [None]:
ds = ImageFolder(data_dir)
indices = torch.randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * .15)
train_ds = torch.utils.data.Subset(ds, indices[:-n_val]) # 5758 images, 85%
val_ds = torch.utils.data.Subset(ds, indices[-n_val:])  # 1012 images, 15%
len(indices[:-n_val]), len(indices[-n_val:]) # 6750 images

(5738, 1012)

In [None]:
# assign id to each label
label2id = {}
id2label = {}
for i, class_name in enumerate(ds.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor

    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors="pt")
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings
feature_extractor = ViTFeatureExtractor.from_pretrained(
    "google/vit-base-patch16-224-in21k",from_pt=True
)
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label
)
collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds,
                          batch_size=64,
                          collate_fn=collator,
                          num_workers=2,
                          shuffle=True)
val_loader = DataLoader(val_ds,
                        batch_size=64,
                        collate_fn=collator,
                        num_workers=2)

In [None]:
def determine_vals(w, max):
    temp_vals = []
    n = len(w.tolist())
    for i in range(n):
        temp_vals[i] = 0         # set to 0s first
        weight = 0               # adjust later(?)
        while (weight < max):
            i = temp_vals[n - 1]
            if (weight + w[i]):
                temp_vals[i] = 1
                weight = weight + w[i]
            else:
                temp_vals[i] = (w - weight) / w[i]
                weight = max
    return temp_vals

In [None]:
def for_adjust(n: int, W):
    temp_vals = []
    for w in range(W):
        temp_vals[0][w] = 0
        for i in range(n):
            temp_vals = 0
            for w in range(len(W)):
                if W[i] <= w:
                    if (temp_vals[i + 1] + temp_vals[i - 1][w - W[w]] > temp_vals[i - 1][w]):
                        temp_vals[i][w] = temp_vals[i] + temp_vals[i - 1][w - W[w]]
                    else:
                        temp_vals[i][w] = temp_vals[i - 1][w]
                else:
                    temp_vals[i][w] = temp_vals[i - 1][w]
    return temp_vals

In [None]:
class Classifier(pl.LightningModule):
    def __init__(self, model, lr: float=2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.relu = torch.nn.ReLU()
        self.forward = self.model.forward
        self.fc = nn.Linear(768, 768, bias=True)
        self.val_acc = Accuracy(
            task='multiclass',
            num_classes=model.config.num_labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.adjust = for_adjust(self, len(batch), len(ds.tolist()))
        self.conv = nn.Conv2d(
            3, 768,
            kernel_size=(3, 3),
            stride=1,
            padding_mode="zeros",
            dilation=1,
            groups=1,
            bias=True
        )
        outputs = self.relu(outputs)
        self.maxpool = nn.MaxPool2d(
            kernel_size=(3, 3),
            stride=3
        )
        outputs = self.maxpool(outputs)
        self.determine_val = determine_vals(self, len(model.tolist()))
        self.flatten = torch.flatten()
        outputs = self.flatten(outputs)
        self.log(f'train_loss', outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f'val_loss', outputs.loss)
        acc = self.val_acc(
            outputs.logits.argmax(1),
            batch['labels']
        )
        self.log(f'val_acc', acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        # return torch.optim.SGD(self.parameters, lr=self.hparams.lr)
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.lr
        )

In [None]:
pl.seed_everything(1) # same randomizer seed
classifier = Classifier(model, lr=2e-5)
trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    precision='16-mixed',
    min_epochs=10,
    max_epochs=50
)
trainer.fit(classifier, train_loader, val_loader)

INFO:lightning_fabric.utilities.seed:Seed set to 42
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type                      | Params
------------------------------------------------------
0 | model   | ViTForImageClassification | 85.8 M
1 | val_acc | MulticlassAccuracy        | 0     
------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.382   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


In [None]:
val_batch = next(iter(val_loader))
outputs = model(**val_batch)
print("Preds: ", outputs.logits.softmax(1).argmax(1))
print("Labels:", val_batch['labels'])

Preds:  tensor([ 0, 38, 25, 34, 16, 50, 44, 13, 60, 27, 10, 34, 39, 42, 14, 29, 55, 46,
         7, 44, 28, 51,  9, 60, 54, 25, 28, 19, 19, 37, 46,  3, 19, 20, 50, 53,
        43, 52, 29, 11, 27, 11, 10,  3, 44, 20, 33, 52,  2, 16,  9, 24, 16, 53,
        22, 38, 22, 52, 26, 39, 28, 50, 58, 40])
Labels: tensor([ 0, 38, 25, 34, 14, 50, 44, 19, 60, 27,  5, 34, 39, 42,  2, 29, 46,  2,
        16, 44, 28, 51,  9, 60, 54, 25, 28,  5, 19, 37,  0,  3, 19, 20, 55, 53,
        43, 52, 29, 11, 27, 11, 20,  3, 24, 20, 33, 52, 14, 29,  9, 24,  7, 53,
        22, 38, 22, 52, 26, 39, 28, 48, 58, 49])
