In [1]:
from TT_BLIP.batch_extractor import DatasetLoader

ds_loader = DatasetLoader(batch_size=1, balance=True)
train_dl, val_dl = ds_loader.get_dataloaders()

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


In [2]:
ds = ds_loader.train_dataset

real = 0
fake = 0
for e in ds:
    if e[1]:
        real += 1
    else:
        fake += 1

print(f"Train set: Real {real} | Fake {fake} | Naive baseline acc: {fake / (real + fake)}")

Train set: Real 556 | Fake 528 | Naive baseline acc: 0.4870848708487085


In [3]:
ds = ds_loader.test_dataset

real = 0
fake = 0
for e in ds:
    if e[1]:
        real += 1
    else:
        fake += 1

print(f"Test set: Real {real} | Fake {fake} | Naive baseline acc: {fake / (real + fake)}")

Test set: Real 122 | Fake 150 | Naive baseline acc: 0.5514705882352942


In [4]:
from TT_BLIP.tt_blip_layers import TT_BLIP_Model
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger


model = TT_BLIP_Model(
        ds_loader.dp.empty_pixel_values, 
        ds_loader.dp.empty_input_ids,
        ds_loader.dp.empty_attn_mask, 
        768, 
        8,
        trainable=-3
    )

In [5]:
from lightning.pytorch.utilities.model_summary import ModelSummary

ModelSummary(model)

  | Name                     | Type                   | Params | Mode 
----------------------------------------------------------------------------
0 | feature_extraction_layer | FeatureExtractionLayer | 866 M  | train
1 | fusion_layer             | FusionLayer            | 193 M  | train
2 | classification_layer     | ClassificationLayer    | 1.2 M  | train
3 | loss_fn                  | BCEWithLogitsLoss      | 0      | train
4 | acc_fn                   | BinaryAccuracy         | 0      | train
5 | f1_fn                    | BinaryF1Score          | 0      | train
6 | prec_fn                  | BinaryPrecision        | 0      | train
7 | recall_fn                | BinaryRecall           | 0      | train
----------------------------------------------------------------------------
559 M     Trainable params
502 M     Non-trainable params
1.1 B     Total params
4,247.737 Total estimated model params size (MB)
380       Modules in train mode
1910      Modules in eval mode

In [6]:
x, y = next(iter(train_dl))
y

tensor([1.])

In [7]:
import torch

model.eval()
with torch.no_grad():
    z = model.feature_extraction_layer(*x)
    print(z[0].shape)
    z = model.fusion_layer(z)
    print(z.shape)
    y = model.classification_layer(z)
    print(y)

torch.Size([1, 578, 768])
torch.Size([1, 3, 768])
tensor([-0.0524])


In [None]:
logger = WandbLogger('TT_BLIP_gossipcop_balanced_256', project="Thesis_New")
trainer = Trainer(max_epochs=50, logger=logger, log_every_n_steps=1, accumulate_grad_batches=64, precision=16)
trainer.fit(model, train_dl, val_dl)

In [None]:
from tqdm.auto import tqdm 
import torch 

model.eval()

labels = []
preds = []
for b in tqdm(val_dl):
    x, y = b
    labels.append(y)
    with torch.no_grad():
        y_pred = model(x).cpu()
        preds.append(y_pred) 

In [None]:
cm = torch.zeros((2, 2))

for idx in range(len(preds)):
    i = int(labels[idx].item())
    j = int(preds[idx].item() > 0.5)
    cm[i, j] += 1

cm

In [None]:
cm_normalized = cm / cm.sum(axis=1, keepdim=True)
cm_normalized

In [None]:
import matplotlib.pyplot as plt


plt.axis('off')
plt.imshow(cm_normalized, cmap='Blues', vmin=0, vmax=1)