<a href="https://colab.research.google.com/github/ell-hol/stonks-wid-codex/blob/main/codex_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import Variable


class Attention(nn.Module):
    def __init__(self, dim: int) -> None:
        super(Attention, self).__init__()

        self._dim = dim

    def forward(
        self,
        query: Tensor,
        keys: Tensor,
        values: Tensor,
        mask: Tensor = None
    ) -> Tensor:

        # score = Softmax(Q.K / sqrt(dim)) * V
        score = torch.matmul(query, keys.transpose(1, 2))
        score = score / (self._dim ** 0.5)

        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)

        score = F.softmax(score, dim=-1)

        return torch.matmul(score, values)


class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        in_dim: int,
        head_count: int,
        dim_per_head: int,
        dropout: float = 0.0
    ) -> None:

        super(MultiHeadAttention, self).__init__()

        self._query_linear = nn.Linear(in_dim, head_count * dim_per_head)
        self._keys_linear = nn.Linear(in_dim, head_count * dim_per_head)
        self._value_linear = nn.Linear(in_dim, head_count * dim_per_head)

        self._multi_head_attention = Attention(dim_per_head)
        self._linear_final = nn.Linear(head_count * dim_per_head, in_dim)
        self._dropout = nn.Dropout(dropout)
        self._layer_norm = nn.LayerNorm(in_dim)

    def forward(
        self,
        query: Tensor,
        keys: Tensor,
        values: Tensor,
        mask: Tensor = None
    ) -> Tensor:

        batch_size = query.size(0)

        query = self._query_linear(query).view(batch_size, -1, self._multi_head_attention._dim)
        keys = self._keys_linear(keys).view(batch_size, -1, self._multi_head_attention._dim)
        values = self._value_linear(values).view(batch_size, -1, self._multi_head_attention._dim)

        query = query.transpose(0, 1)
        keys = keys.transpose(0, 1)
        values = values.transpose(0, 1)

        if mask is not None:
            mask = mask.unsqueeze(1)

        context = self._multi_head_attention(query, keys, values, mask)
        context = context.transpose(0, 1)

        output = self._linear_final(context)
        output = self._dropout(output)

        # residual connection
        output = self._layer_norm(query + output)

        return output

