In [1]:
import torch.nn as nn

class VLPForFeatureExtraction(nn.Module):

    def __init__(self,
                 model_dim,
                 num_layers,
                 num_heads,
                 ff_dim,
                 feature_map_size,
                 dropout=0.0,
                 dec_div=2):
        super(VLPForFeatureExtraction, self).__init__()

        self.vlp = VLP(num_layers=num_layers,
                       model_dim=model_dim,
                       ff_dim=ff_dim,
                       num_heads=num_heads,
                       feature_map_size=feature_map_size,
                       vocab_size=None,
                       dropout=dropout,
                       transformer_type="encoder",
                       dec_div=dec_div)

    def forward(self, images):
        out = self.vlp(images)
        return out.mean(dim=1)

In [2]:
import sys
sys.path.append("../")
import torch
from transformers import AutoTokenizer

from model.vlp import VLP, model_config_factory

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_tokenizer():

    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

    special_tokens_dict = {'additional_special_tokens': ['[START]', '[END]']}
    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    vocab_size = tokenizer.vocab_size + num_added_toks

    return tokenizer, vocab_size

tokenizer, vocab_size = get_tokenizer()

config = model_config_factory("encoder_decoder_lg")
model = VLPForFeatureExtraction(model_dim=config["model_dim"],
                      num_layers=config["num_layers"],
                      num_heads=config["num_heads"],
                      ff_dim=config["ff_dim"],
                      feature_map_size=config["feature_map_size"],
                      dropout=0.0)
model_state = torch.load("../best_model_wiki_enc_sota.pth")["model_state"]
model.load_state_dict(model_state, strict=False)
model.eval()
model.to(device)



