In [2]:
import math
from typing import Tuple 

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import dataset

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 1, img_size: int = 32, patch_size: int = 16, embed_dim: int = 768):
        super().__init__()
        self.patch_size = patch_size
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros((img_size // patch_size) **2 + 1, embed_dim))
        self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
        
    def patchify(self, images: Tensor, patch_size: int):
        n, h, w = images.shape
        assert h == w 
        n_patches = h // patch_size
        patches = torch.zeros(n, n_patches ** 2, patch_size ** 2)
#         print(f'patches  size is {patch_size}')
        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
#                     print(f'patch shape is: {patch.shape}')
#                     print(f'flattened patch shape is: {patch.flatten().shape}')
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches
        
    def forward(self, x: Tensor):
        x = self.patchify(x, self.patch_size)
#         print(x.shape)
        x = self.proj(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed
        return x

In [4]:
class TransformerModel(nn.Module):
    
    def __init__(self, d_model: int, nhead: int, d_hid: int,
                    nlayers: int, num_tokens: int, in_channels: int = 1,
                    img_size: int = 384, patch_size: int = 16, embed_dim: int = 384):
        super().__init__()
        self.d_model = d_model
        self.model_type = 'Transformer'
        self.patch_embed = PatchEmbedding(in_channels=1, img_size=img_size, 
                                          patch_size=patch_size, 
                                          embed_dim=embed_dim)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        decoder_layers = TransformerDecoderLayer(d_model, nhead, d_hid)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.linear = nn.Linear(d_model, num_tokens)
        self.out = nn.Softmax(dim=2)
        self.init_weights()
        
    def init_weights(self) -> None:
        init_range = 0.1
#         self.encoder.weight.data.uniform_(-init_range, init_range)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-init_range, init_range)
        
    def forward(self, x: Tensor, masked_y: Tensor) -> Tensor:
        x = self.patch_embed(x)
        x = self.transformer_encoder(x)
        x = self.transformer_decoder(masked_y, x)
        x = self.linear(x)
        print(ff x.shape)
        x = self.out(x)
        return x

def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [6]:
model = TransformerModel(d_model=384, nhead=6, d_hid=8, 
                         nlayers=12, num_tokens=10,
                         in_channels=1, img_size=384,
                         patch_size=16, embed_dim=384)
x, y = torch.randn(7, 384, 384), torch.randn(7, 36)
model(x, y).shape

AssertionError: was expecting embedding dimension of 384, but got 36