In [None]:
import os
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import CocoCaptions
from transformers import AutoTokenizer
from tqdm import tqdm
from einops import repeat


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        self.num_heads = num_heads
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        B, T_q, _ = q.size()
        T_k = k.size(1)
        Q = self.q_proj(q).view(B, T_q, self.num_heads, self.head_dim).transpose(1,2)
        K = self.k_proj(k).view(B, T_k, self.num_heads, self.head_dim).transpose(1,2)
        V = self.v_proj(v).view(B, T_k, self.num_heads, self.head_dim).transpose(1,2)
        scores = (Q @ K.transpose(-2,-1)) / math.sqrt(self.head_dim)
        if mask is not None:
            # mask: (B, T_k) -> (B, 1, 1, T_k)
            attn_mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        ctx = (attn @ V).transpose(1,2).contiguous().view(B, T_q, self.d_model)
        return self.out_proj(ctx)

class AddNorm(nn.Module):
    def __init__(self, d_model, dropout=0.1, eps=1e-5):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, eps=eps)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, sub):
        return self.norm(x + self.dropout(sub))

class PositionwiseFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.)/d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:,1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(d_model, num_heads, dropout)
        self.an1 = AddNorm(d_model, dropout)
        self.ff = PositionwiseFFN(d_model, d_ff, dropout)
        self.an2 = AddNorm(d_model, dropout)
    def forward(self, x, mask=None):
        x = self.an1(x, self.sa(x, x, x, mask))
        x = self.an2(x, self.ff(x))
        return x


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, emb_size=32):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Unfold(kernel_size=patch_size, stride=patch_size),
            nn.Linear(patch_size*patch_size*in_channels, emb_size)
        )
    def forward(self, x):
        patches = self.projection[0](x)
        patches = patches.transpose(1,2)
        return self.projection[1](patches)

class TextTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=4, d_ff=1024,
                 num_layers=3, max_len=5000, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed   = PositionalEncoding(d_model, max_len, dropout)
        self.layers      = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
    def forward(self, tokens, mask=None):
        x = self.token_embed(tokens)
        x = self.pos_embed(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)[:,0,:]

class VisionTransformer(nn.Module):
    def __init__(self, ch=3, img_size=144, patch_size=4, emb_dim=32,
                 n_layers=6, heads=2, d_ff=128, dropout=0.1):
        super().__init__()
        self.patch_embedding = PatchEmbedding(ch, patch_size, emb_dim)
        num_patches = (img_size//patch_size)**2
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, emb_dim))
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_dim))
        self.layers = nn.ModuleList([
            EncoderLayer(emb_dim, heads, d_ff, dropout)
            for _ in range(n_layers)])
        self.norm = nn.LayerNorm(emb_dim)
    def forward(self, img):
        x = self.patch_embedding(img)
        b,n,_ = x.shape
        cls = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat([cls, x], dim=1) + self.pos_embedding[:,:n+1]
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)[:,0,:]

In [None]:
 class CLIP(nn.Module):
    def __init__(self, vision_encoder, text_encoder, embed_dim=512, temp=0.07):
        super().__init__()
        self.vision = vision_encoder
        self.text   = text_encoder
        self.vision_proj = nn.Linear(vision_encoder.norm.normalized_shape[0], embed_dim)
        self.text_proj   = nn.Linear(text_encoder.norm.normalized_shape[0], embed_dim)
        self.logit_scale = nn.Parameter(torch.ones([])*math.log(1/temp))
    def forward(self, images, tokens, mask=None):
        img_feats = self.vision(images)
        txt_feats = self.text(tokens, mask)
        img_emb = F.normalize(self.vision_proj(img_feats), dim=-1)
        txt_emb = F.normalize(self.text_proj(txt_feats), dim=-1)
        scale = self.logit_scale.exp()
        logits_i = scale * img_emb @ txt_emb.t()
        return logits_i, logits_i.t()
    def contrastive_loss(self, logits_i):
        B = logits_i.size(0)
        targets = torch.arange(B, device=logits_i.device)
        return (F.cross_entropy(logits_i, targets) + F.cross_entropy(logits_i.t(), targets))/2

class CocoCLIPDataset(Dataset):
    def __init__(self, img_folder, ann_file, tokenizer, max_length=16):
        self.coco = CocoCaptions(root=img_folder, annFile=ann_file)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transforms.Compose([
            transforms.Resize((144,144)),transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
        
    def __len__(self): 
        return len(self.coco)
    def __getitem__(self, idx):
        img, caps = self.coco[idx]
        img = self.transform(img)
        cap = random.choice(caps)
        toks = self.tokenizer(cap, padding='max_length', truncation=True,
                             max_length=self.max_length, return_tensors='pt')
        return img, toks.input_ids.squeeze(0), toks.attention_mask.squeeze(0)

def collate_fn(batch):
    imgs, ids, masks = zip(*batch)
    return torch.stack(imgs), torch.stack(ids), torch.stack(masks)

def train():
    root = r'D:\coco_dataset\train2017'
    train_img = os.path.join(root,'train2017')
    val_img   = os.path.join(root,'val2017')
    train_ann = os.path.join(root,'captions_train2017.json')
    val_ann   = os.path.join(root,'captions_val2017.json')

    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
    model = CLIP(
        vision_encoder=VisionTransformer(),
        text_encoder=TextTransformer(tokenizer.vocab_size),
        embed_dim=512
    )
    model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    train_ds = CocoCLIPDataset(train_img, train_ann, tokenizer)
    val_ds   = CocoCLIPDataset(val_img,   val_ann,   tokenizer)
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,
                               collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, collate_fn=collate_fn)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    scaler = torch.cuda.amp.GradScaler()
    best_loss = float('inf')

    for epoch in range(1, 11):
        model.train()
        tbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]")
        run_loss = 0.0
        for imgs, ids, masks in tbar:
            imgs, ids, masks = imgs.to(device), ids.to(device), masks.to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                logits_i, logits_t = model(imgs, ids, mask=masks)
                loss = model.contrastive_loss(logits_i)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            run_loss += loss.item()
        model.eval()
        val_loss, total = 0.0, 0
        with torch.no_grad():
            for imgs, ids, masks in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                imgs, ids, masks = imgs.to(device), ids.to(device), masks.to(device)
                with torch.cuda.amp.autocast():
                    logits_i, _ = model(imgs, ids, mask=masks)
                    l = model.contrastive_loss(logits_i).item()
                bs = imgs.size(0)
                val_loss += l * bs
                total += bs
        avg = val_loss / total
        print(f"==> Epoch {epoch} | Val Loss: {avg:.4f}")
        if avg < best_loss:
            best_loss = avg
            torch.save(model.state_dict(), 'best_clip_coco.pth')
            print("Saved best model.")

 
train()
