In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPTokenizer, CLIPTextModel, CLIPVisionModel, get_cosine_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
import PIL
import os
from tqdm import tqdm
import pandas as pd
import torchvision.transforms as transforms
import regex
import numpy as np
from sklearn.model_selection import train_test_split
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
np.random.seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
elco_df = pd.read_csv('../../data/ELCo.csv')
device

'cpu'

In [4]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
# set_seed()

def comp_type_map(comp_type):
    type_to_label = {'Direct': 0, 'Metaphorical': 1, 'Semantic list': 2, 'Reduplication': 3, 'Single': 4}
    return type_to_label[comp_type]

def label_to_comp_type(label):
    label_to_type = {0: 'Direct', 1: 'Metaphorical', 2: 'Semantic list', 3: 'Reduplication', 4: 'Single'}
    return label_to_type[label]

In [5]:
class EmojiCLIP(nn.Module):
    def __init__(self, clip_model='openai/clip-vit-base-patch32'):
        super().__init__()
        self.vit = CLIPVisionModel.from_pretrained(clip_model)

        # freeze ViT in early training
        for param in self.vit.parameters():
            param.requires_grad = False
    def extract_embedding(self, image):
        with torch.no_grad():
            return self.vit(pixel_values=image).pooler_output
    
    def forward(self, image):
        image_embedding = self.extract_embedding(image)
        return image_embedding

In [6]:
class EmojiImageDataset(Dataset):
    def __init__(self, image_dir, transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])):
        super().__init__()
        self.image_dir = image_dir
        self.images = os.listdir(image_dir)
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        image = PIL.Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image, self.images[idx].removesuffix('.png')

In [7]:
emoji_image_dataset = EmojiImageDataset('images')
emoji_image_dataloader = DataLoader(emoji_image_dataset, batch_size=64, shuffle=False)

model = EmojiCLIP().eval().to(device)
all_embeddings = []
all_emoji_descs = []
with torch.no_grad():
    for batch_images, batch_descs in tqdm(emoji_image_dataloader, desc='Extracting embeddings'):
        batch_images = batch_images.to(device)
        batch_embeddings = model(batch_images)
        all_embeddings.append(batch_embeddings)
        all_emoji_descs.extend(batch_descs)

all_embeddings = torch.cat(all_embeddings, dim=0) # (N, 768)
emoji_embed_dict = {desc: embed for desc, embed in zip(all_emoji_descs, all_embeddings)}

Extracting embeddings: 100%|██████████| 14/14 [00:54<00:00,  3.89s/it]