VLPForFeatureExtraction(
  (vlp): VLP(
    (feature_extractor): Sequential(
      (0): ResNetFeatureExtractor(
        (feature_extractor): Sequential(
          (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          (4): Sequential(
            (0): Bottleneck(
              (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       

In [3]:
with open("../fonts/ubuntu_fonts.txt", "r") as f:
    fonts_lines = f.readlines()

invalid_font_families = [
    "kacst",
    "lohit",
    "samyak",
]

invalid_fonts = [
    '/usr/share/fonts/type1/urw-base35/D050000L.t1',
    '/usr/share/fonts/opentype/urw-base35/D050000L.otf',
    '/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf',
    '/usr/share/fonts/truetype/Gubbi/Gubbi.ttf',
    '/usr/share/fonts/truetype/fonts-kalapi/Kalapi.ttf',
    '/usr/share/fonts/truetype/sinhala/lklug.ttf',
    '/usr/share/fonts/truetype/Navilu/Navilu.ttf',
    '/usr/share/fonts/truetype/openoffice/opens___.ttf',
    '/usr/share/fonts/truetype/fonts-gujr-extra/padmaa-Medium-0.5.ttf',
    '/usr/share/fonts/truetype/fonts-telu-extra/Pothana2000.ttf',
    '/usr/share/fonts/truetype/malayalam/RaghuMalayalamSans-Regular.ttf',
    '/usr/share/fonts/truetype/fonts-guru-extra/Saab.ttf',
    '/usr/share/fonts/type1/urw-base35/StandardSymbolsPS.t1',
    '/usr/share/fonts/opentype/urw-base35/StandardSymbolsPS.otf',
    '/usr/share/fonts/truetype/fonts-orya-extra/utkal.ttf',
    '/usr/share/fonts/truetype/fonts-telu-extra/vemana2000.ttf',
    '/usrà¦¤à¦¿',
    '/usr/share/fonts/truetype/noto/NotoColorEmoji.ttf'
]

processed_fonts = []
for f_l in fonts_lines:
    font_path = f_l.split(":")[0]
    if font_path in invalid_fonts or any(i_f in font_path.lower() for i_f in invalid_font_families):
            continue
    processed_fonts.append(font_path)

In [4]:
# large (up to 140 words)
font_category_distributions = {
    "small": (30, 30, 55),
    "medium": (23, 30, 65), # 30
    "large": (20, 30, 75)
}


In [5]:
def get_text_coords(im_dims, text_dims, align):
    im_width, im_height = im_dims
    text_width, text_height = text_dims
    if align=="center":
        x_text = (im_width - text_width) / 2
        y_text = (im_height - text_height) / 2
    elif align=="right":
        x_text = im_width - text_width - 10
        y_text = 10
    elif align=="down":
        x_text = (im_width - text_width) / 2
        y_text = im_height - text_height - 10
    elif align=="up":
        x_text = (im_width - text_width) / 2
        y_text = 10
    # align left
    else:
        x_text = 10
        y_text = 10
        
    x_text, y_text = min(4, x_text), min(4, y_text)
    return x_text, y_text


In [6]:
from collections import defaultdict
import random
from tqdm.auto import tqdm
import textwrap

ALIGNMENTS = ["center", "left", "right", "down", "up"]

def get_text_category(num_tokens):
    if num_tokens <= 50:
        return "small"
    elif num_tokens <= 100:
        return "medium"
    else:
        return "large"
    
def sample_random_font():
    return random.choice(processed_fonts)

def sample_random_alignment():
    return random.choices(ALIGNMENTS, weights=[0.25, 0.2, 0.1, 0.2, 0.25])[0]

def sample_fs_and_tw(text_category):
    font_size, min_text_width, max_text_width = font_category_distributions[text_category]
    return font_size, random.randrange(min_text_width, max_text_width)

In [7]:
def create_image(text,
                 filename,
                 image_size=512,
                 font="arial.ttf",
                 font_size=20,
                 bg_color=(255, 255, 255),
                 bg='white',
                 align="left",
                 width=636,
                 height=636,
                 text_width=40,
                 resize=True,
                 save_image=True):
    font_size = int(font_size / 1.5)

    # text width for wrapping
    text = '\n'.join(textwrap.wrap(text, width=text_width))  #, width=18
    font = ImageFont.truetype(font, font_size)

    image = Image.new(mode="RGB", size=(width, height),
                      color="white")  # (700, 620)
    draw = ImageDraw.Draw(image)

    l, t, r, b = draw.multiline_textbbox((0, 0), text, font=font)
    l_offset = abs(l) + 10
    t_offset = abs(t) + 10
    l = l_offset
    t = t_offset
    r += l_offset
    b += t_offset
    draw.text((l, t), text, font=font, fill="black")
    image = image.crop((0, 0, r + 10, b + 10))

    if resize:
        image = image.resize((image_size, image_size), Image.Resampling.LANCZOS)
    if save_image:
        image.save(filename)
    return image

In [8]:
from datasets import load_dataset


dataset = load_dataset("imdb")

Found cached dataset imdb (/home/wavelet/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
import torch
import PIL
import torchvision.transforms.functional as F
from torchvision.transforms import Resize

import torch.nn as nn
import copy

from PIL import Image, ImageFont, ImageDraw
from nltk import word_tokenize


def prep_image(text, tokenizer, max_text_len, device=None, image_size=512):
    if not device:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    num_tokens = len(word_tokenize(text))
    text_category = get_text_category(num_tokens)

    font_path = sample_random_font()
    font_size, text_width = sample_fs_and_tw(text_category)
    alignment = sample_random_alignment()
    
    image = create_image(
                text,
                font=font_path,
                filename=None,
                bg_color=(255, 255, 255),
                bg='white',
                align=alignment,
                font_size=font_size,
                text_width=text_width,
                resize=True,
                save_image=False)
    image = Resize((image_size, image_size))(image)
    return F.to_tensor(image).to(device).unsqueeze(0)


def predict_text(
    model,
    text,
    tokenizer,
    max_text_len,
    image_size=512,
    device=None,
):
    image = prep_image(text, tokenizer, max_text_len)
    out = model(image)
    return out

def predict(
    model,
    images
):
    out = model(images)
    return out

In [10]:
import pandas as pd

data = []

for instance in tqdm(dataset["train"]):
    if len(word_tokenize(instance["text"])) > 144:
        continue
    data.append({"text": instance["text"], "label": instance["label"]})

df_train = pd.DataFrame(data)

  0%|          | 0/25000 [00:00<?, ?it/s]

In [12]:
df_train

Unnamed: 0,text,label
0,If only to avoid making this type of film in t...,0
1,I would put this at the top of my list of film...,0
2,Whoever wrote the screenplay for this movie ob...,0
3,My interest in Dorothy Stratten caused me to p...,0
4,I think I will make a movie next weekend. Oh w...,0
...,...,...
5208,This movie really kicked some ass. I watched i...,1
5209,With the mixed reviews this got I wasn't expec...,1
5210,"Very smart, sometimes shocking, I just love it...",1
5211,A hit at the time but now better categorised a...,1


In [30]:
import spacy
from tqdm.notebook import tqdm
tqdm.pandas()

def extract_embedding(text):
    emb = predict_text(model, text, tokenizer, 144)
    return emb[0].cpu().detach().tolist()

df_train["embedding"] = df_train.text.progress_apply(extract_embedding)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/5213 [00:00<?, ?it/s]

In [31]:
import numpy as np

embeddings = np.array(df_train.embedding.tolist())
labels = df_train.label.astype("category")

In [25]:
from sklearn.svm import LinearSVC, SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
 
X_train, X_test, y_train, y_test = train_test_split(embeddings, labels, test_size=0.2, random_state=42)

print("train", X_train.shape[0], len(y_train))
print("test", X_test.shape[0], len(y_test))
 
pca = PCA(n_components=300)
# clf = LinearSVC(multi_class="ovr", C=0.1, max_iter=1000, random_state=42)
clf = LogisticRegression()
clf = CalibratedClassifierCV(clf)
# pipe = Pipeline(steps=[("pca", pca), ("clf", clf)])
pipe = Pipeline(steps=[("clf", clf)])
 
# Traing a classifer
pipe.fit(X_train, y_train)
 
# Evaluate the classifier
y_pred = pipe.predict(X_test)
 
acc = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
 
print("Accuracy: ", acc)
print("Confusion matrix: \n", conf_matrix)
print("Classification report: \n", class_report)
 
y_pred_train = pipe.predict(X_train)
class_report = classification_report(y_train, y_pred_train)
print("Classification report train: \n", class_report)

train 4170 4170
test 1043 1043


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Accuracy:  0.7785234899328859
Confusion matrix: 
 [[386 120]
 [111 426]]
Classification report: 
               precision    recall  f1-score   support

           0       0.78      0.76      0.77       506
           1       0.78      0.79      0.79       537

    accuracy                           0.78      1043
   macro avg       0.78      0.78      0.78      1043
weighted avg       0.78      0.78      0.78      1043

Classification report train: 
               precision    recall  f1-score   support

           0       0.79      0.77      0.78      1920
           1       0.81      0.83      0.82      2250

    accuracy                           0.80      4170
   macro avg       0.80      0.80      0.80      4170
weighted avg       0.80      0.80      0.80      4170



In [None]:
# VLP

Accuracy:  0.7785234899328859
Confusion matrix: 
 [[386 120]
 [111 426]]
Classification report: 
               precision    recall  f1-score   support

           0       0.78      0.76      0.77       506
           1       0.78      0.79      0.79       537

    accuracy                           0.78      1043
   macro avg       0.78      0.78      0.78      1043
weighted avg       0.78      0.78      0.78      1043

Classification report train: 
               precision    recall  f1-score   support

           0       0.79      0.77      0.78      1920
           1       0.81      0.83      0.82      2250

    accuracy                           0.80      4170
   macro avg       0.80      0.80      0.80      4170
weighted avg       0.80      0.80      0.80      4170

In [None]:
# BERT

Accuracy:  0.8207094918504314
Confusion matrix: 
 [[420  86]
 [101 436]]
Classification report: 
               precision    recall  f1-score   support

           0       0.81      0.83      0.82       506
           1       0.84      0.81      0.82       537

    accuracy                           0.82      1043
   macro avg       0.82      0.82      0.82      1043
weighted avg       0.82      0.82      0.82      1043

Classification report train: 
               precision    recall  f1-score   support

           0       0.82      0.83      0.83      1920
           1       0.86      0.85      0.85      2250

    accuracy                           0.84      4170
   macro avg       0.84      0.84      0.84      4170
weighted avg       0.84      0.84      0.84      4170

In [None]:
# spacy glove

Accuracy:  0.840843720038351
Confusion matrix: 
 [[417  89]
 [ 77 460]]
Classification report: 
               precision    recall  f1-score   support

           0       0.84      0.82      0.83       506
           1       0.84      0.86      0.85       537

    accuracy                           0.84      1043
   macro avg       0.84      0.84      0.84      1043
weighted avg       0.84      0.84      0.84      1043

Classification report train: 
               precision    recall  f1-score   support

           0       0.86      0.84      0.85      1920
           1       0.87      0.88      0.88      2250

    accuracy                           0.86      4170
   macro avg       0.86      0.86      0.86      4170
weighted avg       0.86      0.86      0.86      4170
