# Imports and data

In [None]:




import pandas as pd
import numpy as np
import jsonlines
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch_optimizer as optim
import random
from transformers import AutoModelWithLMHead, AutoTokenizer, AutoModel

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from importlib import reload
pd.set_option('display.max_rows', 500)
pd.set_option('display.float_format', '{:0.3f}'.format)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.options.display.width = 0
import warnings
import torchvision
warnings.filterwarnings('ignore')

from facebook_hateful_memes_detector.utils.globals import set_global, get_global
set_global("cache_dir", "/home/ahemf/cache/cache")
set_global("dataloader_workers", 8)
set_global("use_autocast", True)
set_global("models_dir", "/home/ahemf/cache/")

from facebook_hateful_memes_detector.utils import read_json_lines_into_df, in_notebook, set_device, random_word_mask, dict2sampleList, run_simclr, load_stored_params
get_global("cache_dir")
from facebook_hateful_memes_detector.models import Fasttext1DCNNModel, MultiImageMultiTextAttentionEarlyFusionModel, LangFeaturesModel, AlbertClassifer
from facebook_hateful_memes_detector.preprocessing import TextImageDataset, get_datasets, get_image2torchvision_transforms, TextAugment
from facebook_hateful_memes_detector.preprocessing import DefinedRotation, QuadrantCut, ImageAugment, DefinedAffine, HalfSwap, get_image_transforms, get_transforms_for_bbox_methods
from facebook_hateful_memes_detector.preprocessing import NegativeSamplingDataset, ImageFolderDataset, ZipDatasets
from facebook_hateful_memes_detector.models.MultiModal.VilBertVisualBert import VilBertVisualBertModel
from facebook_hateful_memes_detector.training import *
import facebook_hateful_memes_detector
from facebook_hateful_memes_detector.utils import get_vgg_face_model, get_torchvision_classification_models, init_fc, my_collate, merge_sample_lists
reload(facebook_hateful_memes_detector)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_device(device)

scheduler_init_fn = get_cosine_schedule_with_warmup()
# Use mixup in SSL training, Use UDA maybe


In [None]:
def get_preprocess_text():
    char_level = {
        "keyboard": 0.1,
        "char_substitute": 0.4,
        "char_insert": 0.2,
        "char_swap": 0.2,
        "ocr": 0.0,
        "char_delete": 0.1
    }
    char_level = TextAugment([0.1, 0.8, 0.1], char_level)
    word_level = {
        "fasttext": 0.0,
        "glove_twitter": 0.0,
        "glove_wiki": 0.0,
        "word2vec": 0.0,
        "split": 0.2,
        "stopword_insert": 0.0,
        "word_join": 0.2,
        "word_cutout": 0.8,
        "gibberish_insert": 0.0
    }
    word_level = TextAugment([0.2, 0.7, 0.1], word_level)
    sentence_level = {
        "text_rotate": 0.0,
        "sentence_shuffle": 0.0,
        "one_third_cut": 0.3,
        "half_cut": 0.0,
        "part_select": 0.75
    }
    sentence_level = TextAugment([0.75, 0.25], sentence_level)
    gibberish = {
        "gibberish_insert": 0.25,
        "punctuation_insert": 0.75,
    }
    gibberish = TextAugment([0.75, 0.25], gibberish)

    def process(text):
        text = sentence_level(text)
        text = word_level(text)
        text = char_level(text)
        text = gibberish(text)
        return text

    return process


preprocess_text = get_preprocess_text()
transforms_for_bbox_methods = get_transforms_for_bbox_methods()

preprocess_easy = transforms.Compose([
    get_image_transforms(mode="easy"),
    get_image2torchvision_transforms(),
])

preprocess = transforms.Compose([
    get_image_transforms(mode="hard"),
    get_image2torchvision_transforms(),
])

data = get_datasets(data_dir="../data/",
                    train_text_transform=preprocess_text,
                    train_image_transform=transforms_for_bbox_methods,
                    test_text_transform=None,
                    test_image_transform=None,
                    train_torchvision_image_transform=None,
                    test_torchvision_image_transform=None,
                    train_torchvision_pre_image_transform=None,
                    test_torchvision_pre_image_transform=None,
                    cache_images=True,
                    use_images=True,
                    dev=False,
                    test_dev=True,
                    keep_original_text=True,
                    keep_original_image=True,
                    keep_processed_image=True,
                    keep_torchvision_image=False,
                    train_mixup_config=dict(proba=0.0))



