In [None]:
%pip install -U lightning
%pip install -U wandb
%pip install -U transformers

In [1]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33manhlt250102[0m ([33mmy_computer_vision_team[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
from typing import Any
from lightning import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, random_split, Dataset
from sklearn.model_selection import train_test_split
from lightning.pytorch.callbacks import ModelSummary, EarlyStopping
from lightning.pytorch.loggers import WandbLogger
from torchvision import transforms
from transformers import ViTForImageClassification, BitForImageClassification, ViTHybridForImageClassification
from torchvision.datasets import MNIST, ImageNet, CIFAR100
from torchvision.models import vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14
import torch
import torch.nn as nn
import numpy as np


def patchify(batch: torch.Tensor, patch_size: tuple = (16, 16)):
    """
    Patchify the batch of images

    Shape:
        batch: (b, h, w, c)
        output: (b, nh, nw, ph, pw, c)
    """
    b, h, w, c = batch.shape # (n, 224, 224, 3)
    ph, pw = patch_size # (16, 16)
    nh, nw = h // ph, w // pw # (14, 14)

    patches = torch.zeros(b, nh*nw, ph*pw*c).to(batch.device) # (n, nh*nw, ph*pw*c) = (n, 196, 768)

    for idx, image in enumerate(batch):
        for i in range(nh):
            for j in range(nw):
                patch = image[i*ph: (i+1)*ph, j*pw: (j+1)*pw, :]
                patches[idx, i*nh + j] = patch.flatten()
    return patches # (n, nh*nw, ph*pw*c) = (n, 196, 768)

def get_mlp(in_features, hidden_units, out_features):
    """
    Returns a MLP head
    """
    dims = [in_features] + hidden_units + [out_features]
    layers = []
    for dim1, dim2 in zip(dims[:-2], dims[1:-1]):
        layers.append(nn.Linear(dim1, dim2))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    return nn.Sequential(*layers)

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result # (s, d)

class ImgDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if image.shape[0] == 1 or len(image.shape) == 2:
            image = image.repeat(3, 1, 1)

        if self.transform:
            image = self.transform(image)

        return image, label

class MyDataModule(LightningDataModule):
    def __init__(self, dataset_to_down, img_size: tuple, data_dir: str = './data', batch_size: int = 64):
        super().__init__()
        self.dataset_to_down = dataset_to_down
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]),
        ])

    def prepare_data(self):
        print("Downloading dataset...")
        self.train_dataset = self.dataset_to_down(root="data/", train=True, transform=None, download=True)
        self.test_dataset = self.dataset_to_down(root="data/", train=False, transform=None, download=True)

    def n_classes(self):
        return np.unique(self.train_dataset.targets).reshape(-1).shape[0]

    def setup(self, stage=None):
        train_images, val_images, train_labels, val_labels = train_test_split(self.train_dataset.data, self.train_dataset.targets, test_size=0.2, random_state=42)
        test_images, test_labels = self.test_dataset.data, self.test_dataset.targets
        if stage == 'fit' or stage is None:
            self.train_ds = ImgDataset(train_images, train_labels, transform=self.transform)
            self.val_ds = ImgDataset(val_images, val_labels, transform=self.transform)
        if stage == 'test' or stage is None:
            self.test_ds = ImgDataset(test_images, test_labels, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=0)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, shuffle=False, num_workers=0)

