In [None]:
from datasets import load_dataset

train_dataset = load_dataset("rshaojimmy/DGM4", split="validation[:3%]")
validation_dataset = load_dataset("rshaojimmy/DGM4", split="validation[3%:4%]")
train_dataset, validation_dataset

In [None]:
import matplotlib.pyplot as plt
import cv2 as cv


entry = train_dataset[4]

img = cv.imread(entry['image'], cv.IMREAD_ANYCOLOR)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
img = cv.resize(img, (224, 224))

plt.imshow(img)
plt.axis('off')
plt.show()

print(entry['text'])
entry['fake_cls']

In [None]:
# BERT

from transformers import AutoTokenizer, AutoModelForMaskedLM

bert_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-uncased")

bert = model.bert

In [4]:
# BLIP

from transformers import AutoProcessor, AutoModelForImageTextToText

blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip = AutoModelForImageTextToText.from_pretrained("Salesforce/blip-image-captioning-base")

In [None]:
# RESNET
from transformers import AutoImageProcessor, AutoModelForImageClassification

resnet_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")

resnet = model.resnet

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


def process_func(batch):
    image = []
    text = []
    label = []

    for b in batch:
        img = cv.imread(b['image'], cv.IMREAD_ANYCOLOR)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        img = cv.resize(img, (224, 224)) / 255.0
        img = torch.tensor(img).unsqueeze(0)
        img = F.tanh(img)
        image.append(img)

        text.append(b['text'])

        label.append(1 if b['fake_cls'] == 'orig' else 0)

    image = torch.vstack(image)
    tokn = bert_tokenizer(text, return_tensors='pt', padding=True)
    text = tokn.input_ids
    attn_mask = tokn.attention_mask

    pad_dim = 50 - text.shape[1]
    if pad_dim < 0:
        text = text[:, :50]
        attn_mask = attn_mask[:, :50]
    else:
        text = torch.cat([text, torch.zeros(text.shape[0], pad_dim)], dim=1).long()
        attn_mask = torch.cat([attn_mask, torch.zeros(attn_mask.shape[0], pad_dim)], dim=1).long()
    label = torch.tensor(label).float()
    return image, text, label, attn_mask

train_dl = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=process_func, drop_last=True)
val_dl = DataLoader(validation_dataset, batch_size=4, shuffle=True, collate_fn=process_func, drop_last=True)


image, text, label, mask = next(iter(train_dl))
image.shape, text.shape, label.shape, mask.shape

In [None]:
from torch import nn 
import lightning as L
from lightning.pytorch.utilities.model_summary import ModelSummary

class FeatureExtractionLayer(L.LightningModule):
    def __init__(self, blip_preprocessor, resnet_preprocessor, bert, resnet, blip, embed_dim=512):
        super().__init__()
        self.blip_pr = blip_preprocessor
        self.resnet_pr = resnet_preprocessor

        self.bert = bert 
        self.resnet = resnet
        self.blip = blip

        self.bert.requires_grad_(False)
        self.resnet.requires_grad_(False)
        self.blip.requires_grad_(False)

        self.dummy_text = nn.Parameter(torch.tensor([101, 202]).unsqueeze(0), requires_grad=False)
        self.dummy_img = nn.Parameter(torch.zeros((1, 3, 384, 384)), requires_grad=False)

        # Projectors to adapt feature dimensions
        self.resnet_adapter = nn.Sequential(
            nn.Conv2d(2048, embed_dim, 7, 7), 
            nn.Flatten(),
            nn.ReLU()
        )

        self.bert_adapter = nn.Sequential(
            nn.AvgPool1d(50, 50),
            nn.Conv1d(768, embed_dim, 1, 1), 
            nn.Flatten(),
            nn.ReLU()
        )

        self.blip_img_adapter = nn.Sequential(
            nn.AvgPool1d(577, 577),
            nn.Conv1d(768, embed_dim, 1, 1),
            nn.Flatten(),
            nn.ReLU()
        )

        self.blip_txt_adapter = nn.Sequential(
            nn.AvgPool1d(577, 577),
            nn.Conv1d(768, embed_dim, 1, 1),
            nn.Flatten(),
            nn.ReLU()
        )

        self.blip_img_txt_adapter = nn.Sequential(
            nn.AvgPool1d(577, 577),
            nn.Conv1d(768, embed_dim * 2, 1, 1),
            nn.Flatten(),
            nn.ReLU()
        )

    def forward(self, img, txt, mask):
        BSZ, *_ = img.shape
        blip_img_preprocessing = self.blip_pr(image, ["" for _ in range(BSZ)], return_tensors="pt", do_rescale=False).pixel_values.to(self.device)
        resnet_img_preprocessing = self.resnet_pr(image, return_tensors="pt", do_rescale=False).pixel_values.to(self.device)

        bert_encodings = self.bert(txt, attention_mask=mask).last_hidden_state.to(self.device)   # BSZ x 50 x 768
        resnet_encodings = self.resnet(resnet_img_preprocessing).last_hidden_state.to(self.device) # BSZ x 2048 x 7 x 7

        blip_text = self.blip(blip_img_preprocessing, self.dummy_text.repeat(BSZ, 1)).last_hidden_state.to(self.device)   # BSZ x 577 x 768
        blip_image = self.blip(self.dummy_img.repeat(BSZ, 1, 1, 1), txt, attention_mask=mask).last_hidden_state.to(self.device)
        blip_image_text = self.blip(blip_img_preprocessing, txt, attention_mask=mask).last_hidden_state.to(self.device)

        img_features = torch.cat([self.resnet_adapter(resnet_encodings), self.blip_img_adapter(blip_image.transpose(-2, -1))], dim=1)
        txt_features = torch.cat([self.bert_adapter(bert_encodings.transpose(-2, -1)), self.blip_txt_adapter(blip_text.transpose(-2, -1))], dim=1)
        img_txt_features = self.blip_img_txt_adapter(blip_image_text.transpose(-2, -1))
        return img_features, txt_features, img_txt_features
    
