In [None]:
from TT_BLIP.batch_extractor import DatasetLoader

ds_loader = DatasetLoader(batch_size=1, balance=True)
train_dl, val_dl = ds_loader.get_dataloaders()

In [None]:
ds = ds_loader.train_dataset

real = 0
fake = 0
for e in ds:
    if e[1]:
        real += 1
    else:
        fake += 1

print(f"Train set: Real {real} | Fake {fake} | Naive baseline acc: {fake / (real + fake)}")

In [None]:
ds = ds_loader.test_dataset

real = 0
fake = 0
for e in ds:
    if e[1]:
        real += 1
    else:
        fake += 1

print(f"Test set: Real {real} | Fake {fake} | Naive baseline acc: {fake / (real + fake)}")

In [None]:
from TT_BLIP.tt_blip_layers import TT_BLIP_Model
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger


model = TT_BLIP_Model(
        ds_loader.dp.empty_pixel_values, 
        ds_loader.dp.empty_input_ids,
        ds_loader.dp.empty_attn_mask, 
        256, 
        1
    )

In [None]:
from lightning.pytorch.utilities.model_summary import ModelSummary

ModelSummary(model)

In [None]:
logger = WandbLogger('TT_BLIP_gossipcop_balanced_256', project="Thesis_New")
trainer = Trainer(max_epochs=50, logger=logger, log_every_n_steps=1, accumulate_grad_batches=64, precision=16)
trainer.fit(model, train_dl, val_dl)

In [None]:
from tqdm.auto import tqdm 
import torch 

model.eval()

labels = []
preds = []
for b in tqdm(val_dl):
    x, y = b
    labels.append(y)
    with torch.no_grad():
        y_pred = model(x).cpu()
        preds.append(y_pred) 

In [None]:
cm = torch.zeros((2, 2))

for idx in range(len(preds)):
    i = int(labels[idx].item())
    j = int(preds[idx].item() > 0.5)
    cm[i, j] += 1

cm

In [None]:
cm_normalized = cm / cm.sum(axis=1, keepdim=True)
cm_normalized

In [None]:
import matplotlib.pyplot as plt


plt.axis('off')
plt.imshow(cm_normalized, cmap='Blues', vmin=0, vmax=1)