In [221]:
from functools import lru_cache
from typing import Union

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from pandas import Series

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets

from einops import rearrange, repeat
from transformers import AutoTokenizer

In [222]:
import os

os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

In [223]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, nhead, pdrop):
        super(MultiHeadAttention, self).__init__()

        self.nhead = nhead

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.c_proj = nn.Linear(embed_dim, embed_dim)
        self.drop = nn.Dropout(pdrop)

    def forward(self, x, attn_mask=None):
        _, t, _ = x.shape

        qkv = self.qkv(x).chunk(3, dim=-1)

        q, k, v = map(
            lambda tsn: rearrange(tsn, 'b t (nh hd) -> b nh t hd', nh=self.nhead), qkv
        )

        wei = q @ k.transpose(-1, -2) * k.size(dim=-1) ** -0.5
        if attn_mask is not None:
            attn_mask = attn_mask.to(x.device)
            wei = wei.masked_fill(attn_mask[:t, :t] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        attn = wei @ v
        attn = rearrange(attn, 'b nh t hd -> b t (nh hd)')

        return self.drop(self.c_proj(attn))

In [224]:
class FFN(nn.Module):
    def __init__(self, embed_dim, pdrop):
        super(FFN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(pdrop),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self,
                 # Common
                 embed_dim,
                 # MHA
                 nhead,
                 attn_pdrop,
                 # FFN
                 ffn_pdrop
                 ):
        super(Block, self).__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, nhead, attn_pdrop)
        self.ffn = FFN(embed_dim, ffn_pdrop)

    def forward(self, x, attn_mask=None):
        x = x + self.attn(self.ln1(x), attn_mask=attn_mask)
        x = x + self.ffn(self.ln2(x))
        return x

In [225]:
class TransformerLanguageModel(nn.Module):
    def __init__(self,
                 # LM
                 vocab_size,
                 ctx_len,
                 # Common
                 embed_dim,
                 # Block
                 nlayer,
                 nhead,
                 attn_mask=None,
                 attn_pdrop=0.2,
                 ffn_pdrop=0.2,
                 # Output
                 output_dim: int=None
                 ):
        super(TransformerLanguageModel, self).__init__()
        # Properties
        self.ctx_len = ctx_len
        if attn_mask is None:
            attn_mask = torch.tril(torch.ones(ctx_len, ctx_len))
        self.attn_mask = attn_mask
        self.understanding = output_dim is not None

        # Comps
        self.tok_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(ctx_len, embed_dim)

        self.ln_pre = nn.LayerNorm(embed_dim)
        self.blocks = nn.ModuleList(
            [Block(embed_dim, nhead, attn_pdrop, ffn_pdrop) for _ in range(nlayer)]
        )
        self.ln_post = nn.LayerNorm(embed_dim)

        self.lm_head = nn.Linear(embed_dim, vocab_size if output_dim is None else output_dim)

    def generate(self, idx, attn_mask=None, max_new_tokens=100, temp=1.0, topk=None):
        if self.understanding:
            raise NotImplementedError('Not for generating')
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.ctx_len:]
            logits = self(idx_cond, attn_mask)
            logits_of_last_tick = logits[:, -1, :] / temp
            if topk is not None:
                vals, _ = torch.topk(logits_of_last_tick, topk)
                logits_of_last_tick[logits_of_last_tick < vals[-1]] = float('-inf')
            if temp == 0:
                idx_next = torch.argmax(logits_of_last_tick, dim=-1, keepdim=True)
            else:
                probs = logits_of_last_tick.softmax(dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1) # concat on 'time' dimension
        return idx

    def forward(self, idx, attn_mask=None):
        attn_mask = attn_mask if attn_mask is not None else self.attn_mask

        _, t = idx.shape

        tok_emb = self.tok_emb(idx)
        pos_emb = self.pos_emb(
            torch.arange(0, t, dtype=torch.long, device=idx.device)
        )

        x = tok_emb + pos_emb

        x = self.ln_pre(x)
        for block in self.blocks:
            x = block(x, attn_mask=attn_mask)
        x = self.ln_post(x)

        logits = self.lm_head(x)

        if self.understanding:
            logits = logits[torch.arange(logits.shape[0]), idx.argmax(dim=-1)]

        return logits

In [226]:
def pair(sz: Union[int, tuple]):
    if isinstance(sz, int):
        return sz, sz
    else:
        return sz

