In [None]:
import io
import gc
import json
import math
import timm
import torch
import random
import torchvision

import numpy as np
import polars as pl
import albumentations as A

from PIL import Image
from tqdm.notebook import tqdm
from transformers import BartModel, BartTokenizer

In [None]:
df_train = pl.read_parquet('Defactify4_Train/data/train-*.parquet')
df_test = pl.read_parquet('Final_defactify_test_new/data/train-*.parquet', low_memory = True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gc.collect()

In [None]:
class DefactifyDataset(torch.utils.data.Dataset):
    def __init__(self, data, img_size=256):
        self.data = data
        self.img_size = img_size
        self.image_keys = ['coco_image', 'sd21_image', 'sdxl_image', 'sd3_image', 'dalle_image', 'midjourney_image']
        self.augmentation = A.Compose([A.HorizontalFlip(p=0.5),
                                       A.GaussNoise(p=0.3),
                                       A.ImageCompression(quality_range=(96, 100), p=0.3),
                                       A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
                                       A.LongestMaxSize(max_size=self.img_size, interpolation=1),
                                       A.PadIfNeeded(min_height=self.img_size, min_width=self.img_size, border_mode=0, value=(0,0,0))
                                      ])
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                         # torchvision.transforms.RandomResizedCrop((self.img_size, self.img_size)),
                                                         torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    def __len__(self):
        return len(self.data)*len(self.image_keys)

    def __getitem__(self, idx):
        data_idx = idx//len(self.image_keys)
        image = np.asarray(Image.open(io.BytesIO(self.data[data_idx][self.image_keys[idx%len(self.image_keys)]][0]['bytes'])))
        metadata = np.array(image.shape)
        image = self.transform(self.augmentation(image=image)['image'])
        labels = torch.zeros(len(self.image_keys))
        labels[idx%len(self.image_keys)] = 1.0
        return {"text": self.data[data_idx]["caption"][0], "metadata": metadata, "image": image, "label": labels}

In [None]:
class PredictionDataset(torch.utils.data.Dataset):
    def __init__(self, data, img_size=256):
        self.data = data
        self.img_size = img_size
        self.augmentation = A.Compose([
                                       A.LongestMaxSize(max_size=self.img_size, interpolation=1),
                                       A.PadIfNeeded(min_height=self.img_size, min_width=self.img_size, border_mode=0, value=(0,0,0))
                                      ])
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                         # torchvision.transforms.Resize((self.img_size, self.img_size)),
                                                         torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = np.asarray(Image.open(io.BytesIO(self.data[idx]["image"][0]['bytes'])))
        metadata = np.array(image.shape)
        image = self.transform(self.augmentation(image=image)['image'])
        return {"id": idx, "text": self.data[idx]["caption"][0], "metadata": metadata, "image": image}

In [None]:
class ArcMarginProduct(torch.nn.Module):
    def __init__(self, in_features, out_features, s=1.0, m=0.2):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = torch.nn.Parameter(torch.FloatTensor(out_features, in_features))
        torch.nn.init.xavier_uniform_(self.weight)

        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, features, label):
        cos = torch.nn.functional.linear(torch.nn.functional.normalize(features), torch.nn.functional.normalize(self.weight))
        if label is None:
            return cos*self.s
        cos = torch.nn.functional.linear(torch.nn.functional.normalize(features), torch.nn.functional.normalize(self.weight))
        sin = torch.sqrt((1.0 - torch.pow(cos, 2)).clamp(0, 1))
        cos_add = cos * self.cos_m - sin * self.sin_m
        output = (label * cos_add) + ((1.0 - label) * cos)
        output *= self.s
        return output