extraction_layer = FeatureExtractionLayer(blip_processor, resnet_processor, bert, resnet, blip, 256)
with torch.no_grad():
    x_i, x_t, x_it = extraction_layer(image, text, mask)

print(ModelSummary(extraction_layer))
print(x_i.shape, x_t.shape, x_it.shape)

In [None]:
class FusionLayer(L.LightningModule):
    def __init__(self, embedding_dim=512, num_heads=8):
        super().__init__()
        self.attn_img = nn.MultiheadAttention(embedding_dim, num_heads, batch_first=True)
        self.attn_txt = nn.MultiheadAttention(embedding_dim, num_heads, batch_first=True)
        self.attn_txt_img = nn.MultiheadAttention(embedding_dim, num_heads, batch_first=True)

        self.mlp_img = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim), 
            nn.ReLU()
        )

        self.mlp_txt = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim), 
            nn.ReLU(),
        )

        self.mlp_txt_img = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim), 
            nn.ReLU(),
        )

    def forward(self, zi, zt, zit):
        attn_i, _ = self.attn_img(zt, zi, zi)
        attn_t, _ = self.attn_txt(zt, zt, zt)
        attn_it, _ = self.attn_txt_img(zt, zit, zit)

        attn_i = self.mlp_img(attn_i)
        attn_t = self.mlp_txt(attn_t)
        attn_it = self.mlp_txt_img(attn_it)

        z = torch.cat([attn_i, attn_t, attn_it], 1)
        return z
    
fusion_layer = FusionLayer()
# z = fusion_layer(x_i, x_t, x_it)
print(ModelSummary(fusion_layer))

# z.shape


In [None]:
class ClassificationLayer(L.LightningModule):
    def __init__(self, embed_dim):
        super().__init__()
        self.clsf = nn.Sequential(
            nn.Linear(embed_dim * 3, embed_dim * 2),
            nn.ReLU(),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, 1), 
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.clsf(z)
    
clsf_layer = ClassificationLayer(512)
# y = clsf_layer(z)

print(ModelSummary(clsf_layer))

# y.shape

In [None]:
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger


class TT_Blip(L.LightningModule):
    def __init__(self, feature_extraction_layer:FeatureExtractionLayer, fusion_layer:FusionLayer, clsf_layer:ClassificationLayer):
        super().__init__()

        self.feature_extraction_layer = feature_extraction_layer
        self.fusion_layer = fusion_layer
        self.clsf_layer = clsf_layer

        self.loss_fn = nn.BCELoss()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 1e-3)
    
    def forward(self, img, txt, attn_mask):
        z_i, z_t, z_it = self.feature_extraction_layer(img, txt, attn_mask)
        z = self.fusion_layer(z_i, z_t, z_it)
        y = self.clsf_layer(z)
        return y 
    
    def training_step(self, batch):
        img, txt, y, mask = batch
        pred = self.forward(img, txt, mask)
        loss = self.loss_fn(pred, y.unsqueeze(-1))
        self.log('train_loss', loss, prog_bar=True)
        return loss 
    
    def validation_step(self, batch):
        img, txt, y, mask = batch
        pred = self.forward(img, txt, mask)
        loss = self.loss_fn(pred, y.unsqueeze(-1))
        self.log('valid_loss', loss, prog_bar=True, on_epoch=True, on_step=False)
        return loss 
    
model = TT_Blip(extraction_layer, fusion_layer, clsf_layer)
# y = model(image, text, mask)
ModelSummary(model)#, y.shape

In [None]:
logger = WandbLogger("train_TT_Blip", project="Thesis")
trainer = Trainer(max_epochs=50, logger=logger, log_every_n_steps=1)

trainer.fit(model, train_dl, val_dl)