class VisionTransformer(nn.Module):
    def __init__(self,
                 # Vision
                 img_size: Union[int, tuple],
                 patch_num: Union[int, tuple],
                 use_cls_tok: bool,
                 # Transformer
                 nlayer: int,
                 embed_dim: int,
                 nhead: int,
                 attn_pdrop: float,
                 ffn_pdrop: float,
                 # Output
                 num_classes: int
                 ):
        super(VisionTransformer, self).__init__()
        # Properties
        img_height, img_width = pair(img_size)
        patch_height_n, patch_width_n = pair(patch_num)
        scale = embed_dim ** -0.5

        self.patch_height, self.patch_width = img_height // patch_height_n, img_width // patch_width_n
        self.use_cls_tok = use_cls_tok

        # Comps
        self.to_patch = nn.Conv2d(
            in_channels=3,
            out_channels=embed_dim,
            kernel_size=(patch_height_n, patch_width_n),
            stride=(patch_height_n, patch_width_n),
        )

        self.pos_emb = nn.Parameter(scale * torch.randn(1, self.patch_height * self.patch_width + 1, embed_dim))
        self.cls_emb = nn.Parameter(scale *  torch.randn(embed_dim))

        self.ln_pre = nn.LayerNorm(embed_dim)
        self.blocks = nn.ModuleList(
            [Block(embed_dim, nhead, attn_pdrop, ffn_pdrop) for _ in range(nlayer)]
        )
        self.ln_post = nn.LayerNorm(embed_dim)

        self.proj = nn.Parameter(scale * torch.randn(embed_dim, num_classes))

    def forward(self, x):
        patches = self.to_patch(x)

        idx = rearrange(
            patches, 'b c (ph h) (pw w) -> b (ph pw) (c h w)', ph=self.patch_height, pw=self.patch_width
        )


        idx = torch.cat(
            (
                idx,
                repeat(self.cls_emb, 'e -> b 1 e', b=idx.shape[0])
            ), dim=1
        )
        x = idx + self.pos_emb


        x = self.ln_pre(x)
        for block in self.blocks:
            x = block(x)

        if self.use_cls_tok:
            x = x[:, 0, :]
        else:
            x = x[:, 1:, :].mean(dim=1)

        x = self.ln_post(x)

        return x @ self.proj

In [227]:
class CLIP(nn.Module):
    def __init__(self,
                 img_encoder,
                 text_encoder,
                 ):
        super(CLIP, self).__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.img_encoder = img_encoder
        self.text_encoder = text_encoder

    def encode_img(self, img):
        return self.img_encoder(img)

    def encode_text(self, text):
        return self.text_encoder(text)

    def forward(self, imgs, texts):
        img_features = self.encode_img(imgs)
        text_features = self.encode_text(texts)

        if img_features.shape != text_features.shape:
            raise ValueError(f'img_features shape {img_features.shape} != text_features shape {text_features.shape}')

        img_features = img_features / img_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        logits_per_img = img_features @ text_features.t() * self.logit_scale.exp()
        logits_per_text = logits_per_img.t()

        return logits_per_img, logits_per_text

In [228]:
ctx = Series()

ctx.train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomErasing(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

ctx.test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

train_cifar10_ds = datasets.CIFAR10(root='~/data', train=True, transform=ctx.train_transform)
test_cifar10_ds = datasets.CIFAR10(root='~/data', train=True, transform=ctx.test_transform)

In [229]:
ctx.device = 'cuda' if torch.cuda.is_available() else 'cpu'
ctx.tokenizer = AutoTokenizer.from_pretrained('gpt2')
ctx.tokenizer.pad_token = ctx.tokenizer.eos_token
ctx.label2cls_dict = train_cifar10_ds.classes
ctx.ctx_len = 5
ctx.batch_size = 256

def label2sentence(label, prefix='The photo of a '):
    return prefix + ctx.label2cls_dict[label]

def tokenize(s):
    return ctx.tokenizer(
        s,
        max_length=ctx.ctx_len,
        padding='max_length',
        return_tensors='pt',
    )['input_ids'].squeeze(dim=0)

In [230]:
class TextImgDataset(Dataset):
    def __init__(self, img_label_ds):
        self.img_label_ds = img_label_ds

    def __len__(self):
        return len(self.img_label_ds)

    def __getitem__(self, index):
        orig_img, orig_label = self.img_label_ds[index]
        sentence = label2sentence(orig_label)
        return orig_img, tokenize(sentence)

train_dl = DataLoader(
    TextImgDataset(train_cifar10_ds),
    batch_size=ctx.batch_size,
    shuffle=True,
)

test_dl = DataLoader(
    TextImgDataset(test_cifar10_ds),
    batch_size=ctx.batch_size,
    shuffle=False,
)

for imgs, texts in train_dl:
    print(imgs.shape, texts.shape)
    break

torch.Size([256, 3, 32, 32]) torch.Size([256, 5])


In [231]:
def clip_loss(logits_per_image, logits_per_text):
    b = logits_per_text.shape[0]
    labels = torch.arange(0, b, device=logits_per_text.device)
    image_loss = F.cross_entropy(logits_per_image, labels)
    text_loss = F.cross_entropy(logits_per_text, labels)
    return (image_loss + text_loss) / 2

ctx.epochs = 50
ctx.lr = 1e-4
ctx.weight_decay = 1e-5
ctx.eval_interval = 3
ctx.saved_w = 'clip.pt'

model = CLIP(
    img_encoder=VisionTransformer(
        img_size=32,
        patch_num=8,
        use_cls_tok=False,
        nlayer=4,
        embed_dim=768,
        nhead=12,
        attn_pdrop=0.2,
        ffn_pdrop=0.2,
        num_classes=128,
    ),
    text_encoder=TransformerLanguageModel(
        vocab_size=ctx.tokenizer.vocab_size,
        ctx_len=ctx.ctx_len,
        embed_dim=768,
        nlayer=4,
        nhead=12,
        attn_mask=None, # Default: Casual Attention Mask
        attn_pdrop=0.2,
        ffn_pdrop=0.2,
        output_dim=128
    )
).to(ctx.device)

if os.path.isfile(ctx.saved_w):
    model.load_state_dict(torch.load(ctx.saved_w, weights_only=True))

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=ctx.lr,
    weight_decay=ctx.weight_decay,
)

