In [1]:
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import List
from torch import flatten, randn, cat
from torch import Tensor
from torch.nn import Parameter
from torch.nn import Module
from torch.nn import Conv2d
from torch.nn import Linear
from torch.nn import GELU
from torch.nn import Dropout
from torch.nn import Sequential
from torch.nn import LayerNorm
from torch.nn import MultiheadAttention
from torch.nn import ModuleList


class ImageEmbeddings(Module):
    def __init__(self, image_width: int, image_height: int, input_channels: int, patch_size: int, model_dimension: int):
        super().__init__()
        self.number_of_patches = (image_width *  image_height // patch_size) ** 2        
        self.projector = Conv2d(input_channels, model_dimension, kernel_size=patch_size, stride=patch_size)

    def forward(self, input: Tensor) -> Tensor:
        output = self.projector(input)
        return flatten(output, 2).transpose(1, 2)


class CLSToken(Module):
    def __init__(self, model_dimension: int):
        super().__init__()
        self.token = Parameter(randn(1, 1, model_dimension))

    def forward(self, input: Tensor) -> Tensor:
        batch_size = input.shape[0]
        token = self.token.expand(batch_size, -1, -1)
        return cat([token, input], dim=1)


class LearnablePositionalEncoding(Module):
    def __init__(self, model_dimension: int, number_of_patches: int):
        super().__init__()
        self.position_embeddings = Parameter(randn(1, number_of_patches + 1, model_dimension))

    def forward(self, input: Tensor) -> Tensor:
        input = input + self.position_embeddings
        return input
    

class Encoder(Module):
    def __init__(self, model_dimension: int, hidden_dimension: int, number_of_heads: int, dropout: float):
        super().__init__()
        self.attention = MultiheadAttention(model_dimension, number_of_heads, dropout=dropout)
        self.first_layer_normalization = LayerNorm(model_dimension)
        self.second_layer_normalization = LayerNorm(model_dimension)
        self.mlp = Sequential(
            Linear(model_dimension, hidden_dimension),
            GELU(),
            Dropout(dropout),
            Linear(hidden_dimension, model_dimension),
            Dropout(dropout)
        )

    def forward(self, input: Tensor, need_weights: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
        output = self.first_layer_normalization(input)
        attention, attention_weights = self.attention(output, output, output, need_weights=need_weights)
        output = output + attention
        output = self.second_layer_normalization(output)
        output = output + self.mlp(output)
        return output, attention_weights
    
@dataclass
class Settings:
    image_width: int
    image_height: int
    patch_size: int
    input_channels: int
    model_dimension: int
    hidden_dimension: int
    number_of_heads: int
    number_of_layers: int
    dropout: float

class ViTClassifier(Module):
    def __init__(self, settings: Settings, output_classes: int):
        super().__init__()
        self.image_embeddings = ImageEmbeddings(settings.image_width, settings.image_height, settings.input_channels, settings.patch_size, settings.model_dimension)
        self.cls_token = CLSToken(settings.model_dimension)
        self.positional_encoding = LearnablePositionalEncoding(settings.model_dimension, self.image_embeddings.number_of_patches)
        self.encoders = ModuleList([Encoder(settings.model_dimension, settings.hidden_dimension, settings.number_of_heads, settings.dropout) for layer in range(settings.number_of_layers)])
        self.classification_head = Linear(settings.model_dimension, output_classes)

    def forward(self, input: Tensor, need_weights: bool = False) -> Tuple[Tensor, Optional[List[Tensor]]]:
        output = self.image_embeddings(input)
        output = self.cls_token(output)
        output = self.positional_encoding(output)
        attention_weights = []
        for encoder in self.encoders:
            output, weight = encoder(output, need_weights=need_weights)
            if need_weights:
                attention_weights.append(weight)

        logits = self.classification_head(output[:, 0])
        return logits, attention_weights

In [None]:
import os
from matplotlib.pyplot import figure, show, savefig
from matplotlib.axes import Axes
from uuid import UUID, uuid4
from torch.utils.tensorboard import SummaryWriter
from typing import Protocol, Optional, Dict
from csv import DictWriter
from torch.optim import Optimizer
from utils import train, test, Criterion, Data


class Writer(Protocol):
    def add_scalar(self, tag: str, scalar_value: float, global_step: int):
        ...
    
class Metrics:
    def __init__(self, writer: Optional[Writer] = None):
        self.writer = writer
        self.history = {
            'loss': [],
            'accuracy': [],
        }
        self.epoch = 0

    def start(self, mode: str):
        self.mode = mode
        self.epoch += 1
        self.batch = 0
        self.loss = 0
        self.accuracy = 0

    def update(self, batch: int, loss: float, accuracy: float):
        self.batch = batch
        self.loss += loss
        self.accuracy += accuracy
    
    def stop(self):
        self.loss /= self.batch
        self.accuracy /= self.batch
        self.history['loss'].append(self.loss)
        self.history['accuracy'].append(self.accuracy)
        print(f'Processed {self.batch} batches, average loss: {self.loss:.4f}, average accuracy: {self.accuracy:.4f}, in epoch {self.epoch} for {self.mode} mode')

        if self.writer:
            self.writer.add_scalar(f'{self.mode}/loss', self.loss, self.epoch)
            self.writer.add_scalar(f'{self.mode}/accuracy', self.accuracy, self.epoch)

    def write_to_csv(self, filename: str):
        with open(filename, 'w', newline='') as csvfile:
            fieldnames = ['epoch', 'loss', 'accuracy']
            writer = DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for epoch, (loss, accuracy) in enumerate(zip(self.history['loss'], self.history['accuracy']), start=1):
                writer.writerow({'epoch': epoch, 'loss': loss, 'accuracy': accuracy})


class Summary:
    def __init__(self, name: str = None, id: UUID = None) -> None:
        self.id = id or uuid4()
        self.name = name or 'model'
        self.metrics = {
            'train': Metrics(),
            'test': Metrics()
        }

    def open(self):
        self.writer = SummaryWriter(log_dir=f'logs/{self.name}-{self.id}')
        self.metrics['train'].writer = self.writer
        self.metrics['test'].writer = self.writer
        print(f"Running experiment {self.name} with id {self.id}")
        print(f"Tensorboard logs are saved in logs/{self.name}-{self.id}")
        print(f"Run tensorboard with: tensorboard --logdir=logs/")
        print(f"Open browser and go to: http://localhost:6006/")
        print(f"----------------------------------------------------------------")

    def close(self):
        print(f"Experiment {self.name} with id {self.id} completed")
        print(f"#### Results for {self.name}:")
        print(f"- Average loss: {self.metrics['train'].loss:.4f} (train), {self.metrics['test'].loss:.4f} (test)")
        print(f"- Average accuracy: {self.metrics['train'].accuracy:.4f} (train), {self.metrics['test'].accuracy:.4f} (test)")
        print(f"----------------------------------------------------------------")
        
        path = f"{'results/'}{self.name}-{self.id}.csv"
        self.metrics['train'].write_to_csv(path.replace('.csv', '-train.csv'))
        self.metrics['test'].write_to_csv(path.replace('.csv', '-test.csv'))
        
        self.writer.close()
 
    def add_text(self, tag: str, text: str):
        with open(f'parameters/{self.name}-{self.id}.txt', 'a') as f:
            f.write(f'{tag}: {text}\n')     

        if self.writer:
            self.writer.add_text(tag, text)

        print(f'{tag}: {text}')
        print(f"----------------------------------------------------------------")

def plot(metrics: Dict[str, Metrics], metric: str, ax: Optional[Axes] = None):
    if ax is None:
        plot = figure()
        ax = plot.add_subplot()

    for key, value in metrics.items():
        ax.plot(value.history[metric], label=key)
    
    ax.legend()
    ax.set_title(metric)
    ax.set_xlabel('epoch')
    ax.set_ylabel(metric)

    if ax is None:
        show()


def run(model: Module, optimizer: Optimizer, criterion: Criterion, device: str, data: Dict[str, Data], summary: Summary, epochs: int = 30):
    summary.open()
    summary.add_text('model', str(model))
    summary.add_text('optimizer', str(optimizer))
    summary.add_text('criterion', str(criterion))

    for epoch in range(epochs):
        train(model, criterion, optimizer, data['train'], summary.metrics['train'], device)
        test(model, criterion, data['test'], summary.metrics['test'], device)

    summary.close()

    metrics_plot = figure(figsize=(10, 5))
    metrics_plot.suptitle(f'{summary.name}')
    ax = metrics_plot.add_subplot(1, 2, 1)
    plot(summary.metrics, 'loss', ax)

    ax = metrics_plot.add_subplot(1, 2, 2)
    plot(summary.metrics, 'accuracy', ax)

    if not os.path.exists('./plots'):
        os.makedirs('./plots')

    savefig(f'plots/{summary.name}-{summary.id}.png', bbox_inches ="tight" )
    show()