df = pd.concat((data["train"].drop(columns=["label"]),
                data["dev"].drop(columns=["label"]), data["test"]))

In [None]:
dataset = convert_dataframe_to_dataset(df, data["metadata"], True)

In [None]:
# MLM for TextImage, AugSim for All

# data = get_datasets(data_dir="../data/",
#                     train_text_transform=None,
#                     train_image_transform=None,
#                     test_text_transform=None,
#                     test_image_transform=None,
#                     train_torchvision_image_transform=None,
#                     test_torchvision_image_transform=None,
#                     train_torchvision_pre_image_transform=None,
#                     test_torchvision_pre_image_transform=None,
#                     cache_images=True,
#                     use_images=True,
#                     dev=False,
#                     test_dev=True,
#                     keep_original_text=True,
#                     keep_original_image=True,
#                     keep_processed_image=True,
#                     keep_torchvision_image=False,
#                     train_mixup_config=dict(proba=0.0))
# dataset = convert_dataframe_to_dataset(df, data["metadata"], True)



In [None]:
vectorized_text_processor = np.vectorize(preprocess_text)

def torch_vectorize(fn):
    def vfn(elements):
        elements = [fn(e) for e in elements]
        return torch.stack(elements)# .type(torch.cuda.HalfTensor)
    return vfn

preprocess_torchvision_image = torch_vectorize(preprocess)

def vectorized_image_processor(images):
    return [transforms_for_bbox_methods(i) for i in images]

def augment_method(sampleList):
    sampleList = dict2sampleList(sampleList, device=get_device())
    sampleList = sampleList.copy()
    if "torchvision_image" in sampleList:
        sampleList.torchvision_image = preprocess_torchvision_image(sampleList.original_image)
    sampleList.image = vectorized_image_processor(sampleList.original_image)
    sampleList.text = vectorized_text_processor(sampleList.original_text)
    sampleList.mixup = [False] * len(sampleList.text)
    sampleList = sampleList.to(get_device())
    return sampleList


In [None]:
dataset.show_mixup(5000) # 191

In [None]:
model_params = dict(
    model_name={"lxmert": dict(dropout=0.05, gaussian_noise=0.01)},
    num_classes=2,
    gaussian_noise=0.0,
    dropout=0.0,
    word_masking_proba=0.1,
    featurizer="pass",
    final_layer_builder=fb_1d_loss_builder,
    internal_dims=768,
    classifier_dims=768,
    n_tokens_in=96,
    n_tokens_out=96,
    n_layers=0,
    attention_drop_proba=0.0,
    loss="focal",
    dice_loss_coef=0.0,
    auc_loss_coef=0.0,
    bbox_swaps=1,
    bbox_copies=1,
    bbox_gaussian_noise=0.0,
    finetune=False)

model_class = VilBertVisualBertModel
model = model_class(**model_params)
model.to(get_device())



# Unimodal MLM

In [None]:


adam = torch.optim.Adam
adam_params = params = dict(lr=1e-4, weight_decay=1e-2)
optimizer = adam
optimizer_params = adam_params


from facebook_hateful_memes_detector.utils import MLMPretraining
mlm_model = MLMPretraining(model, model.text_processor._tokenizer, 768, "relu", 96)
mlm_model = mlm_model.to(get_device())


In [None]:

lr_strategy = {
    "model": {
        "finetune": False,
    },
    "mlm": {
        "finetune": True
    }
}
epochs = 3
batch_size = 128
optimizer_class = torch.optim.AdamW
optimizer_params = dict(lr=1e-4, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-3)

_ = group_wise_finetune(mlm_model, lr_strategy)
params_conf, _ = group_wise_lr(mlm_model, lr_strategy)
optimizer = optimizer_class(params_conf, **optimizer_params)
train_losses, learning_rates, _ = train(mlm_model, optimizer, scheduler_init_fn, batch_size, epochs, dataset,
                                     model_call_back=None, accumulation_steps=4, plot=True,
                                     sampling_policy=None, class_weights=None)



