# Misinformation detection model (MID) full structure

In [1]:
import os
import re
import copy

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, ViTFeatureExtractor, ViTModel, ViTForImageClassification
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

In [2]:
HOME = os.path.expanduser('~')
TEXT_DATADIR = '~/Projects/Datasets/public_news_set'
IMAGE_DATADIR = 'Projects/Datasets/public_image_set'
# TRAIN_FILE = "train_1000.tsv"
# TEST_FILE = "test_100.tsv"
# VALID_FLIE = "valid_100.tsv"
TRAIN_FILE = "new_train_with_sentiment.tsv"
TEST_FILE = "new_test_with_sentiment.tsv"
VALID_FLIE = "new_valid_with_sentiment.tsv"

SUFFIX = '.jpg'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEXT_MODEL_CKPT = "distilbert-base-uncased"
IMAGE_MODEL_CKPT = "google/vit-base-patch16-224-in21k"

MAX_LENGTH = 32
BATCH_SIZE = 64

In [3]:
mid_train = pd.read_csv(os.path.join(TEXT_DATADIR, TRAIN_FILE), sep='\t')
mid_test = pd.read_csv(os.path.join(TEXT_DATADIR, TEST_FILE), sep='\t')
mid_val = pd.read_csv(os.path.join(TEXT_DATADIR, VALID_FLIE), sep='\t')

In [4]:
mid_train.head()

Unnamed: 0,author,clean_title,created_utc,domain,hasImage,id,image_url,linked_submission_id,num_comments,score,subreddit,title,upvote_ratio,2_way_label,3_way_label,6_way_label,pos,neu,neg
0,jnoble50,red skull,1553267000.0,,True,ej4e1lj,https://i.imgur.com/eD7QGRM.jpg,b44rhx,,58,psbattle_artwork,Red Skull,,0,2,4,0.25,0.5,0.25
1,Gtash,cafe in bangkok with the cutest employees ever...,1559911000.0,nynno.com,True,bxu2dd,https://external-preview.redd.it/MS7vkNibB3Yq1...,,0.0,34,upliftingnews,Cafe in Bangkok With the Cutest Employees Ever...,0.78,1,0,0,0.202655,0.357452,0.439893
2,RoyalPrinceSoldier,he betrayed him,1400820000.0,,True,chp14h4,http://i.imgur.com/9Q9CCDn.jpg,269qyi,,8,psbattle_artwork,He betrayed him!,,0,2,4,0.25,0.5,0.25
3,penguinseed,alderman wants to know exactly what bong shops...,1403114000.0,dnainfo.com,True,28h8p1,https://external-preview.redd.it/lwbRUIzyGF5sU...,,2.0,3,nottheonion,Alderman Wants to Know Exactly What 'Bong Shop...,0.71,1,0,0,0.202655,0.357452,0.439893
4,DM90,man accused of stalking scots police officer s...,1383750000.0,dailyrecord.co.uk,True,1q10us,https://external-preview.redd.it/_fNXvGtKcKn_U...,,2.0,23,nottheonion,Man accused of stalking Scots police officer s...,0.84,1,0,0,0.202655,0.357452,0.439893


## Mid Dataset

In [5]:
mid_train['clean_title'].iloc[1]

'cafe in bangkok with the cutest employees ever corgis'

In [6]:
class MidDataset(Dataset):
    """
    torch dataset for Mid Model
    """
    def __init__(self, dataframe) -> None:
        super().__init__()
        self.df = dataframe
        self.ids = self.df['id'].values
        self.labels = self.df['6_way_label'].values
        self.information = self.df['clean_title'].values
        self.imagepaths = self.df['id'].values
        self.sent_scores = self.df[['pos', 'neu', 'neg']]

    def __getitem__(self, idx):
        item_id = self.ids[idx]
        text = self.information[idx]
        label = self.labels[idx]
        imagepath = os.path.join(HOME, IMAGE_DATADIR, self.imagepaths[idx] + SUFFIX)
        sentiment_scores = torch.tensor(self.sent_scores.iloc[idx].values)
        return (item_id, text, imagepath, sentiment_scores), label

    def __len__(self):
        return len(self.df)