In [12]:
"""
Define a Transformer model that uses the defined multi-headed attention module
"""
class Transformer(nn.Module):
    def __init__(
        self,
        n_vocab: int,
        dim_embedding: int,
        dim_model: int,
        dim_ff: int,
        head_count: int,
        n_layers: int,
        max_len: int,
        dropout: float = 0.0
    ) -> None:

        super(Transformer, self).__init__()

        self._n_vocab = n_vocab
        self._dim_embedding = dim_embedding
        self._dim_model = dim_model
        self._dim_ff = dim_ff
        self._head_count = head_count
        self._n_layers = n_layers
        self._max_len = max_len

        self._embedding = nn.Embedding(n_vocab, dim_embedding)

        self._postion_embedding = nn.Embedding(self._max_len, dim_model)
        self._postion_embedding.weight.data = self.position_encoding_init(self._max_len, self._dim_model)

        self._encoder = nn.Sequential(
            *[
                Encoder(dim_model, dim_ff, head_count, dropout, dim_per_head=16)
                for _ in range(self._n_layers)
            ]
        )

        self._decoder = nn.Sequential(
            *[
                Decoder(dim_model, dim_ff, head_count, dropout)
                for _ in range(self._n_layers)
            ]
        )

        self._linear = nn.Linear(dim_model, n_vocab)

        self._layer_norm = nn.LayerNorm(dim_model)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    @staticmethod
    def position_encoding_init(n_position: int, dim_pos_embed: int) -> Tensor:

        encoding = np.array([
            [pos / np.power(10000, 2.0 * (j // 2) / dim_pos_embed) for j in range(dim_pos_embed)]
            if pos != 0 else np.zeros(dim_pos_embed) for pos in range(n_position)
        ])

        encoding[1:, 0::2] = np.sin(encoding[1:, 0::2])
        encoding[1:, 1::2] = np.cos(encoding[1:, 1::2])

        return torch.tensor(encoding.astype('float32'))

    def mask_pad(self, x: Tensor) -> Tensor:

        return (x != 0).unsqueeze(-2)

    def forward(self, x: Tensor, mask: Tensor) -> Tensor:

        batch_size, _ = x.size()

        pos = torch.arange(0, self._max_len).repeat(batch_size, 1)
        mask_source = mask.clone().fill_(0)

        x = self._embedding(x) + self._postion_embedding(pos)

        for encoder in self._encoder:
            x = encoder(x, mask, mask_source)

        mask = mask_source

        for decoder in self._decoder:
            x = decoder(x, mask)

        return self._linear(self._layer_norm(x))

In [13]:
"""
Define the Encoder classe used in the Transformer class.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class Encoder(nn.Module):
    def __init__(
        self,
        dim_model: int,
        dim_per_head: int,
        dim_ff: int,
        head_count: int,
        dropout: float = 0.0
    ) -> None:

        super(Encoder, self).__init__()

        self._self_attention = MultiHeadAttention(dim_model, head_count, dim_per_head=dim_per_head)
        self._feed_forward = PositionwiseFeedForward(dim_model, dim_ff, dropout)
        self._layer_norm1 = nn.LayerNorm(dim_model)
        self._layer_norm2 = nn.LayerNorm(dim_model)
        self._dropout1 = nn.Dropout(dropout)
        self._dropout2 = nn.Dropout(dropout)

    def forward(
        self,
        x: Tensor,
        mask: Tensor,
        mask_source: Tensor = None
    ) -> Tensor:

        query = self._layer_norm1(x)
        query = self._self_attention(query, query, query, mask)
        query = self._dropout1(query)
        query += x

        query_ff = self._layer_norm2(query)
        query_ff = self._feed_forward(query_ff)
        query_ff = self._dropout2(query_ff)
        query += query_ff

        return query
"""
Do the same for the Decoder class.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class Decoder(nn.Module):
    def __init__(
        self,
        dim_model: int,
        dim_ff: int,
        head_count: int,
        dropout: float = 0.0
    ) -> None:

        super(Decoder, self).__init__()

        self._self_attention = MultiHeadAttention(dim_model, head_count)
        self._encoder_decoder_attention = MultiHeadAttention(dim_model, head_count)
        self._feed_forward = PositionwiseFeedForward(dim_model, dim_ff, dropout)
        self._layer_norm1 = nn.LayerNorm(dim_model)
        self._layer_norm2 = nn.LayerNorm(dim_model)
        self._layer_norm3 = nn.LayerNorm(dim_model)
        self._dropout1 = nn.Dropout(dropout)
        self._dropout2 = nn.Dropout(dropout)
        self._dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        x: Tensor,
        mask: Tensor,
        mask_source: Tensor = None
    ) -> Tensor:

        query = self._layer_norm1(x)
        query = self._self_attention(query, query, query, mask)
        query = self._dropout1(query)
        query += x

        x = self._layer_norm2(query)
        x = self._encoder_decoder_attention(x, mask_source, mask_source)
        x = self._dropout2(x)
        x += query

        query_ff = self._layer_norm3(x)
        query_ff = self._feed_forward(query_ff)
        query_ff = self._dropout3(query_ff)
        x += query_ff

        return x


In [14]:
import torch
import torchvision
from torch import Tensor
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm


def accuracy(y_true: Tensor, y_pred: Tensor) -> float:

    y_true = y_true.cpu().numpy()
    y_pred = y_pred.cpu().numpy()

    return np.sum(y_true == y_pred) / len(y_true)


def precision(y_true: Tensor, y_pred: Tensor, average: str = 'macro') -> float:

    y_true = y_true.cpu().numpy()
    y_pred = y_pred.cpu().numpy()

    return sklearn.metrics.precision_score(y_true, y_pred, average=average)


def recall(y_true: Tensor, y_pred: Tensor, average: str = 'macro') -> float:

    y_true = y_true.cpu().numpy()
    y_pred = y_pred.cpu().numpy()

    return sklearn.metrics.recall_score(y_true, y_pred, average=average)


def f1(y_true: Tensor, y_pred: Tensor, average: str = 'macro') -> float:

    p = precision(y_true, y_pred, average=average)
    r = recall(y_true, y_pred, average=average)

    return 2 * p * r / (p + r)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor()
])

dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    transform=transforms,
    download=True
)

loader = DataLoader(
    dataset=dataset,
    batch_size=32,
    num_workers=4,
    shuffle=True
)

model = Transformer(
    n_vocab=10,
    dim_embedding=512,
    dim_model=512,
    dim_ff=2048,
    head_count=8,
    n_layers=6,
    max_len=32,
    dropout=0.1
).to(device)

criterion = CrossEntropyLoss()
optimizer = Adam(params=model.parameters())

epochs = 100

for epoch in tqdm(range(epochs)):

    avg_loss = 0
    avg_acc = 0
    avg_precision = 0
    avg_recall = 0
    avg_f1 = 0

    for batch in loader:

        optimizer.zero_grad()

        x = batch[0].to(device)
        y = batch[1].to(device)

        mask = model.mask_pad(x)
        y_pred = model(x, mask)

        loss = criterion(y_pred.view(-1, y_pred.size(-1)), y.view(-1))
        loss.backward()

        optimizer.step()

        avg_loss += loss.item() / len(loader)

        with torch.no_grad():

            y = y.cpu().numpy()
            y_pred = torch.argmax(y_pred, axis=-1).cpu().numpy()

            avg_acc += accuracy(y, y_pred) / len(loader)
            avg_precision += precision(y, y_pred) / len(loader)
            avg_recall += recall(y, y_pred) / len(loader)
            avg_f1 += f1(y, y_pred) / len(loader)

    print('[%d / %d] loss: %.4f acc: %.4f p: %.4f r: %.4f f1: %.4f' % (epoch + 1, epochs, avg_loss, avg_acc, avg_precision, avg_recall, avg_f1))


Files already downloaded and verified


  cpuset_checked))


TypeError: ignored