In [1]:
import math
from typing import Tuple 

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

In [3]:
class PatchEmbeddings(nn.Module):
    def __init__(self, in_channels: int = 1, patch_size: int = 16, embed_dim: int = 768):
        super().__init__()
        self.patch_size = patch_size
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_size, embed_dim))
        self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
        
    def patchify(images, n_patches):
        n, c, h, w = images.shape
        assert h == w 
        patches = toroch.zeros(n, n_patches ** 2, h * w // n_patches ** 2)
        patch_size = h // n_patches

        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1)]
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches
    
    def forward(self, x: Tensor):
        x = patchify(x, self.patch_size)
        x = self.proj(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed
        return x

In [None]:
class TransformerModel(nn.Module):
    
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                    nlayers: int):
        super().__init__()
        self.model_type = 'Transformer'
        self.patch_embed = PatchEmbedding()
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)
        self.init_weights()
        
    def init_weights(self) -> None:
        init_range = 0.1
        self.encoder.weight.data.uniform_(-init_range, init_range)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-init_range, init_range)
        
    def forward(self, images: Tensor) -> Tensor:
        src = self.patch_embed(images)
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output
    
    def generate_square_subsequent_mask(sz: int) -> Tensor:
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    

