In [317]:
from typing import Union

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

In [318]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, nhead, pdrop):
        super(MultiHeadAttention, self).__init__()

        self.nhead = nhead

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.c_proj = nn.Linear(embed_dim, embed_dim)
        self.drop = nn.Dropout(pdrop)

    def forward(self, x, attn_mask=None):
        _, t, _ = x.shape

        qkv = self.qkv(x).chunk(3, dim=-1)

        q, k, v = map(
            lambda tsn: rearrange(tsn, 'b t (nh hd) -> b nh t hd', nh=self.nhead), qkv
        )

        wei = q @ k.transpose(-1, -2) * k.size(dim=-1) ** -0.5
        if attn_mask is not None:
            wei = wei.masked_fill(attn_mask[:t, :t] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        attn = wei @ v
        attn = rearrange(attn, 'b nh t hd -> b t (nh hd)')

        return self.drop(self.c_proj(attn))

In [319]:
class FFN(nn.Module):
    def __init__(self, embed_dim, pdrop):
        super(FFN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(pdrop),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self,
                 # Common
                 embed_dim,
                 # MHA
                 nhead,
                 attn_pdrop,
                 # FFN
                 ffn_pdrop
                 ):
        super(Block, self).__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, nhead, attn_pdrop)
        self.ffn = FFN(embed_dim, ffn_pdrop)

    def forward(self, x, attn_mask=None):
        x = x + self.attn(self.ln1(x), attn_mask=attn_mask)
        x = x + self.ffn(self.ln2(x))
        return x

In [320]:
class TransformerLanguageModel(nn.Module):
    def __init__(self,
                 # LM
                 vocab_size,
                 ctx_len,
                 # Common
                 embed_dim,
                 # Block
                 nlayer,
                 nhead,
                 attn_mask=None,
                 attn_pdrop=0.2,
                 ffn_pdrop=0.2,
                 # Output
                 output_dim: int=None
                 ):
        super(TransformerLanguageModel, self).__init__()
        # Properties
        self.ctx_len = ctx_len
        if attn_mask is None:
            attn_mask = torch.tril(torch.ones(ctx_len, ctx_len))
        self.attn_mask = attn_mask
        self.understanding = output_dim is not None

        # Comps
        self.tok_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(ctx_len, embed_dim)

        self.ln_pre = nn.LayerNorm(embed_dim)
        self.blocks = nn.ModuleList(
            [Block(embed_dim, nhead, attn_pdrop, ffn_pdrop) for _ in range(nlayer)]
        )
        self.ln_post = nn.LayerNorm(embed_dim)

        self.lm_head = nn.Linear(embed_dim, vocab_size if output_dim is None else output_dim)

    def generate(self, idx, attn_mask=None, max_new_tokens=100, temp=1.0, topk=None):
        if self.understanding:
            raise NotImplementedError('Not for generating')
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.ctx_len:]
            logits = self(idx_cond, attn_mask)
            logits_of_last_tick = logits[:, -1, :] / temp
            if topk is not None:
                vals, _ = torch.topk(logits_of_last_tick, topk)
                logits_of_last_tick[logits_of_last_tick < vals[-1]] = float('-inf')
            if temp == 0:
                idx_next = torch.argmax(logits_of_last_tick, dim=-1, keepdim=True)
            else:
                probs = logits_of_last_tick.softmax(dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1) # cat on 'time' dimension
        return idx

    def forward(self, idx, attn_mask=None):
        attn_mask = attn_mask if attn_mask is not None else self.attn_mask

        _, t = idx.shape

        tok_emb = self.tok_emb(idx)
        pos_emb = self.pos_emb(
            torch.arange(0, t, dtype=torch.long, device=idx.device)
        )

        x = tok_emb + pos_emb

        x = self.ln_pre(x)
        for block in self.blocks:
            x = block(x, attn_mask=attn_mask)
        x = self.ln_post(x)

        logits = self.lm_head(x)

        if self.understanding:
            logits = logits[torch.arange(logits.shape[0]), idx.argmax(dim=-1)]

        return logits

In [321]:
def pair(sz: Union[int, tuple]):
    if isinstance(sz, int):
        return sz, sz
    else:
        return sz

class VisionTransformer(nn.Module):
    def __init__(self,
                 # Vision
                 img_size: Union[int, tuple],
                 patch_num: Union[int, tuple],
                 use_cls_tok: bool,
                 # Transformer
                 nlayer: int,
                 embed_dim: int,
                 nhead: int,
                 attn_pdrop: float,
                 ffn_pdrop: float,
                 # Output
                 num_classes: int
                 ):
        super(VisionTransformer, self).__init__()
        # Properties
        img_height, img_width = pair(img_size)
        patch_height_n, patch_width_n = pair(patch_num)
        scale = embed_dim ** -0.5

        self.patch_height, self.patch_width = img_height // patch_height_n, img_width // patch_width_n
        self.use_cls_tok = use_cls_tok

        # Comps
        self.to_patch = nn.Conv2d(
            in_channels=3,
            out_channels=embed_dim,
            kernel_size=(patch_height_n, patch_width_n),
            stride=(patch_height_n, patch_width_n),
        )

        self.pos_emb = nn.Parameter(scale * torch.randn(1, self.patch_height * self.patch_width + 1, embed_dim))
        self.cls_emb = nn.Parameter(scale *  torch.randn(embed_dim))

        self.ln_pre = nn.LayerNorm(embed_dim)
        self.blocks = nn.ModuleList(
            [Block(embed_dim, nhead, attn_pdrop, ffn_pdrop) for _ in range(nlayer)]
        )
        self.ln_post = nn.LayerNorm(embed_dim)

        self.proj = nn.Parameter(scale * torch.randn(embed_dim, num_classes))

    def forward(self, x):
        patches = self.to_patch(x)

        idx = rearrange(
            patches, 'b c (ph h) (pw w) -> b (ph pw) (c h w)', ph=self.patch_height, pw=self.patch_width
        )


        idx = torch.cat(
            (
                idx,
                repeat(self.cls_emb, 'e -> b 1 e', b=idx.shape[0])
            ), dim=1
        )
        x = idx + self.pos_emb


        x = self.ln_pre(x)
        for block in self.blocks:
            x = block(x)

        if self.use_cls_tok:
            x = x[:, 0, :]
        else:
            x = x[:, 1:, :].mean(dim=1)

        x = self.ln_post(x)

        return x @ self.proj

In [322]:
class CLIP(nn.Module):
    def __init__(self,
                 img_encoder,
                 text_encoder,
                 ):
        super(CLIP, self).__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.img_encoder = img_encoder
        self.text_encoder = text_encoder

    def encode_img(self, img):
        return self.img_encoder(img)

    def encode_text(self, text):
        return self.text_encoder(text)

    def forward(self, imgs, texts):
        img_features = self.encode_img(imgs)
        text_features = self.encode_text(texts)

        if img_features.shape != text_features.shape:
            raise ValueError(f'img_features shape {img_features.shape} != text_features shape {text_features.shape}')

        img_features = img_features / img_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        logits_per_img = img_features @ text_features.t() * self.logit_scale.exp()
        logits_per_text = logits_per_img.t()

        return logits_per_img, logits_per_text