In [7]:
# Test on the text model features
train_dataset = MidDataset(mid_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
valid_dataset = MidDataset(mid_val)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_dataset = MidDataset(mid_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

## Mid Model

In [8]:
class TextFeatureExtractor(nn.Module):
    """
        Text feature extractor
    """
    def __init__(self) -> None:
        super().__init__()
        self.text_model = AutoModel.from_pretrained(TEXT_MODEL_CKPT).to(DEVICE)
        self.text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_CKPT)

    def forward(self, texts):
        texts = [self._text_preprocessing(text) for text in texts]
        encode_sent = self.text_tokenizer(texts, truncation=True, padding='max_length', max_length=MAX_LENGTH)
        encode_sent['input_ids'] = torch.tensor(encode_sent['input_ids']).to(DEVICE)
        encode_sent['attention_mask'] = torch.tensor(encode_sent['attention_mask']).to(DEVICE)
        with torch.no_grad():
            outputs = self.text_model(**encode_sent)
            last_hidden_state = outputs.last_hidden_state[:, 0]
        return last_hidden_state

    def _text_preprocessing(self, text):
        """
        - Lowercase
        - Remove entity name (e.g. @name)
        @param text (str): a string to be processed
        @return text (str): the processed string
        """
        text = re.sub(r'(@.*?)[\s]', ' ', text)
        text = re.sub(r'&amp;', '&', text)
        text = re.sub(r'\s+', ' ', text).strip()

        return text


class ImageFeatureExtractor(nn.Module):
    """
    Image feature extractor
    """
    # ViT model is temporarily selected as a feature extractor
    def __init__(self) -> None:
        super().__init__()
        self.feature_extractor = ViTFeatureExtractor(IMAGE_MODEL_CKPT)
        self.feature_model = ViTModel.from_pretrained(IMAGE_MODEL_CKPT).to(DEVICE)
    def forward(self, imagefiles):
        ims = [Image.open(imagefile) for imagefile in imagefiles]
        ims = list(map(self._mode_convert, ims))
        im_trans = self.feature_extractor(ims, return_tensors='pt').to(DEVICE)
        with torch.no_grad():
            features = self.feature_model(**im_trans)
            last_hidden_state = features.last_hidden_state[:,0]
        return last_hidden_state

    def _mode_convert(self, im):
        if im.mode != 'RGB':
            im = im.convert(mode="RGB")
        return im


class MultiFeatureClassifier(nn.Module):
    # TODO: change the classifier to LSTM or transformer
    def __init__(self) -> None:
        super().__init__()
        # self.classifier = nn.Sequential(
        #     nn.Linear(1536, 768),
        #     nn.ReLU(),
        #     nn.Linear(768, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 6)
        # )
        self.classifier = nn.Sequential(
            nn.Linear(1536, 768),
            nn.ReLU(),
            nn.Linear(768, 6)
        )

    def forward(self, term):
        logits = self.classifier(term)
        return logits


class MidModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.text_feature_extractor = TextFeatureExtractor()
        self.image_feature_extractor = ImageFeatureExtractor()
        self.feature_classifier = MultiFeatureClassifier().to(DEVICE)

    def forward(self, item):
        ids, texts, imagepaths, scores = item
        text_features = self._textfeature(texts)
        image_features = self._imagefeature(imagepaths)
        fusion_features = self._combinefeature(text_features, image_features)
        outputs = self.feature_classifier(fusion_features)
        return outputs

    def _textfeature(self, texts):
        text_features = self.text_feature_extractor(texts)
        return text_features

    def _imagefeature(self, imagepaths):

        image_features = self.image_feature_extractor(imagepaths)
        return image_features

    def _combinefeature(self, text_feature, image_feature, sentiment_features=None):
        # TODO: combine text image sentiment feature
        fusion_features = torch.cat((text_feature, image_feature), axis=1)
        return fusion_features

In [9]:
classifier = MultiFeatureClassifier()
a = torch.randn(3, 1536)
classifier(a)

tensor([[-0.1826, -0.3802,  0.1636, -0.1328, -0.1821, -0.1924],
        [-0.1285, -0.4832,  0.3914, -0.2239, -0.0998,  0.0439],
        [-0.1852, -0.4853, -0.0740,  0.1178, -0.3811, -0.1920]],
       grad_fn=<AddmmBackward0>)

In [10]:
# text_feature_extractor = TextFeatureExtractor()
# image_feature_extractor = ImageFeatureExtractor().to(DEVICE)

# all_features = []
# for item in tqdm(train_loader, total=len(train_loader)):
#     ids, texts, imagepaths, scores = item
#     image_features = image_feature_extractor(imagepaths)
#     text_features = text_feature_extractor(texts)
#     image_features = image_features.to('cpu')
#     text_features = text_features.to('cpu')
#     all_features.extend(list(zip(ids, text_features, image_features)))

In [11]:
# torch.cat((text_features, image_features))

In [12]:
from sklearn.metrics import accuracy_score, f1_score, ConfusionMatrixDisplay, confusion_matrix

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}