class ViT(nn.Module):
    def __init__(
        self,
        nhead: int = 4,
        dim_feedforward: int = 1024,
        blocks: int = 4,
        mlp_head_units: list = [1024, 512],
        n_classes: int = 1,
        img_size: tuple = (224, 224),
        patch_size: tuple = (16, 16),
        n_channels: int = 3,
        d_model: int = 512,
    ):
        super().__init__()
        """
        Args:
            img_size: Size of the image
            patch_size: Size of the patch
            n_channels: Number of image channels
            d_model: The number of features in the transformer encoder
            nhead: The number of heads in the multiheadattention models
            dim_feedforward: The dimension of the feedforward network model in the encoder
            blocks: The number of sub-encoder-layers in the encoder
            mlp_head_units: The hidden units of mlp_head
            n_classes: The number of output classes
        """
        # self.img2seq = Img2Seq(img_size, patch_size, n_channels, d_model)
        # self.patch_size = patch_size # (16, 16)
        # nh, nw = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # (14, 14)
        # n_tokens = nh * nw # 196
        # token_dim = patch_size[0] * patch_size[1] * n_channels # 768
        # self.first_linear = nn.Linear(token_dim, d_model) # (768, 512)
        # self.cls_token = nn.Parameter(torch.randn(1, d_model)) # (1, 512)
        # self.pos_emb = nn.Parameter(get_positional_embeddings(n_tokens, d_model)) # (196, 512)

        self.patch_emb = nn.Conv2d(n_channels, d_model, kernel_size=patch_size, stride=patch_size) # (3, 512, (16, 16), (16, 16))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, activation="gelu", batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, blocks
        )
        self.mlp = get_mlp(d_model, mlp_head_units, n_classes) # (512, [1024, 512], n_classes)

        # self.classifer = nn.Sigmoid() if n_classes == 1 else nn.Softmax()

    def forward(self, batch):
        """
        Shape:
            input: (b, c, h, w)
            output: (b, n_classes)
        """
        # batch = torch.permute(batch, (0, 2, 3, 1)) # (b, h, w, c) = (b, 224, 224, 3)
        # batch = patchify(batch, self.patch_size) # (b, nh*nw, ph*pw*c) = (b, 196, 768)
        # b = batch.shape[0]
        # batch = self.first_linear(batch) # (b, nh*nw, d_model) = (b, 196, 512)
        # cls = self.cls_token.expand([b, -1, -1]) # (b, 1, d_model) = (b, 1, 512)
        # emb = batch + self.pos_emb # (b, nh*nw, d_model) = (b, 196, 512)
        # batch = torch.cat([cls, emb], axis=1) # (b, nh*nw+1, d_model) = (b, 197, 512)

        batch = self.patch_emb(batch) # (b, d_model, nh, nw) = (b, 512, 14, 14)
        batch = batch.flatten(2).transpose(1, 2) # (b, nh*nw, d_model) = (b, 196, 512)

        batch = self.transformer_encoder(batch) # (b, s, d)
        batch = batch[:, 0, :] # (b, d)
        output = self.mlp(batch) # (b, n_classes)
        # output = self.classifer(batch) # (b, n_classes)
        return output