@lru_cache
def tokenized_sentences():
    labels = range(len(ctx.label2cls_dict))
    return torch.stack([tokenize(label2sentence(label)) for label in labels])

@torch.no_grad()
def eval_accuracy():
    def eq_cnt(l1, l2):
        n = len(l1)
        cnt = 0
        for i in range(n):
            if l1[i] == l2[i]:
                cnt += 1
        return cnt
    model.eval()
    text_features = model.encode_text(
        tokenized_sentences().to(ctx.device)
    )
    text_features = text_features / text_features.norm(dim=1, keepdim=True)
    total = 0
    correct = 0
    pbar = tqdm(test_dl)
    pbar.set_description_str('Evaling accuracy on test set')
    for imgs, texts in pbar:
        imgs = imgs.to(ctx.device)
        texts = texts.to(ctx.device)
        image_features = model.encode_img(imgs)
        logits_per_image = image_features @ text_features.t()
        labels = logits_per_image.argmax(dim=1)
        pred = [label2sentence(label) for label in labels]
        gt = [ctx.tokenizer.decode(text, skipskip_special_tokens=True) for text in texts]
        total += len(imgs)
        correct += eq_cnt(pred, gt)
    return correct / total

len_train_dl = len(train_dl)
for epoch in range(1, ctx.epochs + 1):
    train_avg_loss = 0.
    pbar = tqdm(train_dl)
    pbar.set_description_str(f'Training one epoch {epoch}/{ctx.epochs}')
    for imgs, texts in pbar:
        imgs = imgs.to(ctx.device)
        texts = texts.to(ctx.device)
        optimizer.zero_grad()
        logits_per_image, logits_per_text = model(imgs, texts)
        loss = clip_loss(logits_per_image, logits_per_text)
        train_avg_loss += loss.item() / len_train_dl
        loss.backward()
        optimizer.step()
    if epoch % ctx.eval_interval == 0:
        print(f'{epoch}/{ctx.epochs}, train_avg_loss: {train_avg_loss:.3f}, test_acc: {eval_accuracy()}')
    else:
        print(f'{epoch}/{ctx.epochs}, train_avg_loss: {train_avg_loss:.3f}')

Training one epoch 1/50: 100%|██████████| 196/196 [00:36<00:00,  5.41it/s]


1/50, train_avg_loss: 5.389


Training one epoch 2/50: 100%|██████████| 196/196 [00:36<00:00,  5.42it/s]


2/50, train_avg_loss: 5.152


Training one epoch 3/50: 100%|██████████| 196/196 [00:35<00:00,  5.47it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:22<00:00,  8.69it/s]


3/50, train_avg_loss: 5.016, test_acc: 0.39684


Training one epoch 4/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


4/50, train_avg_loss: 4.909


Training one epoch 5/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


5/50, train_avg_loss: 4.826


Training one epoch 6/50: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:24<00:00,  8.01it/s]


6/50, train_avg_loss: 4.762, test_acc: 0.48972


Training one epoch 7/50: 100%|██████████| 196/196 [00:37<00:00,  5.17it/s]


7/50, train_avg_loss: 4.707


Training one epoch 8/50: 100%|██████████| 196/196 [00:37<00:00,  5.17it/s]


8/50, train_avg_loss: 4.657


Training one epoch 9/50: 100%|██████████| 196/196 [00:35<00:00,  5.45it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:22<00:00,  8.58it/s]


9/50, train_avg_loss: 4.616, test_acc: 0.52602


Training one epoch 10/50: 100%|██████████| 196/196 [00:35<00:00,  5.53it/s]