mlm_model.plot_loss_acc_hist()
mlm_model.test_accuracy(batch_size, dataset)




In [None]:
# torch.save(mlm_model.state_dict(), "lxmert-mlm-init.pth")
# mlm_model.load_state_dict(torch.load("lxmert-mlm-init.pth"))


In [None]:

epochs = 5
batch_size = 48
optimizer_class = torch.optim.AdamW
optimizer_params = dict(lr=1e-5, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-4)

lr_strategy = {
    "model": {
        "finetune": True,
        "lr": optimizer_params["lr"]
    },
    "mlm": {
        "finetune": True
    }
}

_ = group_wise_finetune(mlm_model, lr_strategy)
params_conf, _ = group_wise_lr(mlm_model, lr_strategy)
optimizer = optimizer_class(params_conf, **optimizer_params)
train_losses, learning_rates, _ = train(mlm_model, optimizer, scheduler_init_fn, batch_size, epochs, dataset,
                                     model_call_back=None, accumulation_steps=5, plot=True,
                                     sampling_policy=None, class_weights=None)



mlm_model.plot_loss_acc_hist()
acc = mlm_model.test_accuracy(batch_size, dataset)


In [None]:
acc = mlm_model.test_accuracy(batch_size, dataset)

In [None]:
torch.save(mlm_model.model.state_dict(), "lxmert-mlm.pth")


# AugSim
- Combine both unimodal and bimodal augsim using `random.random`
- Take hints from SimCLR
- We can do Text x Image (TODO: CrissCrossDataset for Augsim)


In [None]:
load_stored_params(model, "lxmert-mlm.pth")
set_global("cache_allow_writes", True)


In [None]:
adamw = torch.optim.AdamW
adamw_params = dict(lr=1e-5, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-3)
optimizer_class = adamw
optimizer_params = adamw_params


In [None]:
epochs = 10
batch_size = 64


_ = group_wise_finetune(model, lr_strategy_model)
params_conf, _ = group_wise_lr(model, lr_strategy_model)
optim = optimizer_class(params_conf, **optimizer_params)

_ = train_for_augment_similarity(model,
                                 optim,
                                 scheduler_init_fn,
                                 batch_size,
                                 epochs,
                                 dataset,
                                 augment_method=augment_method,
                                 model_call_back=None,
                                 collate_fn=my_collate,
                                 accumulation_steps=4,
                                 plot=True)
# 0.001580, 0.000527
# Try Augsim with L2 normed / LayerNormed vectors


In [None]:
torch.save(model.state_dict(), "lxmert-augsim.pth")
# model.load_state_dict(torch.load("lxmert-augsim.pth"))

# SimCLR style or Differentiator
- Combine Unimodal and Bimodal with probability
- In unimodal differentiator we only change either text or image
- Ability to use non-overlapping image sections.


In [None]:
load_stored_params(model, "lxmert-smclr.pth")


In [None]:
from facebook_hateful_memes_detector.utils import SimCLR

def simclr_aug(sampleList):
    sampleList = augment_method(sampleList.copy())
    s2 = sampleList.copy()
    s2.text = list(reversed(s2.text))
    s = merge_sample_lists(sampleList, s2)
    return s

# set_global("cache_allow_writes", False)


In [None]:
smclr = SimCLR(model, 768, 256, 0.05, simclr_aug, simclr_aug)
smclr = smclr.to(get_device())

lr_strategy_pre = {
    "finetune": True,
    "model": {
        "finetune": False,
    },
}

lr_strategy_post = {
    "finetune": True,
}

pre_lr, post_lr = 5e-5, 5e-5
pre_batch_size, post_batch_size = 256, 32
pre_epochs, full_epochs = 2, 5
collate_fn = my_collate

def simclr_aug(sampleList):
    sampleList = augment_method(sampleList.copy())


In [None]:
res = run_simclr(smclr, dataset, dataset, lr_strategy_pre, lr_strategy_post, pre_lr, post_lr,
           pre_batch_size, post_batch_size, pre_epochs, full_epochs,
           collate_fn)

res

# 0.3268


In [None]:
torch.save(model.state_dict(), "lxmert-smclr.pth")