In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

%pip install tiktoken verovio accelerate -q

import os
import sys
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import re

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset


from datasets import Dataset as DS
from datasets import load_metric

import seaborn as sn

import matplotlib.pyplot as plt

import random

import time

from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import MinMaxScaler, StandardScaler


import gc

import inspect

try:
    from transformers import (
        AutoTokenizer,
        AutoModelForSeq2SeqLM,
        DataCollatorForSeq2Seq,
        Seq2SeqTrainer,
        Seq2SeqTrainingArguments,
        TrainingArguments,
        Trainer,
        pipeline,
        AutoModelForSequenceClassification
    )
except:
    %pip install transformers
    from transformers import (
        AutoTokenizer,
        AutoModelForSeq2SeqLM,
        DataCollatorForSeq2Seq,
        Seq2SeqTrainer,
        Seq2SeqTrainingArguments,
        pipeline
    )
    
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification

print("All libraries have been installed successfully!", end="\r")

In [None]:
import pandas as pd

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
data = pd.read_csv("dataset.csv")

In [None]:
data

In [None]:
data.reset_index(drop=True, inplace=True)

In [None]:
data.columns

In [None]:
def sanity_check():
    for i, row in data.iterrows():
        if not os.path.exists(f"./images/{row['image']}"):
            print(row['image'],"does not exist")
        

In [None]:
# sanity_check()

In [None]:
train_df, val_df = train_test_split(data, test_size=0.2, shuffle=True)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

In [None]:
val_df, test_df = train_test_split(val_df, test_size=0.5, shuffle=True)

In [None]:
test_df = test_df.reset_index(drop=True)

In [None]:
len(train_df), len(val_df), len(test_df)

In [None]:
torch.cuda.empty_cache()

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
from transformers import AutoModel, AutoTokenizer, BeitFeatureExtractor, BeitForMaskedImageModeling, AutoModelForPreTraining, BertTokenizerFast, VisualBertModel

In [None]:
from transformers import ViltProcessor, ViltModel

In [None]:
OCR_MODEL_NAME = "ucaslcl/GOT-OCR2_0"
MODEL_NAME = "dandelin/vilt-b32-mlm"

ocr_tokenizer = AutoTokenizer.from_pretrained(OCR_MODEL_NAME, trust_remote_code=True)
ocr_model = AutoModel.from_pretrained(OCR_MODEL_NAME, trust_remote_code=True, 
                                      low_cpu_mem_usage=True, device_map=device, use_safetensors=True, pad_token_id=ocr_tokenizer.eos_token_id)

ocr_model = ocr_model.eval().cuda()
ocr_model.generation_config.pad_token_id = ocr_tokenizer.pad_token_id

processor = ViltProcessor.from_pretrained(MODEL_NAME)
model = ViltModel.from_pretrained(MODEL_NAME)

In [None]:
import torchvision.transforms as transforms

In [None]:
# res = ocr_model.chat(ocr_tokenizer, "/kaggle/input/muslim-hate-memes/images/0desqkyb24r51.jpeg", ocr_type='ocr')

In [None]:
# res

In [None]:
from datasets import Dataset as DS
from datasets import load_metric

In [None]:
from PIL import Image

