# Imports and data

In [None]:


import pandas as pd
import numpy as np
import jsonlines
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch_optimizer as optim
import random
from transformers import AutoModelWithLMHead, AutoTokenizer, AutoModel

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from importlib import reload
pd.set_option('display.max_rows', 500)
pd.set_option('display.float_format', '{:0.3f}'.format)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.options.display.width = 0
import warnings
import torchvision
warnings.filterwarnings('ignore')

from facebook_hateful_memes_detector.utils.globals import set_global, get_global
set_global("cache_dir", "/home/ahemf/cache/cache")
# set_global("cache_dir", "/Users/ahemf/mygit/facebook-hateful-memes/cache")
set_global("dataloader_workers", 8)
set_global("use_autocast", True)
set_global("models_dir", "/home/ahemf/cache/")

from facebook_hateful_memes_detector.utils import read_json_lines_into_df, in_notebook, set_device, random_word_mask, dict2sampleList, run_simclr, load_stored_params
get_global("cache_dir")
from facebook_hateful_memes_detector.models import Fasttext1DCNNModel, MultiImageMultiTextAttentionEarlyFusionModel, LangFeaturesModel, AlbertClassifer
from facebook_hateful_memes_detector.preprocessing import TextImageDataset, get_datasets, get_image2torchvision_transforms, TextAugment
from facebook_hateful_memes_detector.preprocessing import DefinedRotation, QuadrantCut, ImageAugment, DefinedAffine, HalfSwap, get_transforms_for_bbox_methods
from facebook_hateful_memes_detector.preprocessing import get_transforms_for_multiview
from facebook_hateful_memes_detector.preprocessing import NegativeSamplingDataset, ImageFolderDataset, ZipDatasets
from facebook_hateful_memes_detector.models.MultiModal.VilBertVisualBert import VilBertVisualBertModel
from facebook_hateful_memes_detector.models.MultiModal import VilBertVisualBertModelV2, MLMSimCLR, MLMOnlyV2
from facebook_hateful_memes_detector.training import *
import facebook_hateful_memes_detector
from facebook_hateful_memes_detector.utils import get_vgg_face_model, get_torchvision_classification_models, init_fc, my_collate, merge_sample_lists
reload(facebook_hateful_memes_detector)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_device(device)

scheduler_init_fn = get_cosine_schedule_with_warmup()
# Use mixup in SSL training, Use UDA maybe


In [None]:
import random
def get_preprocess_text():
    char_level = {"keyboard": 0.1, "char_substitute": 0.4, "char_insert": 0.2, "char_swap": 0.2, 
                  "ocr": 0.0, "char_delete": 0.1}
    char_level = TextAugment([0.1, 0.4, 0.5], char_level)
    word_level = {"split": 0.2,
                 "stopword_insert": 0.0, "word_join": 0.2, "punctuation_continue": 0.5}
    word_level = TextAugment([0.1, 0.4, 0.5], word_level, 
                             fasttext_file="wiki-news-300d-1M-subword.bin")
    sentence_level = {"text_rotate": 0.0, "sentence_shuffle": 0.0, # "glove_twitter": 0.75,"word_cutout": 0.5,
                      "one_third_cut": 0.25, "half_cut":0.0, "part_select": 0.25, }
    sentence_level = TextAugment([0.1, 0.9], sentence_level, # idf_file="/home/ahemf/cache/tfidf_terms.csv"
                                )
    gibberish = {"gibberish_insert": 0.25, "punctuation_insert": 0.75, 
                 "punctuation_replace": 0.25, "punctuation_strip": 0.5,}
    gibberish = TextAugment([0.25, 0.75], gibberish)
    # translation = {"dab":1.0, "punctuation_insert": 0.01}
    # translation = TextAugment([0.0, 1.0], translation, dab_file="/home/ahemf/cache/fdab.csv")
    def process(text, **kwargs):
        if random.random() < 0.25:
            text = sentence_level(text, **kwargs)
        # else:
            # text = translation(text, **kwargs)
        text = word_level(text, **kwargs)
        text = char_level(text, **kwargs)
        text = gibberish(text, **kwargs)
        return text
    return process


preprocess_text = get_preprocess_text()