In [None]:
class ModelFactory(torch.nn.Module):
    def __init__(self, image_backbones, text_model, n_classes, hashLength=512):
        super(ModelFactory, self).__init__()
        self.text_model = text_model
        self.image_models = [timm.create_model(image_backbone, pretrained=True, num_classes=0).to(device) for image_backbone in image_backbones]
        if torch.cuda.device_count() > 1:
            for i in range(len(image_backbones)):
                self.image_models[i] = torch.nn.DataParallel(self.image_models[i])

        self.hash = torch.nn.Linear(3691, hashLength)
        self.fc = ArcMarginProduct(hashLength, n_classes)
        self.output_layer = torch.nn.Softmax(dim=1)

    def forward(self, images, metadata, input_ids, attention_mask, labels=None):
        image_features = [model(images) for model in self.image_models]
        text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        features = torch.cat([text_features, metadata/512] + image_features, dim=1)
        fhash = self.hash(features)
        logits = self.fc(fhash, labels)
        return self.output_layer(logits)

In [None]:
class Accuracy:
    def __init__(self):
        self.correct_count = 0
        self.total_count = 0

    def update(self, predictions, targets):
        predicted_labels = torch.argmax(predictions, dim=1)
        target_labels = torch.argmax(targets, dim=1)
        self.correct_count += (predicted_labels == target_labels).sum().item()
        self.total_count += targets.size(0)

    def reset(self):
        self.correct_count = 0
        self.total_count = 0

    def compute(self):
        if self.total_count == 0:
            return 0.0
        return self.correct_count / self.total_count

In [None]:
train_dataset = torch.utils.data.DataLoader(DefactifyDataset(df_train), batch_size=12, shuffle=True, num_workers=0)
test_dataset = torch.utils.data.DataLoader(PredictionDataset(df_test), batch_size=12, shuffle=False, num_workers=0)

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
text_model = BartModel.from_pretrained("facebook/bart-base")
model = ModelFactory(["efficientnetv2_rw_m.agc_in1k", "swinv2_tiny_window8_256.ms_in1k"], text_model, 6).to(device)
model.load_state_dict(torch.load("model_meta.pth").state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
loss_function = torch.nn.CrossEntropyLoss()
metrics = [Accuracy()]

scaler = torch.amp.GradScaler(enabled=True)
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

In [None]:
for epoch in range(100): 
    gc.collect()
    model.train()
    pbar = tqdm(train_dataset)
    for i, batch in enumerate(pbar):
        texts = tokenizer(batch["text"], return_tensors="pt", padding=True)
        input_ids = texts["input_ids"].to(device)
        attention_mask = texts["attention_mask"].to(device)
        images = batch["image"].to(device)
        labels = batch["label"].to(device)
        metadata = batch["metadata"].to(device, dtype=torch.float16)
        with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16, enabled=True):
            output = model(images, metadata, input_ids, attention_mask, labels)
            loss = loss_function(output, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        for metric in metrics:
            metric.update(output, labels)
        pbar.set_description(f"Epoch: {epoch}, Loss: {loss.item():.5f}, Accuracy: {metric.compute():.5f}")
        break

    metric.reset()
    torch.save(model, "model_meta.pth")

    if epoch%3 == 0:
        model.eval()
        pbar = tqdm(test_dataset)
        predictions = []
        for batch in pbar:
            texts = tokenizer(batch["text"], return_tensors="pt", padding=True)
            input_ids = texts["input_ids"].to(device)
            attention_mask = texts["attention_mask"].to(device)
            images = batch["image"].to(device)
            metadata = batch["metadata"].to(device, dtype=torch.float16)
            with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16, enabled=True):
                output = model(images, metadata, input_ids, attention_mask).detach().cpu().numpy()
            for i, prob in enumerate(output):
                argmax_pred = int(prob.argmax())
                predictions.append({'id': int(batch["id"][i]),
                                    'caption': batch["text"][i],
                                    'Label_A': int(argmax_pred > 0),
                                    'Label_B': argmax_pred,
                                    'Prob': prob.tolist(),
                                   })
    
        with open(f"answer_{epoch}.json", 'w') as f:
            json.dump(predictions, f, indent=4)