def plot_confusion_matrix(y_preds, y_true):
    cm = confusion_matrix(y_true, y_preds, normalize="true")
    fig, ax = plt.subplots(figsize=(6, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False)
    plt.title("Normalized confusion matrix")
    plt.show()

In [13]:
## Train model
mid_model = MidModel()

critieron = nn.CrossEntropyLoss()
optimizer = optim.Adam(mid_model.feature_classifier.parameters(), lr=0.02)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

def train_model(model, trainloader, validloader, criterion, optimizer, scheduler, num_epochs = 25, valid=True):
    model.train()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')

        pbar = tqdm(train_loader, total=len(trainloader))
        for items in pbar:
            
            optimizer.zero_grad()
            item, labels = items
            labels = labels.to(DEVICE)
            outputs = model(item)
            preds = torch.argmax(outputs, 1)
            loss = criterion(outputs, labels)
            pbar.set_postfix({'loss':loss})
            loss.backward()
            optimizer.step()
        scheduler.step()

        if valid == True:
            print(validate_model(model, validloader))

def validate_model(model, valid_loader):
    model.eval()

    val_accuracy = []
    val_loss = []

    for items in tqdm(valid_loader, total=len(valid_loader)):
        model.eval()
        item, labels = items
        with torch.no_grad():
            outputs = model(item)
        labels = labels.to("cuda")
        loss = critieron(outputs, labels)
        val_loss.append(loss.item())
        preds = torch.argmax(outputs, 1).to("cpu")
        accuracy = (preds == labels.to('cpu')).numpy().mean() * 100
        val_accuracy.append(accuracy)
    
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)
    return val_loss, val_accuracy

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
train_model(mid_model, train_loader, valid_loader, critieron, optimizer, exp_lr_scheduler)


Epoch 0/24


100%|██████████| 1706/1706 [23:43<00:00,  1.20it/s, loss=tensor(0.6960, device='cuda:0', grad_fn=<NllLossBackward0>)]
100%|██████████| 214/214 [02:59<00:00,  1.19it/s]


(0.5564075091453357, 80.05013629283488)
Epoch 1/24


100%|██████████| 1706/1706 [23:47<00:00,  1.19it/s, loss=tensor(0.5566, device='cuda:0', grad_fn=<NllLossBackward0>)]
100%|██████████| 214/214 [02:56<00:00,  1.21it/s]


(0.5558396768625652, 80.10611370716512)
Epoch 2/24


100%|██████████| 1706/1706 [23:21<00:00,  1.22it/s, loss=tensor(0.5311, device='cuda:0', grad_fn=<NllLossBackward0>)]
100%|██████████| 214/214 [02:59<00:00,  1.19it/s]


(0.5465067443307315, 80.3129867601246)
Epoch 3/24


  8%|▊         | 139/1706 [01:56<21:53,  1.19it/s, loss=tensor(0.4564, device='cuda:0', grad_fn=<NllLossBackward0>)]


KeyboardInterrupt: 

In [16]:
validate_model(mid_model, test_loader)

100%|██████████| 214/214 [03:08<00:00,  1.13it/s]


(0.6007503149943931, 79.04497663551402)