In [8]:
class TypeClassifier(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=512, num_classes=5, dropout=0.1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
    def forward(self, x):
        return self.model(x)

In [9]:
class EmojisDataset(Dataset):
    def __init__(self, emoji_embed_dict, elco_df, em_max_len=3, text_max_len=4, tokenizer='openai/clip-vit-base-patch32'):
        super().__init__()
        self.elco_df = elco_df
        self.emoji_descriptions = [self.preprocess_emoji_description(desc) for desc in elco_df["Description"]]
        self.raw_emoji_descriptions = elco_df["Description"].values
        self.emoji_embed_dict = emoji_embed_dict
        self.em_max_len = em_max_len
        self.text_max_len = text_max_len
        self.clip_tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
        self.composition_type = [comp_type_map(t) for t in elco_df['Composition strategy'].values]
    
    def preprocess_emoji_description(self, text):
        text = text.replace('\'\'', '').lower()
        split_text = regex.findall(r'\':?(.*?):?\'', text)
        return split_text
    def preprocess_en(self, text):
        return text.lower().strip()
    
    def __len__(self):
        return len(self.emoji_descriptions)
    
    def __getitem__(self, index):
        emoji_descs = self.emoji_descriptions[index]
        emoji_embeds = [self.emoji_embed_dict[desc] for desc in emoji_descs]
        en_text = self.preprocess_en(self.elco_df['EN'].values[index])
        composition_type = self.composition_type[index]
        assert len(emoji_embeds) == len(emoji_descs)
        if len(emoji_embeds) < self.em_max_len:
            emoji_embeds += [torch.zeros_like(emoji_embeds[0]) for _ in range(self.em_max_len - len(emoji_embeds))]
        else:
            emoji_embeds = emoji_embeds[:self.em_max_len]
        # emoji_embeds = torch.stack(emoji_embeds, dim=0)
        emoji_embeds = torch.concatenate(emoji_embeds, dim=-1)
        
        en_tokens = self.clip_tokenizer(en_text, 
                                  truncation=True, 
                                  padding='max_length',
                                  return_tensors='pt',
                                  max_length=self.text_max_len)
        en_tokens = {k: v.squeeze(0) for k, v in en_tokens.items()}
        return (emoji_embeds, en_tokens), composition_type
        

In [10]:
generator = torch.Generator()
# generator.manual_seed(42)
train_df, validate_df = train_test_split(elco_df, test_size=0.1, random_state=42)
# validate_df, test_df = train_test_split(validate_df, test_size=0.3, random_state=42)
emojis_train_dataset = EmojisDataset(emoji_embed_dict, train_df)
emojis_train_dataloader = DataLoader(emojis_train_dataset, batch_size=32, shuffle=True, drop_last=True, generator=generator)
emojis_validate_dataset = EmojisDataset(emoji_embed_dict, validate_df)
emojis_validate_dataloader = DataLoader(emojis_validate_dataset, batch_size=32, shuffle=False, drop_last=True)
# emojis_test_dataset = EmojisDataset(emoji_embed_dict, test_df)
# emojis_test_dataloader = DataLoader(emojis_test_dataset, batch_size=32, shuffle=False, drop_last=True)

In [11]:
for em_en_pair, ctype in emojis_train_dataloader:
    print(em_en_pair[0].shape, em_en_pair[1]['input_ids'].shape)
    print(ctype[0])
    break

torch.Size([32, 2304]) torch.Size([32, 4])
tensor(1)


In [19]:
class EmojiCompositionModel(nn.Module):
    def __init__(self, text_encoder='openai/clip-vit-base-patch32',
                 image_embed_dim=768, projection_dim=512, num_layers=2, freeze_clip=False):
        super().__init__()


        # self.emoji_transformer = ImageEmbedTransformer(embedding_dim=image_embed_dim, num_layers=num_layers)

        self.text_encoder = CLIPTextModel.from_pretrained(text_encoder).text_model
        self.text_hidden_dim = self.text_encoder.config.hidden_size
        
        if freeze_clip:
            for name, param in self.text_encoder.named_parameters():
                    # if "encoder.layers.10" in name or "encoder.layers.11" in name:
                    #     param.requires_grad = True
                    # else:
                    param.requires_grad = False
                    
        
        self.text_proj = nn.Sequential(
            nn.Linear(self.text_hidden_dim, projection_dim),
            nn.ReLU(),
            nn.LayerNorm(projection_dim)
            )
        self.image_proj = nn.Sequential(
            nn.Linear(image_embed_dim, projection_dim),
            nn.ReLU(),
            nn.LayerNorm(projection_dim)
            )
        
        self.classifier = TypeClassifier(input_dim=projection_dim * 4, hidden_dim=projection_dim * 2, num_classes=5)
        
    def forward(self, img_seq, text_input):
        # z_image = self.emoji_transformer(img_seq)[:, 0, :]
        z_image = self.image_proj(img_seq)
        text_out = self.text_encoder(**text_input).pooler_output
        z_text = self.text_proj(text_out)

        z_image = F.normalize(z_image, dim=-1)
        z_text = F.normalize(z_text, dim=-1)
        
        z_concat = torch.cat([z_image, z_text, torch.abs(z_image - z_text), z_image*z_text], dim=-1)
        output = self.classifier(z_concat)

        return (z_image, z_text), output

In [17]:
class EarlyStopping:
    def __init__(self, patience=3, delta=0.0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss, accuracy):
        if accuracy > 0.7:
            self.early_stop = True
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [23]:
# reconstruct text dataloader -> each bacth is a list of all possible en for the emojis in given batch
en_em_model = EmojiCompositionModel(image_embed_dim=2304, freeze_clip=True).to(device)
optimizer = torch.optim.AdamW(en_em_model.parameters(), lr=2e-5)
early_stopping = EarlyStopping(patience=3)
criterion = nn.CrossEntropyLoss()
num_epochs = 70

for epoch in range(num_epochs):
    total_loss = 0.0
    en_em_model.train()
    for (emoji_embeds, en_input), ctype in emojis_train_dataloader:
        (z_emojis, z_text), output = en_em_model(emoji_embeds, en_input)
        
        classifier_loss = criterion(output, torch.tensor(ctype))
        
        match_labels = torch.arange(emojis_train_dataloader.batch_size).to(device)
        loss_per_emojis = z_emojis @ z_text.T
        loss_per_text = loss_per_emojis.T
        loss_em = F.cross_entropy(loss_per_emojis, match_labels)
        loss_text = F.cross_entropy(loss_per_text, match_labels)

        loss = classifier_loss + (loss_em + loss_text) / 2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    total_loss /= len(emojis_train_dataloader)
        
    en_em_model.eval()
    correct_count = 0
    with torch.no_grad():
        for (emoji_embeds, en_input), ctype in emojis_validate_dataloader:
            _, output = en_em_model(emoji_embeds, en_input)
            prediction = torch.argmax(output, dim=-1)
            correct_count += torch.sum(prediction == ctype).item()
        en_accuracy = correct_count / len(emojis_validate_dataset)
    
    print(f'Epoch {epoch} - Loss: {total_loss}, Accuracy: {en_accuracy}')
    
    early_stopping(total_loss, en_accuracy)
    if early_stopping.early_stop:
        print(f'Early stopping at epoch {epoch}')
        break

  classifier_loss = criterion(output, torch.tensor(ctype))


Epoch 0 - Loss: 5.015411843424258, Accuracy: 0.43373493975903615
Epoch 1 - Loss: 4.861016470453014, Accuracy: 0.43373493975903615
Epoch 2 - Loss: 4.65356069025786, Accuracy: 0.43373493975903615
Epoch 3 - Loss: 4.537453889846802, Accuracy: 0.43373493975903615
Epoch 4 - Loss: 4.4455646950265635, Accuracy: 0.4578313253012048
Epoch 5 - Loss: 4.356900650521983, Accuracy: 0.5240963855421686
Epoch 6 - Loss: 4.270261961480846, Accuracy: 0.5421686746987951
Epoch 7 - Loss: 4.172721396321836, Accuracy: 0.5903614457831325
Epoch 8 - Loss: 4.089020371437073, Accuracy: 0.6204819277108434
Epoch 9 - Loss: 4.004605225894762, Accuracy: 0.6445783132530121
Epoch 10 - Loss: 3.922036824019059, Accuracy: 0.6144578313253012
Epoch 11 - Loss: 3.8473214273867398, Accuracy: 0.6385542168674698
Epoch 12 - Loss: 3.7773318342540576, Accuracy: 0.6265060240963856
Epoch 13 - Loss: 3.716186487156412, Accuracy: 0.6445783132530121
Epoch 14 - Loss: 3.645752435145171, Accuracy: 0.6506024096385542
Epoch 15 - Loss: 3.5882493361