10/50, train_avg_loss: 4.578


Training one epoch 11/50: 100%|██████████| 196/196 [00:35<00:00,  5.50it/s]


11/50, train_avg_loss: 4.548


Training one epoch 12/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:22<00:00,  8.66it/s]


12/50, train_avg_loss: 4.519, test_acc: 0.55914


Training one epoch 13/50: 100%|██████████| 196/196 [00:35<00:00,  5.49it/s]


13/50, train_avg_loss: 4.490


Training one epoch 14/50: 100%|██████████| 196/196 [00:35<00:00,  5.50it/s]


14/50, train_avg_loss: 4.462


Training one epoch 15/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.05it/s]


15/50, train_avg_loss: 4.438, test_acc: 0.58668


Training one epoch 16/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


16/50, train_avg_loss: 4.410


Training one epoch 17/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


17/50, train_avg_loss: 4.384


Training one epoch 18/50: 100%|██████████| 196/196 [00:35<00:00,  5.50it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:22<00:00,  8.90it/s]


18/50, train_avg_loss: 4.375, test_acc: 0.60834


Training one epoch 19/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


19/50, train_avg_loss: 4.351


Training one epoch 20/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


20/50, train_avg_loss: 4.335


Training one epoch 21/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  8.99it/s]


21/50, train_avg_loss: 4.320, test_acc: 0.62612


Training one epoch 22/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


22/50, train_avg_loss: 4.302


Training one epoch 23/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


23/50, train_avg_loss: 4.290


Training one epoch 24/50: 100%|██████████| 196/196 [00:35<00:00,  5.50it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.00it/s]


24/50, train_avg_loss: 4.275, test_acc: 0.64294


Training one epoch 25/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


25/50, train_avg_loss: 4.257


Training one epoch 26/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


26/50, train_avg_loss: 4.247


Training one epoch 27/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.00it/s]


27/50, train_avg_loss: 4.231, test_acc: 0.6546


Training one epoch 28/50: 100%|██████████| 196/196 [00:35<00:00,  5.53it/s]


28/50, train_avg_loss: 4.215


Training one epoch 29/50: 100%|██████████| 196/196 [00:35<00:00,  5.49it/s]


29/50, train_avg_loss: 4.207


Training one epoch 30/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.02it/s]


30/50, train_avg_loss: 4.193, test_acc: 0.6665


Training one epoch 31/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


31/50, train_avg_loss: 4.186


Training one epoch 32/50: 100%|██████████| 196/196 [00:35<00:00,  5.50it/s]


32/50, train_avg_loss: 4.174


Training one epoch 33/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.01it/s]


33/50, train_avg_loss: 4.158, test_acc: 0.67598


Training one epoch 34/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


34/50, train_avg_loss: 4.153


Training one epoch 35/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


35/50, train_avg_loss: 4.135


Training one epoch 36/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.04it/s]


36/50, train_avg_loss: 4.128, test_acc: 0.6816


Training one epoch 37/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


37/50, train_avg_loss: 4.118


Training one epoch 38/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


38/50, train_avg_loss: 4.112


Training one epoch 39/50: 100%|██████████| 196/196 [00:35<00:00,  5.50it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.07it/s]


39/50, train_avg_loss: 4.100, test_acc: 0.69454


Training one epoch 40/50: 100%|██████████| 196/196 [00:35<00:00,  5.49it/s]


40/50, train_avg_loss: 4.095


Training one epoch 41/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


41/50, train_avg_loss: 4.090


Training one epoch 42/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.04it/s]


42/50, train_avg_loss: 4.080, test_acc: 0.69682


Training one epoch 43/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


43/50, train_avg_loss: 4.064


Training one epoch 44/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]


44/50, train_avg_loss: 4.063


Training one epoch 45/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.03it/s]


45/50, train_avg_loss: 4.053, test_acc: 0.70694


Training one epoch 46/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


46/50, train_avg_loss: 4.041


Training one epoch 47/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


47/50, train_avg_loss: 4.038


Training one epoch 48/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]
Evaling accuracy on test set: 100%|██████████| 196/196 [00:21<00:00,  9.04it/s]


48/50, train_avg_loss: 4.027, test_acc: 0.71234


Training one epoch 49/50: 100%|██████████| 196/196 [00:35<00:00,  5.52it/s]


49/50, train_avg_loss: 4.016


Training one epoch 50/50: 100%|██████████| 196/196 [00:35<00:00,  5.51it/s]

50/50, train_avg_loss: 4.014





In [233]:
model.eval()

torch.save(
    model.state_dict(), ctx.saved_w
)

eval_accuracy()

Evaling accuracy on test set: 100%|██████████| 196/196 [00:22<00:00,  8.83it/s]


0.71722