In [1]:
from datasets import load_dataset

train_dataset = load_dataset("rshaojimmy/DGM4", split="train[:5%]")
validation_dataset = load_dataset("rshaojimmy/DGM4", split="validation[:5%]")
train_dataset, validation_dataset

  from .autonotebook import tqdm as notebook_tqdm


(Dataset({
     features: ['id', 'text', 'image', 'fake_cls', 'fake_text_pos', 'mtcnn_boxes', 'fake_image_box'],
     num_rows: 10409
 }),
 Dataset({
     features: ['id', 'text', 'image', 'fake_cls', 'fake_text_pos', 'mtcnn_boxes', 'fake_image_box'],
     num_rows: 1106
 }))

In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import cv2 as cv
from transformers import BertTokenizerFast, AutoImageProcessor, AutoTokenizer

bert_tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")
vit_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
blip_txt_processor = AutoTokenizer.from_pretrained("Salesforce/blip-image-captioning-base")
blip_img_processor = AutoImageProcessor.from_pretrained("Salesforce/blip-image-captioning-base")


def process_func(batch):
    image = []
    text = []
    label = []

    for b in batch:
        img = cv.imread(b['image'], cv.IMREAD_ANYCOLOR)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        image.append(img)
        text.append(b['text'])
        label.append(1 if b['fake_cls'] == 'orig' else 0)

    vit_img = vit_processor(image, return_tensors='pt').pixel_values
    blip_img = blip_img_processor(image, return_tensors='pt').pixel_values
    blip_txt = blip_txt_processor(text, return_tensors='pt', padding=True)
    bert_txt = bert_tokenizer(text, return_tensors='pt', padding=True)

    blip_tokens = blip_txt.input_ids
    blip_attn = blip_txt.attention_mask

    bert_tokens = bert_txt.input_ids
    bert_attn = bert_txt.attention_mask


    label = torch.tensor(label).unsqueeze(-1).float()
    return vit_img, blip_img, (blip_tokens, blip_attn), (bert_tokens, bert_attn), label

train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=process_func, drop_last=True)
val_dl = DataLoader(validation_dataset, batch_size=8, shuffle=True, collate_fn=process_func, drop_last=True)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
ri, bi, (bt, ba), (bet, bea), lab = next(iter(train_dl))

ri.shape, bi.shape, bt.shape, bea.shape, lab.shape

(torch.Size([8, 3, 224, 224]),
 torch.Size([8, 3, 384, 384]),
 torch.Size([8, 32]),
 torch.Size([8, 32]),
 torch.Size([8, 1]))

In [4]:
from torch import nn 
import lightning as L 
from transformers import BertModel, ViTModel, BlipForConditionalGeneration