def get_views():
    image_views = get_transforms_for_multiview()
    def get_view_1():
        augs = {"keyboard": 0.4, "char_substitute": 0.4, "char_insert": 0.2, "char_swap": 0.2, 
                      "ocr": 0.0, "char_delete": 0.1, "gibberish_insert": 0.1, "punctuation_insert": 0.75, 
                     "punctuation_replace": 0.25, "punctuation_strip": 0.5, "word_join": 0.2, "punctuation_continue": 0.5}
        text_augs = TextAugment([0.1, 0.9,], augs)
        imtrans = image_views[0]
        vtp = np.vectorize(text_augs)
        def vip(images):
            return [imtrans(i) for i in images]

        def aug(sampleList):
            sampleList = dict2sampleList(sampleList, device=get_device())
            sampleList = sampleList.copy()
            sampleList.image = vip(sampleList.original_image)
            sampleList.text = vtp(sampleList.original_text)
            sampleList.mixup = [False] * len(sampleList.text)
            sampleList = sampleList.to(get_device())
            return sampleList
        return aug
    
    def get_view_2():
        augs = {"keyboard": 0.4, "char_substitute": 0.2, "char_insert": 0.2, "char_swap": 0.1, 
                      "ocr": 0.0, "char_delete": 0.1, "gibberish_insert": 0.0, "punctuation_insert": 0.75, 
                     "punctuation_replace": 0.5, "punctuation_strip": 0.5, "word_join": 0.3, "punctuation_continue": 0.5}
        text_augs = TextAugment([0.2, 0.8,], augs)
        imtrans = image_views[1]
        vtp = np.vectorize(text_augs)
        def vip(images):
            return [imtrans(i) for i in images]

        def aug(sampleList):
            sampleList = dict2sampleList(sampleList, device=get_device())
            sampleList = sampleList.copy()
            sampleList.image = vip(sampleList.original_image)
            sampleList.text = vtp(sampleList.original_text)
            sampleList.mixup = [False] * len(sampleList.text)
            sampleList = sampleList.to(get_device())
            return sampleList
        return aug
    
    def get_view_3():
        imtrans = image_views[2]
        def vip(images):
            return [imtrans(i) for i in images]

        def aug(sampleList):
            sampleList = dict2sampleList(sampleList, device=get_device())
            sampleList = sampleList.copy()
            sampleList.image = vip(sampleList.original_image)
            sampleList.text = sampleList.original_text
            sampleList.mixup = [False] * len(sampleList.text)
            sampleList = sampleList.to(get_device())
            return sampleList
        return aug
    return [get_view_1(), get_view_2(), get_view_3()]

data = get_datasets(data_dir="../data/",
                    train_text_transform=preprocess_text,
                    train_image_transform=get_transforms_for_bbox_methods(),
                    test_text_transform=None,
                    test_image_transform=None,
                    train_torchvision_pre_image_transform=None,
                    test_torchvision_pre_image_transform=None,
                    cache_images=True,
                    use_images=True,
                    dev=False,
                    test_dev=True,
                    keep_original_text=True,
                    keep_original_image=True,
                    keep_processed_image=True,
                    keep_torchvision_image=False,
                    train_mixup_config=None)


data["test"]["label"] = -1

df = pd.concat((data["train"],
                data["dev"], 
                data["test"]))

In [None]:
dataset = convert_dataframe_to_dataset(df, data["metadata"], True)

In [None]:
model_params = dict(
    model_name={"lxmert": dict(dropout=0.05, gaussian_noise=0.01), 
                "vilbert": dict(dropout=0.1, gaussian_noise=0.05), 
                "visual_bert": dict(dropout=0.1, gaussian_noise=0.05), 
                "mmbt_region": dict(dropout=0.1, gaussian_noise=0.05)},
    num_classes=2,
    gaussian_noise=0.0,
    dropout=0.0,
    word_masking_proba=0.15,
    featurizer="pass",
    final_layer_builder=fb_1d_loss_builder,
    internal_dims=768,
    classifier_dims=768,
    n_tokens_in=128,
    n_tokens_out=128,
    n_layers=0,
    attention_drop_proba=0.0,
    loss="focal",
    dice_loss_coef=0.0,
    auc_loss_coef=0.0,
    bbox_swaps=1,
    bbox_copies=1,
    bbox_deletes=0,
    bbox_gaussian_noise=0.01,
    view_transforms=get_views(),
    finetune=False)

model_class = VilBertVisualBertModelV2
model = model_class(**model_params)
model = model.to(get_device())




# Unimodal MLM

In [None]:

optimizer = torch.optim.AdamW
optimizer_params = dict(lr=1e-4, weight_decay=1e-2)

from facebook_hateful_memes_detector.models.MultiModal.VilBertVisualBertV2 import positive, negative
mlm_model = MLMOnlyV2(model, 0.1, {1: negative, 0: positive}, None)
mlm_model = mlm_model.to(get_device())


In [None]:

lr_strategy = {
    "finetune": True,
    "model": {
        "vilbert": {"finetune": False,},
        "visual_bert": {"finetune": False,},
        "mmbt_region": {"finetune": False,},
        "lxmert": {"finetune": False,},
    },
    "mlms": {"finetune": True},
}
epochs = 1
batch_size = 4
optimizer_class = torch.optim.AdamW
optimizer_params = dict(lr=1e-4, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-3)

_ = group_wise_finetune(mlm_model, lr_strategy)
params_conf, _ = group_wise_lr(mlm_model, lr_strategy)
optimizer = optimizer_class(params_conf, **optimizer_params)
train_losses, learning_rates, _ = train(mlm_model, optimizer, scheduler_init_fn, batch_size, epochs, dataset,
                                     model_call_back=None, accumulation_steps=4, plot=True,
                                     sampling_policy=None, class_weights=None)




In [None]:
# torch.save(mlm_model.state_dict(), "lxmert-mlm-init.pth")
# mlm_model.load_state_dict(torch.load("lxmert-mlm-init.pth"))


In [None]:
cache_stats = get_global("cache_stats")
cache_stats['get_img_details']
cache_stats['get_lxmert_details']