class ViTModule(LightningModule):
    def __init__(self, learning_rate: float = 1e-4,
                 nhead: int = 4,
                 dim_feedforward: int = 1024,
                 blocks: int = 4,
                 mlp_head_units: list = [1024, 512],
                 n_classes: int = 1,
                 img_size: tuple = (224, 224),
                 patch_size: tuple = (16, 16),
                 n_channels: int = 3,
                 d_model: int = 512) -> None:
        '''
        Args:
            img_size: Size of the image
            patch_size: Size of the patch
            n_channels: Number of image channels
            d_model: The number of features in the transformer encoder
            nhead: The number of heads in the multiheadattention models
            dim_feedforward: The dimension of the feedforward network model in the encoder
            blocks: The number of sub-encoder-layers in the encoder
            mlp_head_units: The hidden units of mlp_head
            n_classes: The number of output classes

        Shape:
            input: (b, c, h, w)
            output: (b, n_classes)
        '''
        super().__init__()
        self.learing_rate = learning_rate
        self.patch_size = patch_size # (16, 16)
        nh, nw = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # (14, 14)
        n_tokens = nh * nw # 196
        token_dim = patch_size[0] * patch_size[1] * n_channels # 768
        self.first_linear = nn.Linear(token_dim, d_model) # (768, 512)
        self.cls_token = nn.Parameter(torch.randn(1, d_model)) # (1, 512)
        self.pos_emb = nn.Parameter(get_positional_embeddings(n_tokens, d_model)) # (196, 512)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, activation="gelu", batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, blocks
        )
        self.mlp = get_mlp(d_model, mlp_head_units, n_classes) # (512, [1024, 512], n_classes)

        self.classifer = nn.Sigmoid() if n_classes == 1 else nn.Softmax()
        # self.criteria = nn.CrossEntropyLoss()

        self.train_accuracy = []
        self.val_accuracy = []
        self.test_accuracy = []
        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        """
        Shape:
            input: (b, c, h, w)
            output: (b, n_classes)
        """
        # batch = self.img2seq(batch) # (b, s, d)
        batch = torch.permute(batch, (0, 2, 3, 1)) # (b, h, w, c) = (b, 224, 224, 3)
        batch = patchify(batch, self.patch_size) # (b, nh*nw, ph*pw*c) = (b, 196, 768)
        b = batch.shape[0]
        batch = self.first_linear(batch) # (b, nh*nw, d_model) = (b, 196, 512)
        cls = self.cls_token.expand([b, -1, -1]) # (b, 1, d_model) = (b, 1, 512)
        emb = batch + self.pos_emb # (b, nh*nw, d_model) = (b, 196, 512)
        batch = torch.cat([cls, emb], axis=1) # (b, nh*nw+1, d_model) = (b, 197, 512)

        batch = self.transformer_encoder(batch) # (b, s, d)
        batch = batch[:, 0, :] # (b, d)
        batch = self.mlp(batch) # (b, n_classes)
        output = self.classifer(batch) # (b, n_classes)
        return output

    def training_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('train_accuracy', accuracy, prog_bar=True)
        self.log('train_loss', loss, prog_bar=True)
        self.train_accuracy.append(accuracy)
        self.train_loss.append(loss)
        return loss

    def validation_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('val_accuracy', accuracy, prog_bar=True)
        self.log('val_loss', loss, prog_bar=True)
        self.val_accuracy.append(accuracy)
        self.val_loss.append(loss)
        return loss

    def test_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('test_accuracy', accuracy, prog_bar=True)
        self.log('test_loss', loss, prog_bar=True)
        self.test_accuracy.append(accuracy)
        self.test_loss.append(loss)
        return loss

    def on_train_epoch_end(self) -> None:
        self.log('train_accuracy_epoch', torch.stack(self.train_accuracy).mean())
        self.log('train_loss_epoch', torch.stack(self.train_loss).mean())
        self.train_accuracy = []
        self.train_loss = []

    def on_validation_epoch_end(self) -> None:
        self.log('val_accuracy_epoch', torch.stack(self.val_accuracy).mean())
        self.log('val_loss_epoch', torch.stack(self.val_loss).mean())
        self.val_accuracy = []
        self.val_loss = []

    def on_test_epoch_end(self) -> None:
        self.log('test_accuracy_epoch', torch.stack(self.test_accuracy).mean())
        self.log('test_loss_epoch', torch.stack(self.test_loss).mean())
        self.test_accuracy = []
        self.test_loss = []

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learing_rate)