In [None]:
class MuslimHateMemes(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        image_path = f"./images/{row['image']}"
        label = int(row['label'])
        ocr_text = ocr_model.chat(ocr_tokenizer, image_path, ocr_type='ocr')
        image = Image.open(image_path).resize((252,252)).convert("RGB")
        inputs = processor(image, ocr_text, max_length=40, padding='max_length', truncation=True, return_tensors="pt")

        for x in inputs:
            inputs[x] = inputs[x].squeeze(0)
            # print(x,":",inputs[x].shape)
        
        return inputs, label, image_path

In [None]:
trainds = MuslimHateMemes(train_df)
valds = MuslimHateMemes(val_df)
testds = MuslimHateMemes(test_df)

In [None]:
trainloader = DataLoader(trainds, batch_size=16, shuffle=True)
valloader = DataLoader(valds, batch_size=2, shuffle=False)
testloader = DataLoader(testds, batch_size=2, shuffle=False)

In [None]:
# example = next(iter(trainloader))

In [None]:
# model

In [None]:
class MemeClassifer(nn.Module):
    def __init__(
        self,
        finetune: str = "limit",
        enc_finetune_limit: int = 2, 
        output: int = 1
    ):
        super().__init__()
        self.vilt = model
        self.clf = nn.Sequential(
            nn.LayerNorm(768),
            nn.Dropout(0.3),
            nn.Linear(768, 768, bias=True),
            nn.LayerNorm(768),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(768, output, bias=True),
        )
        self.act = nn.Sigmoid()

        if finetune == "all":
            for child in self.vilt.encoder.layer.children():
                for param in child.parameters():
                    param.requires_grad = True
        else:
            count = 0
            length = sum(1 for _ in self.vilt.encoder.layer.children())
            for child in self.vilt.encoder.layer.children():
                if count >= (length - enc_finetune_limit):
                    for param in child.parameters():
                        param.requires_grad = True
                else:
                    for param in child.parameters():
                        param.requires_grad = False
                count += 1

    def forward(self, inputs):
        
        outputs = self.vilt(**inputs)
        outputs = self.act(self.clf(outputs.pooler_output))

        return outputs

In [None]:
memeclsmodel = MemeClassifer(finetune="all", enc_finetune_limit=3)

In [None]:
memeclsmodel = memeclsmodel.to(device)

In [None]:
# memeclsmodel

In [None]:
def reset_weights(m):
    '''
        Try resetting model weights to avoid
        weight leakage.
      '''
    for layer in m.children():
        if hasattr(layer, 'reset_parameters'):
            print(f'Reset trainable parameters of layer = {layer}')
            layer.reset_parameters()

In [None]:
%pip install evaluate -q

In [None]:
import evaluate

accuracy = evaluate.load('accuracy')
f1_metric = evaluate.load("f1", "binary")
precision_metric = evaluate.load("precision", "binary")
recall_metric = f1_macro = evaluate.load("recall", "binary")
# clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

In [None]:
def compute_metrics_v2(preds, labels):
    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    
    torch.cuda.empty_cache()
    
    acc = accuracy.compute(predictions=preds, references=labels)['accuracy']
    f1_macro = f1_metric.compute(predictions=preds, references=labels, average="macro")['f1']
    f1_micro = f1_metric.compute(predictions=preds, references=labels, average="micro")['f1']
    f1_weighted = f1_metric.compute(predictions=preds, references=labels, average="weighted")['f1']
    
    precision = precision_metric.compute(predictions=preds, references=labels, average='weighted', zero_division=1)['precision']
    
    recall = recall_metric.compute(predictions=preds, references=labels, average='weighted', zero_division=1)['recall']

    return {
        "accuracy": acc,
        "f1_macro": f1_macro,
        "f1_micro": f1_micro,
        "f1_weighted": f1_weighted,
        "precision": precision,
        "recall": recall,
    }

In [None]:
optimizer = torch.optim.Adam(memeclsmodel.parameters(), lr=3e-4)
loss_function = torch.nn.BCELoss()

In [None]:
def training(model, training_loader):
    tr_loss = 0
    nb_tr_steps = 0
    nb_tr_examples = 0
    example_ct = 0
    
    metrs = {
        "train/accuracy": 0,
        "train/f1_macro": 0,
        "train/f1_micro": 0,
        "train/f1_weighted": 0,
        "train/precision": 0,
        "train/recall": 0,
    }
    
    model.train()

    try:
        wandb.watch(model, log=None, log_freq=10)
    except:
        print("WANDB not logging")
    
    for _ , (data, labels, img_paths) in enumerate(training_loader):

        data = data.to(device)
        labels = labels.to(device, dtype = torch.float)

        outputs = model(data).squeeze(-1).float()
        loss = loss_function(outputs, labels)
        tr_loss += loss.item()
        big_idx = torch.round(outputs.data)
        metrics = compute_metrics_v2(big_idx, labels)
        
        for k,v in metrics.items():
            metrs[f"train/{k}"] += v
        
        example_ct += len(data)
        
        nb_tr_steps += 1
        nb_tr_examples+=labels.size(0)
        
        if _%1000==0:
            loss_step = tr_loss/len(training_loader)
#             wandb.log( mergeDicts([{"loss": loss_step}, metrics]) , step=nb_tr_steps)
            print(f"Training Loss per 1000 steps: {loss_step}")
            print(metrics)

        optimizer.zero_grad()
        loss.backward()
        # # When using GPU
        optimizer.step()
        
        torch.cuda.empty_cache()

    epoch_loss = tr_loss/len(training_loader)
    
    for k,v in metrs.items():
        metrs[k] = v / len(training_loader)

    print(f"Total Training: {epoch_loss}")
    print(f"Total Training Metrics:")
    
    print(metrs)

    return epoch_loss, metrs

In [None]:
def testing(model, testing_loader):
    model.eval()
    
    metrs = {
        "eval/accuracy": 0,
        "eval/f1_macro": 0,
        "eval/f1_micro": 0,
        "eval/f1_weighted": 0,
        "eval/precision": 0,
        "eval/recall": 0,
    }
    
    try:
        wandb.watch(model, log=None, log_freq=10)
    except:
        print("WANDB not logging")
    
    n_correct = 0; n_wrong = 0; total = 0
    tr_loss = 0
    nb_tr_steps = 0
    nb_tr_examples = 0
    with torch.no_grad():
        for _,(data, labels, img_paths) in enumerate(testing_loader):
            data = data.to(device)
            labels = labels.to(device, dtype = torch.float)
            
            outputs = model(data).squeeze(-1).float()
            loss = loss_function(outputs, labels)
            tr_loss += loss.item()
            big_idx = torch.round(outputs.data)
            metrics = compute_metrics_v2(big_idx, labels)
    
            for k,v in metrics.items():
                metrs[f"eval/{k}"] += v
            
            nb_tr_steps += 1
            nb_tr_examples+=labels.size(0)
            
            if _%1000==0:
                loss_step = tr_loss/len(testing_loader)
#                 wandb.log( mergeDicts([{"loss": loss_step}, metrs]) , step=nb_tr_steps)
                print(f"Validation Loss per 1000 steps: {loss_step}")
                print(metrics)
                
            torch.cuda.empty_cache()
                
    epoch_loss = tr_loss/len(testing_loader)
    for k,v in metrs.items():
        metrs[k] = v / len(testing_loader)
    
    print(f"Total Eval Loss: {epoch_loss}")
    print(f"Total Eval Metrics:")
    print(metrs)
    
    return epoch_loss, metrs

In [None]:
best_loss = 1e9

In [None]:
epochs = 10

In [None]:
memeclsmodel.load_state_dict(torch.load("./best_weights.pth"))

In [None]:
test_loss, test_metrs = testing(memeclsmodel, testloader)

In [None]:
print("Test Loss:",test_loss)

In [None]:
print(test_metrs)