In [None]:
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
from torch import nn

In [11]:
transform = Compose([ToTensor(), Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

batch_size = 128

trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)

trainset, valset = torch.utils.data.random_split(trainset, [40000, 10000])

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, persistent_workers=True)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2, persistent_workers=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [12]:
class PatchEmbeddings(pl.LightningModule):
    def __init__(self, in_channels = 3, patch_size = 8, embedding_dim = 128):
        super().__init__()
        self.unfolding = nn.Unfold(kernel_size = patch_size, stride = patch_size)
        self.projection = nn.Linear(in_channels * patch_size ** 2 , embedding_dim)
                
    def forward(self,x):
        x = self.unfolding(x) # H * W * C -> N * ( P * P * C)
        x = x.transpose(1, 2) # N * ( P * P * C) -> N * ( P * P * C)
        x = self.projection(x) # N * ( P * P * C) -> N * E
        return x

In [13]:
class MultiHeadAttention(pl.LightningModule):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        assert dim % n_heads == 0 # dim must be divisible by n_heads
        self.head_dim = dim // n_heads
        
        self.K = nn.Linear(in_features=dim, out_features=dim)
        self.Q = nn.Linear(in_features=dim, out_features=dim)
        self.V = nn.Linear(in_features=dim, out_features=dim)

        self.out_proj = nn.Linear(in_features=dim, out_features=dim)
        
    def forward(self, x):
        K = self.K(x)
        Q = self.Q(x)
        V = self.V(x)
        
        K = K.view(x.shape[0], x.shape[1], self.n_heads, self.head_dim).transpose(1, 2)
        Q = Q.view(x.shape[0], x.shape[1], self.n_heads, self.head_dim).transpose(1, 2)
        V = V.view(x.shape[0], x.shape[1], self.n_heads, self.head_dim).transpose(1, 2)
        
        attention = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention = torch.nn.functional.softmax(attention, dim=-1)
        x = torch.matmul(attention, V)
        
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.dim)
        x = self.out_proj(x)
        
        return x

In [14]:
class MLP(pl.LightningModule):
    def __init__(self, in_features, hidden_features, out_features, dropout=0.1, activation=nn.GELU):
        super().__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=hidden_features)
        self.fc2 = nn.Linear(in_features=hidden_features, out_features=out_features)
        self.dropout = nn.Dropout(p=dropout)
        self.activation = activation()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [15]:
class TransformerEncoder(pl.LightningModule):
    def __init__(self, dim, n_heads, dropout=0.1):
        super().__init__()
        self.ln_pre_attn = nn.LayerNorm(dim)
        self.attention = MultiHeadAttention(dim, n_heads)
        self.ln_pre_ffn = nn.LayerNorm(dim)
        self.ffn = MLP(dim, dim * 4, dim, dropout)
        
    def forward(self, x):
        x = x + self.attention(self.ln_pre_attn(x))
        x = x + self.ffn(self.ln_pre_ffn(x))
        return x

In [16]:
class ViT(pl.LightningModule):
    def __init__(self, in_channels = 3, patch_size = 4, embedding_dim = 256, n_blocks = 6 , n_heads = 8, out_dim = 10, dropout = 0.1):
        super().__init__()
        
        # Patch Embeddings
        self.patch_embeddings = PatchEmbeddings(in_channels = in_channels, patch_size = patch_size, embedding_dim = embedding_dim)
        
        # Positional Embeddings
        num_patches = (32 // patch_size) ** 2
        self.positional_embeddings = nn.Parameter(torch.randn(1, 1 + num_patches, embedding_dim))
        
        # Class Token
        self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
        # Transformer Encoder
        self.transformer_encoder = nn.Sequential(*[TransformerEncoder(embedding_dim, n_heads, dropout) for _ in range(n_blocks)])
        
        # MLP Head
        self.mlp_head = MLP(embedding_dim, embedding_dim * 4, out_dim)
        
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embeddings(x)
        x = torch.cat([self.class_token.expand(B, -1, -1), x], dim=1)
        x = x + self.positional_embeddings
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.mlp_head(x)
        return x
    
    def _shared_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss(label_smoothing=0.1)(y_hat, y)
        return loss, y_hat, y
    
    def training_step(self, batch, batch_idx):
        loss,y_hat, y = self._shared_step(batch, batch_idx)
        preds = torch.argmax(y_hat, dim=1)
        acc = (preds == y).float().mean()
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        self.log('train_accuracy', acc, prog_bar=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, y_hat, y = self._shared_step(batch, batch_idx)
        preds = torch.argmax(y_hat, dim=1)
        acc = (preds == y).float().mean()
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        self.log('val_accuracy', acc, prog_bar=True, on_epoch=True)
        return {'val_loss': loss, 'val_accuracy': acc}
    
    def test_step(self, batch, batch_idx):
        loss, y_hat, y = self._shared_step(batch, batch_idx)
        preds = torch.argmax(y_hat, dim=1)
        acc = (preds == y).float().mean()
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_accuracy', acc, prog_bar=True)
        return {'test_loss': loss, 'test_accuracy': acc}
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]
        

In [17]:
model = ViT()

In [18]:
logger = TensorBoardLogger("lightning_logs", name="ViT")
trainer = Trainer(
    max_epochs=50,                  
    logger=logger,                  
)

trainer.fit(model=model, train_dataloaders=trainloader, val_dataloaders=valloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                | Type            | Params
--------------------------------------------------------
0 | patch_embeddings    | PatchEmbeddings | 12.5 K
1 | transformer_encoder | Sequential      | 4.7 M 
2 | mlp_head            | MLP             | 273 K 
  | other params        | n/a             | 16.9 K
--------------------------------------------------------
5.0 M     Trainable params
0         Non-trainable params
5.0 M     Total params
20.166    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/Users/eithannakache/Desktop/ViT/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