class ViTPretrainedModule(LightningModule):
    def __init__(self, model, learning_rate: float, source: str = 'pytorch', n_classes: int = None, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.source = source
        self.learing_rate = learning_rate
        self.model = model
        if n_classes is not None and source == 'pytorch':
            self.model.heads = nn.Linear(self.model.heads.head.in_features, n_classes)
        self.criteria = nn.CrossEntropyLoss()

        self.train_accuracy = []
        self.val_accuracy = []
        self.test_accuracy = []
        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x).logits if self.source == 'huggingface' else self.model(x)

    def training_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = self.criteria(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('train_accuracy', accuracy, prog_bar=True)
        self.log('train_loss', loss, prog_bar=True)
        self.train_accuracy.append(accuracy)
        self.train_loss.append(loss)
        return loss

    def validation_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = self.criteria(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('val_accuracy', accuracy, prog_bar=True)
        self.log('val_loss', loss, prog_bar=True)
        self.val_accuracy.append(accuracy)
        self.val_loss.append(loss)
        return loss

    def test_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = self.criteria(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('test_accuracy', accuracy, prog_bar=True)
        self.log('test_loss', loss, prog_bar=True)
        self.test_accuracy.append(accuracy)
        self.test_loss.append(loss)
        return loss

    def on_train_epoch_end(self) -> None:
        self.log('train_accuracy_epoch', torch.stack(self.train_accuracy).mean())
        self.log('train_loss_epoch', torch.stack(self.train_loss).mean())
        self.train_accuracy = []
        self.train_loss = []

    def on_validation_epoch_end(self) -> None:
        self.log('val_accuracy_epoch', torch.stack(self.val_accuracy).mean())
        self.log('val_loss_epoch', torch.stack(self.val_loss).mean())
        self.val_accuracy = []
        self.val_loss = []

    def on_test_epoch_end(self) -> None:
        self.log('test_accuracy_epoch', torch.stack(self.test_accuracy).mean())
        self.log('test_loss_epoch', torch.stack(self.test_loss).mean())
        self.test_accuracy = []
        self.test_loss = []

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learing_rate)

class LinearModule(LightningModule):
    def __init__(self, img_size: tuple = (224, 224), n_channels: int = 3, n_classes: int = 10, learning_rate: float = 1e-4, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.learing_rate = learning_rate
        self.model = nn.Linear(img_size[0] * img_size[1] * n_channels, n_classes)
        self.criteria = nn.CrossEntropyLoss()

        self.train_accuracy = []
        self.val_accuracy = []
        self.test_accuracy = []
        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Shape:
            input: (b, c, h, w)
            output: (b, n_classes) = (b, 10)
        '''
        x = x.view(x.size(0), -1) # (b, c*h*w)
        return self.model(x) # (b, 10)

    def training_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = self.criteria(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('train_accuracy', accuracy, prog_bar=True)
        self.log('train_loss', loss, prog_bar=True)
        self.train_accuracy.append(accuracy)
        self.train_loss.append(loss)
        return loss

    def validation_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = self.criteria(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('val_accuracy', accuracy, prog_bar=True)
        self.log('val_loss', loss, prog_bar=True)
        self.val_accuracy.append(accuracy)
        self.val_loss.append(loss)
        return loss

    def test_step(self, batch, batch_idx: int):
        x, y = batch
        logits = self.forward(x)
        loss = self.criteria(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log('test_accuracy', accuracy, prog_bar=True)
        self.log('test_loss', loss, prog_bar=True)
        self.test_accuracy.append(accuracy)
        self.test_loss.append(loss)
        return loss

    def on_train_epoch_end(self) -> None:
        self.log('train_accuracy_epoch', torch.stack(self.train_accuracy).mean())
        self.log('train_loss_epoch', torch.stack(self.train_loss).mean())
        self.train_accuracy = []
        self.train_loss = []

    def on_validation_epoch_end(self) -> None:
        self.log('val_accuracy_epoch', torch.stack(self.val_accuracy).mean())
        self.log('val_loss_epoch', torch.stack(self.val_loss).mean())
        self.val_accuracy = []
        self.val_loss = []

    def on_test_epoch_end(self) -> None:
        self.log('test_accuracy_epoch', torch.stack(self.test_accuracy).mean())
        self.log('test_loss_epoch', torch.stack(self.test_loss).mean())
        self.test_accuracy = []
        self.test_loss = []

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learing_rate)


In [3]:
ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [5]:
BATCH_SIZE = 64
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
IMG_SIZE = (28, 28)
PATCH_SIZE = (2, 2)
PROJ_NAME = 'ViT-2-2_test_mnist'
BLOCKS = 4

def train(dataset, batch_size = BATCH_SIZE, epochs = EPOCHS, device = DEVICE, lr = LEARNING_RATE, img_size = IMG_SIZE, patch_size = PATCH_SIZE, proj_name = PROJ_NAME, blocks = BLOCKS):
    data_module = MyDataModule(dataset, img_size, batch_size=batch_size)
    data_module.prepare_data()

    model = ViTModule(
        img_size=img_size,
        patch_size=patch_size,
        n_channels=3,
        d_model=8,
        nhead=4,
        dim_feedforward=32,
        blocks=blocks,
        mlp_head_units=[32, 16],
        n_classes=data_module.n_classes(),
        learning_rate=lr,
    )

    logger = WandbLogger(project=proj_name,
                         config={'batch_size': BATCH_SIZE, 'epochs': EPOCHS, 'learning_rate': LEARNING_RATE, 'img_size': IMG_SIZE, 'patch_size': PATCH_SIZE})
    trainer = Trainer(
        default_root_dir='./models',
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        max_epochs=epochs,
        logger=logger,
        callbacks=[EarlyStopping(monitor='val_loss_epoch', patience=2)],
    )
    trainer.fit(model, data_module)
    trainer.test(model, data_module)

print("Using device:", DEVICE)

Using device: cuda


In [6]:
trans = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]),
        ])

ds = MNIST(root="data/", train=True, transform=trans, download=True)
# conv = nn.Conv2d(1, 768, kernel_size=(2, 2), stride=(2, 2))
# x = ds.__getitem__(0)[0].unsqueeze(0)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

model = ViT(nhead=4, dim_feedforward=16, blocks=4, mlp_head_units=[32, 16], n_classes=10, img_size=IMG_SIZE, patch_size=PATCH_SIZE, n_channels=1, d_model=8)
model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for _ in range(EPOCHS):
    for x, y in dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        logits = model(x)
        loss = nn.functional.cross_entropy(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        print("Loss:", loss.item(), "Accuracy:", accuracy.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print("Epoch:", _)


Loss: 2.323056221008301 Accuracy: 0.109375
Loss: 2.314821720123291 Accuracy: 0.15625
Loss: 2.31184720993042 Accuracy: 0.15625
Loss: 2.298128128051758 Accuracy: 0.15625
Loss: 2.3069207668304443 Accuracy: 0.140625
Loss: 2.314337968826294 Accuracy: 0.046875
Loss: 2.3000080585479736 Accuracy: 0.078125
Loss: 2.304555892944336 Accuracy: 0.125
Loss: 2.3200442790985107 Accuracy: 0.125
Loss: 2.319098472595215 Accuracy: 0.078125
Loss: 2.317134141921997 Accuracy: 0.109375
Loss: 2.295874834060669 Accuracy: 0.171875
Loss: 2.333125591278076 Accuracy: 0.09375
Loss: 2.3168447017669678 Accuracy: 0.09375
Loss: 2.309624195098877 Accuracy: 0.125
Loss: 2.317972421646118 Accuracy: 0.078125
Loss: 2.302276611328125 Accuracy: 0.15625
Loss: 2.3002095222473145 Accuracy: 0.09375
Loss: 2.303027629852295 Accuracy: 0.140625
Loss: 2.3104758262634277 Accuracy: 0.078125
Loss: 2.298619270324707 Accuracy: 0.046875
Loss: 2.3090264797210693 Accuracy: 0.078125
Loss: 2.3115570545196533 Accuracy: 0.09375
Loss: 2.2872774600982

In [19]:
ans = conv(x)

In [20]:
ans.shape

torch.Size([1, 768, 14, 14])

In [None]:
# cifar100_dm = MyDataModule(dataset_to_down=CIFAR100, img_size=IMG_SIZE, batch_size=BATCH_SIZE)
# cifar100_dm.prepare_data()
# n_classes = cifar100_dm.n_classes()
# wandb_logger = WandbLogger(project='ViTHybrid_test_cifar100',
#                            config={'batch_size': BATCH_SIZE, 'epochs': EPOCHS, 'learning_rate': LEARNING_RATE, 'img_size': IMG_SIZE, 'patch_size': PATCH_SIZE})
# vit_model = ViTModule(img_size=IMG_SIZE, patch_size=PATCH_SIZE, n_channels=3,
#                   n_classes=100, nhead=4, dim_feedforward=1024, blocks=BLOCKS,
#                   mlp_head_units=[1024, 512], d_model=512, learning_rate=LEARNING_RATE)
# vit_b_16_model = ViTPretrainedModule(model=vit_b_16(pretrained=True), learning_rate=LEARNING_RATE, n_classes=n_classes)
# hf_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=n_classes)
# hf_model = BitForImageClassification.from_pretrained('facebookresearch/bit-m-r101x1-ILSVRC2012')
# hf_model = BitForImageClassification.from_pretrained('google/bit-50')
# hf_model = ViTHybridForImageClassification.from_pretrained('google/vit-hybrid-base-bit-384')
# freeze all layers except the final layer
# for param in hf_model.bit.parameters():
#     param.requires_grad = False
# hf_model.classifier = nn.Linear(hf_model.classifier.in_features, n_classes)
# hf_vit_b_16_model = ViTPretrainedModule(model=hf_model, learning_rate=LEARNING_RATE, source='huggingface')
# linear_model = LinearModule(img_size=IMG_SIZE, n_channels=1, n_classes=10, learning_rate=LEARNING_RATE)

In [None]:
# train vit
# trainer = Trainer(max_epochs=EPOCHS,
#                   default_root_dir='./models',
#                   accelerator='gpu' if torch.cuda.is_available() else 'cpu',
#                   callbacks=[EarlyStopping(monitor='val_loss_epoch')],
#                   logger=wandb_logger)
# trainer.fit(hf_vit_b_16_model, cifar100_dm)

In [None]:
# trainer.test(hf_vit_b_16_model, cifar100_dm)

In [None]:
wandb.finish()