class FeatureExtractionLayer(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("google-bert/bert-base-uncased")
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

        self.bert.eval()
        self.vit.eval()
        self.blip.eval()

        self.bert.requires_grad_(False)
        self.vit.requires_grad_(False)
        self.blip.requires_grad_(False)

        self.vit.encoder.layer[9:].requires_grad_(True)
        self.vit.encoder.layer[9:].train()
        self.vit.pooler.requires_grad_(True)
        self.vit.pooler.train()

        self.bert.encoder.layer[9:].requires_grad_(True)
        self.bert.encoder.layer[9:].train()
        self.bert.pooler.requires_grad_(True)
        self.bert.pooler.train()

        self.dummy_img = nn.Parameter(torch.zeros((1, 3, 384, 384)), requires_grad=False)
        self.dummy_txt = nn.Parameter(torch.zeros((1, 1), dtype=torch.int), requires_grad=False)
        self.dummy_attn = nn.Parameter(torch.zeros((1, 1), dtype=torch.int), requires_grad=False)

    def forward(self, vit_img, blip_img, blip_txt, blip_attn, bert_txt, bert_attn):
        BSZ, *_ = vit_img.shape

        # IMAGES
        vit_encodings = self.vit(vit_img).last_hidden_state # 197 x 768
        blip_img_encodings = self.blip(blip_img, self.dummy_txt.repeat(BSZ, 1), attention_mask=self.dummy_attn.repeat(BSZ, 1)).last_hidden_state    # 577 x 768
        zi = torch.cat([vit_encodings, blip_img_encodings], dim=1)  # 197 + 577 x 768


        # TEXT
        bert_encodings = self.bert(bert_txt, bert_attn).last_hidden_state   # ?? x 768
        blip_txt_encodings = self.blip(self.dummy_img.repeat(BSZ, 1, 1, 1), blip_txt, attention_mask=blip_attn).last_hidden_state   # 577 x 768
        zt = torch.cat([bert_encodings, blip_txt_encodings], dim=1) # 577 + ?? x 768


        # IMAGE / TEXT
        blip_img_txt_encodings = self.blip(blip_img, blip_txt, attention_mask=blip_attn).last_hidden_state  # 577 x 768
        zit = blip_img_txt_encodings

        return zi, zt, zit 

In [5]:
feature_extraction_layer = FeatureExtractionLayer()
with torch.no_grad():
    zi, zt, zit = feature_extraction_layer(ri, bi, bt, ba, bet, bea)

zi.shape, zt.shape, zit.shape

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


(torch.Size([8, 774, 768]),
 torch.Size([8, 609, 768]),
 torch.Size([8, 577, 768]))

In [6]:
from lightning.pytorch.utilities.model_summary import ModelSummary

ModelSummary(feature_extraction_layer)

  | Name         | Type                         | Params | Mode
---------------------------------------------------------------------
0 | bert         | BertModel                    | 109 M  | eval
1 | vit          | ViTModel                     | 86.4 M | eval
2 | blip         | BlipForConditionalGeneration | 247 M  | eval
  | other params | n/a                          | 442 K  | n/a 
---------------------------------------------------------------------
43.7 M    Trainable params
400 M     Non-trainable params
443 M     Total params
1,774.912 Total estimated model params size (MB)
114       Modules in train mode
832       Modules in eval mode

In [7]:
class FusionLayer(L.LightningModule):
    def __init__(self, h=4):
        super().__init__()

        self.ca_img = nn.MultiheadAttention(768, h, 0.1, batch_first=True)
        self.ca_img_txt = nn.MultiheadAttention(768, h, 0.1, batch_first=True)
        self.sa_txt = nn.MultiheadAttention(768, h, 0.1, batch_first=True)

        self.mlp_img = nn.Sequential(
            nn.Linear(768, 768 * 2),
            nn.ReLU(),
            nn.Linear(768 * 2, 768),
            nn.ReLU()
        )

        self.mlp_txt = nn.Sequential(
            nn.Linear(768, 768 * 2),
            nn.ReLU(),
            nn.Linear(768 * 2, 768),
            nn.ReLU()
        )

        self.mlp_img_txt = nn.Sequential(
            nn.Linear(768, 768 * 2),
            nn.ReLU(),
            nn.Linear(768 * 2, 768),
            nn.ReLU()
        )
    
    def forward(self, zi, zt, zit):
        zi, _ = self.ca_img(zt, zi, zi)
        zit, _ = self.ca_img_txt(zt, zit, zit)
        zt, _ = self.sa_txt(zt, zt, zt)

        zi = self.mlp_img(zi)
        zt = self.mlp_txt(zt)
        zit = self.mlp_img(zit)

        z = torch.cat([zi, zt, zit], dim=1)
        return z 

In [8]:
fusion_layer = FusionLayer()

z = fusion_layer(zi, zt, zit)
z.shape

torch.Size([8, 1827, 768])

In [9]:
ModelSummary(fusion_layer)

  | Name        | Type               | Params | Mode 
-----------------------------------------------------------
0 | ca_img      | MultiheadAttention | 2.4 M  | train
1 | ca_img_txt  | MultiheadAttention | 2.4 M  | train
2 | sa_txt      | MultiheadAttention | 2.4 M  | train
3 | mlp_img     | Sequential         | 2.4 M  | train
4 | mlp_txt     | Sequential         | 2.4 M  | train
5 | mlp_img_txt | Sequential         | 2.4 M  | train
-----------------------------------------------------------
14.2 M    Trainable params
0         Non-trainable params
14.2 M    Total params
56.688    Total estimated model params size (MB)
21        Modules in train mode
0         Modules in eval mode

In [10]:
import torch.nn.functional as F 

class ClsfLayer(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.clsf = nn.Sequential(
            nn.Linear(768, 768 * 2),
            nn.ReLU(),
            nn.Linear(768 * 2, 768 * 2),
            nn.ReLU(),
            nn.Linear(768 * 2, 768 * 2),
            nn.ReLU(),
            nn.Linear(768 * 2, 1),
            nn.Flatten()
        )
    
    def forward(self, z):
        z = self.clsf(z)
        _, N = z.shape 
        y = F.avg_pool1d(z, N, N)
        return y 

In [11]:
clsf_layer = ClsfLayer()

y = clsf_layer(z)
y.shape

torch.Size([8, 1])

In [12]:
ModelSummary(clsf_layer)

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | clsf | Sequential | 5.9 M  | train
--------------------------------------------
5.9 M     Trainable params
0         Non-trainable params
5.9 M     Total params
23.618    Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode

In [13]:
from torchmetrics import Accuracy


class TT_Blip(L.LightningModule):
    def __init__(self, feature_extraction_layer, fusion_layer, clsf_layer):
        super().__init__()

        self.feature_extraction_layer = feature_extraction_layer
        self.fusion_layer = fusion_layer
        self.clsf_layer = clsf_layer

        self.loss_fn = nn.BCEWithLogitsLoss()
        self.accuracy = Accuracy('binary')

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 1e-3)
    
    def forward(self, vit_img, blip_img, blip_txt, blip_attn, bert_txt, bert_attn):
        zi, zt, zit = self.feature_extraction_layer(vit_img, blip_img, blip_txt, blip_attn, bert_txt, bert_attn)
        z = self.fusion_layer(zi, zt, zit)
        y = self.clsf_layer(z)
        return y 
    
    def training_step(self, batch):
        vi, bi, (bt, ba), (bet, bea), y = batch
        pred = self.forward(vi, bi, bt, ba, bet, bea)

        loss = self.loss_fn(pred, y)
        acc = self.accuracy(pred, y)

        self.log("train_loss", loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss 
    
    def validation_step(self, batch):
        vi, bi, (bt, ba), (bet, bea), y = batch
        pred = self.forward(vi, bi, bt, ba, bet, bea)

        loss = self.loss_fn(pred, y)
        acc = self.accuracy(pred, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss 


model = TT_Blip(feature_extraction_layer, fusion_layer, clsf_layer)

ModelSummary(model)

  | Name                     | Type                   | Params | Mode 
----------------------------------------------------------------------------
0 | feature_extraction_layer | FeatureExtractionLayer | 443 M  | train
1 | fusion_layer             | FusionLayer            | 14.2 M | train
2 | clsf_layer               | ClsfLayer              | 5.9 M  | train
3 | loss_fn                  | BCEWithLogitsLoss      | 0      | train
4 | accuracy                 | BinaryAccuracy         | 0      | train
----------------------------------------------------------------------------
63.8 M    Trainable params
400 M     Non-trainable params
463 M     Total params
1,855.217 Total estimated model params size (MB)
149       Modules in train mode
832       Modules in eval mode

In [None]:
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger

logger = WandbLogger(name='ViT_TT_Blip_training', project="Thesis")
trainer = Trainer(max_epochs=50, logger=logger, log_every_n_steps=1, accumulate_grad_batches=8)

trainer.fit(model, train_dl, val_dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mosusume[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type                   | Params | Mode 
----------------------------------------------------------------------------
0 | feature_extraction_layer | FeatureExtractionLayer | 443 M  | train
1 | fusion_layer             | FusionLayer            | 14.2 M | train
2 | clsf_layer               | ClsfLayer              | 5.9 M  | train
3 | loss_fn                  | BCEWithLogitsLoss      | 0      | train
4 | accuracy                 | BinaryAccuracy         | 0      | train
----------------------------------------------------------------------------
63.8 M    Trainable params
400 M     Non-trainable params
463 M     Total params
1,855.217 Total estimated model params size (MB)
149       Modules in train mode
832       Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/daniele/Scrivania/tt_blip_implementation/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/daniele/Scrivania/tt_blip_implementation/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

/home/daniele/Scrivania/tt_blip_implementation/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0:  62%|██████▏   | 810/1301 [25:23<15:23,  0.53it/s, v_num=34tn, train_loss=0.712]