# Misinformation detection model (MID) full structure

In [1]:
import os
import re
import copy

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, ViTFeatureExtractor, ViTModel, ViTForImageClassification
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

In [2]:
def same_seed(seed):
    '''
    Fixes random number generator seeds for reproducibility
    '''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


same_seed(123)

In [3]:
HOME = os.path.expanduser('~')
TEXT_DATADIR = '~/Projects/Datasets/public_news_set'
IMAGE_DATADIR = 'Projects/Datasets/public_image_set'
# TRAIN_FILE = "train_1000.tsv"
# TEST_FILE = "test_100.tsv"
# VALID_FLIE = "valid_100.tsv"
TRAIN_FILE = "new_train_with_sentiment.tsv"
TEST_FILE = "new_test_with_sentiment.tsv"
VALID_FLIE = "new_valid_with_sentiment.tsv"

SUFFIX = '.jpg'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# TEXT_MODEL_CKPT = "distilbert-base-uncased"
# TEXT_MODEL_CKPT = "bert-base-uncased"
TEXT_MODEL_CKPT = "roberta-base"
IMAGE_MODEL_CKPT = "google/vit-base-patch16-224-in21k"

MAX_LENGTH = 32
BATCH_SIZE = 64

In [4]:
mid_train = pd.read_csv(os.path.join(TEXT_DATADIR, TRAIN_FILE), sep='\t')
mid_test = pd.read_csv(os.path.join(TEXT_DATADIR, TEST_FILE), sep='\t')
mid_val = pd.read_csv(os.path.join(TEXT_DATADIR, VALID_FLIE), sep='\t')

In [5]:
mid_train.head()

Unnamed: 0,author,clean_title,created_utc,domain,hasImage,id,image_url,linked_submission_id,num_comments,score,subreddit,title,upvote_ratio,2_way_label,3_way_label,6_way_label,pos,neu,neg
0,jnoble50,red skull,1553267000.0,,True,ej4e1lj,https://i.imgur.com/eD7QGRM.jpg,b44rhx,,58,psbattle_artwork,Red Skull,,0,2,4,0.25,0.5,0.25
1,Gtash,cafe in bangkok with the cutest employees ever...,1559911000.0,nynno.com,True,bxu2dd,https://external-preview.redd.it/MS7vkNibB3Yq1...,,0.0,34,upliftingnews,Cafe in Bangkok With the Cutest Employees Ever...,0.78,1,0,0,0.202655,0.357452,0.439893
2,RoyalPrinceSoldier,he betrayed him,1400820000.0,,True,chp14h4,http://i.imgur.com/9Q9CCDn.jpg,269qyi,,8,psbattle_artwork,He betrayed him!,,0,2,4,0.25,0.5,0.25
3,penguinseed,alderman wants to know exactly what bong shops...,1403114000.0,dnainfo.com,True,28h8p1,https://external-preview.redd.it/lwbRUIzyGF5sU...,,2.0,3,nottheonion,Alderman Wants to Know Exactly What 'Bong Shop...,0.71,1,0,0,0.202655,0.357452,0.439893
4,DM90,man accused of stalking scots police officer s...,1383750000.0,dailyrecord.co.uk,True,1q10us,https://external-preview.redd.it/_fNXvGtKcKn_U...,,2.0,23,nottheonion,Man accused of stalking Scots police officer s...,0.84,1,0,0,0.202655,0.357452,0.439893


## Mid Dataset

In [6]:
mid_train['clean_title'].iloc[1]

'cafe in bangkok with the cutest employees ever corgis'

In [7]:
class MidDataset(Dataset):
    """
    torch dataset for Mid Model
    """
    def __init__(self, dataframe) -> None:
        super().__init__()
        self.df = dataframe
        self.ids = self.df['id'].values
        self.labels = self.df['6_way_label'].values
        self.information = self.df['clean_title'].values
        self.imagepaths = self.df['id'].values
        self.sent_scores = self.df[['pos', 'neu', 'neg']]

    def __getitem__(self, idx):
        item_id = self.ids[idx]
        text = self.information[idx]
        label = self.labels[idx]
        imagepath = os.path.join(HOME, IMAGE_DATADIR, self.imagepaths[idx] + SUFFIX)
        sentiment_scores = torch.tensor(self.sent_scores.iloc[idx].values)
        return (item_id, text, imagepath, sentiment_scores), label

    def __len__(self):
        return len(self.df)

In [8]:
# Test on the text model features
train_dataset = MidDataset(mid_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataset = MidDataset(mid_val)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_dataset = MidDataset(mid_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

## Mid Model

In [9]:
class TextFeatureExtractor(nn.Module):
    """
        Text feature extractor (Bert Series)
    """
    def __init__(self) -> None:
        super().__init__()
        self.text_model = AutoModel.from_pretrained(TEXT_MODEL_CKPT).to(DEVICE)
        self.text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_CKPT)

    def forward(self, texts):
        texts = [self._text_preprocessing(text) for text in texts]
        encode_sent = self.text_tokenizer(texts, truncation=True, padding='max_length', max_length=MAX_LENGTH)
        encode_sent['input_ids'] = torch.tensor(encode_sent['input_ids']).to(DEVICE)
        encode_sent['attention_mask'] = torch.tensor(encode_sent['attention_mask']).to(DEVICE)
        with torch.no_grad():
            outputs = self.text_model(**encode_sent)
            last_hidden_state = outputs.last_hidden_state[:, 0]
        return last_hidden_state

    def _text_preprocessing(self, text):
        """
        - Lowercase
        - Remove entity name (e.g. @name)
        @param text (str): a string to be processed
        @return text (str): the processed string
        """
        text = re.sub(r'(@.*?)[\s]', ' ', text)
        text = re.sub(r'&amp;', '&', text)
        text = re.sub(r'\s+', ' ', text).strip()

        return text


class ViTImageFeatureExtractor(nn.Module):
    """
    Image feature extractor with ViT model
    """
    def __init__(self) -> None:
        super().__init__()
        self.feature_extractor = ViTFeatureExtractor(IMAGE_MODEL_CKPT)
        self.feature_model = ViTModel.from_pretrained(IMAGE_MODEL_CKPT).to(DEVICE)
    def forward(self, imagefiles):
        ims = [Image.open(imagefile) for imagefile in imagefiles]
        ims = list(map(self._mode_convert, ims))
        im_trans = self.feature_extractor(ims, return_tensors='pt').to(DEVICE)
        with torch.no_grad():
            features = self.feature_model(**im_trans)
            last_hidden_state = features.last_hidden_state[:,0]
        return last_hidden_state

    def _mode_convert(self, im):
        if im.mode != 'RGB':
            im = im.convert(mode="RGB")
        return im


class ImageFeatureExtractor(nn.Module):
    """
    Image feature extractor with ResNet50 / VGG16
    """
    def __init__(self) -> None:
        super().__init__()

        self.feature_model = models.resnet50(pretrained=True)
        self.feature_model.fc = nn.Linear(2048, 768)
        self.feature_model = self.feature_model.to(DEVICE)
        self.training = True
        
    def forward(self, imagefiles):
        ims = [Image.open(imagefile) for imagefile in imagefiles]
        im_trans = torch.stack(list(map(self._im_transform, ims))).to(DEVICE)
        with torch.no_grad():
            features = self.feature_model(im_trans)
        return features

    def _im_transform(self, im, train=True):
        train = self.training
        if im.mode != "RGB":
            im = im.convert(mode="RGB")
        # transform the train data
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        # transform the test and validate data
        transform_val = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        im_trans = transform_train(im) if train else transform_val(im)

        return im_trans


class MultiFeatureClassifier(nn.Module):
    """
    LSTM
    """
    def __init__(self) -> None:
        super().__init__()
        # common ANN for classifier
        self.classifier = nn.Sequential(
            nn.Linear(1539, 768),
            nn.ReLU(),
            nn.Linear(768, 6)
        )

        # basic LSTM for classification
        self.lstm = nn.LSTM(1539, 768, bidirectional=True, batch_first =True)
        self.fc1 = nn.Linear(1536, 768)
        self.bn = nn.BatchNorm1d(512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 6)


    def forward(self, term):
        # outputs = self.transformer_encoder(term)
        outputs, _ = self.lstm(term)
        outputs = self.fc1(outputs)
        outputs = self.bn(outputs)
        outputs = self.relu(outputs)
        logits = self.fc2(outputs)
        return logits


class MultiFeatureClassifier(nn.Module):
    """
    LSTM + attention
    """
    def __init__(self, image_model='vit') -> None:
        super().__init__()
        
        self.input_size = {'vit': 1539, 'resnet': 1539, 'vgg': 4867}
        self.classifier = nn.Sequential(
            nn.Linear(self.input_size[image_model], 768),
            nn.ReLU(),
            nn.Linear(768, 6)
        )
        self.lstm = nn.LSTM(self.input_size[image_model], 768, 2)
        self.fc1 = nn.Linear(768, 512)
        self.bn = nn.BatchNorm1d(512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 6)

    def attention_net(self, lstm_output, final_state):
        lstm_output = lstm_output.permute(1, 0, 2)
        merged_state = torch.mean(torch.cat([s for s in final_state]), 0)
        hidden = merged_state.unsqueeze(1).unsqueeze(0)
        attn_weights = torch.bmm(lstm_output, hidden)
        soft_attn_weights = F.softmax(attn_weights, dim=1)

        new_hidden_state = (lstm_output.transpose(1,2).squeeze() * soft_attn_weights.squeeze()).transpose(0,1)
        return new_hidden_state

    def forward(self, term):
        # term = term.unsqueeze(1)
        # outputs, (hidden, cell) = self.lstm(term)
        # attn_output = self.attention_net(outputs, hidden)
        # attn_output = self.fc1(attn_output)
        # attn_output = self.bn(attn_output)
        # attn_output = self.relu(attn_output)
        # logits = self.fc2(attn_output)
        logits = self.classifier(term)
        return logits




class MidModel(nn.Module):
    def __init__(self, image_model='vit'):
        super().__init__()
        self.image_model = image_model
        self.training = True
        self.text_feature_extractor = TextFeatureExtractor()
        self.image_feature_extractor = ViTImageFeatureExtractor() if image_model == 'vit' else ImageFeatureExtractor()
        self.feature_classifier = MultiFeatureClassifier(image_model).to(DEVICE)

    def forward(self, item):
        ids, texts, imagepaths, scores = item
        text_features = self._textfeature(texts)
        image_features = self._imagefeature(imagepaths)
        fusion_features = self._combinefeature(text_features, image_features, sentiment_features=scores)
        outputs = self.feature_classifier(fusion_features)
        return outputs

    def _textfeature(self, texts):
        text_features = self.text_feature_extractor(texts)
        return text_features

    def _imagefeature(self, imagepaths):
        if self.image_model != 'vit':
            self.image_feature_extractor.training = self.training
        image_features = self.image_feature_extractor(imagepaths)
        return image_features

    def _combinefeature(self, text_feature, image_feature, sentiment_features=None):
        if sentiment_features != None:
            sentiment_features = sentiment_features.float().to(DEVICE)
        
        fusion_features = torch.cat((text_feature, image_feature), axis=1) if sentiment_features == None else torch.cat((text_feature, sentiment_features, image_feature), axis=1)
        return fusion_features

In [10]:
# Test final classifier
classifier = MultiFeatureClassifier()
a = torch.randn(64, 1539)
classifier(a).shape

torch.Size([64, 6])

In [11]:
# Test image feature extractor
imagefeature = ImageFeatureExtractor()

for i in train_loader:
    item, labels = i
    ids, texts, imagepaths, scores = item
    features = imagefeature(imagepaths)
    print(features)
    break

tensor([[-0.1046,  0.4809,  0.2576,  ...,  0.3444,  0.3518,  0.2544],
        [-0.2915,  0.4736,  0.0614,  ...,  0.0478,  0.6280, -0.3919],
        [-0.3044,  0.3286,  0.2695,  ...,  0.0605,  0.4982, -0.0940],
        ...,
        [-0.2487,  0.1355,  0.0941,  ...,  0.5524,  0.5215,  0.4005],
        [ 0.1414,  0.0319,  0.0988,  ..., -0.0450,  0.4359,  0.3888],
        [-0.1296,  0.1772,  0.2801,  ..., -0.1226,  0.2787,  0.0326]],
       device='cuda:0')


In [12]:
# Test feature extractors
# text_feature_extractor = TextFeatureExtractor()
# image_feature_extractor = ImageFeatureExtractor().to(DEVICE)

# all_features = []
# for item in tqdm(train_loader, total=len(train_loader)):
#     ids, texts, imagepaths, scores = item
#     image_features = image_feature_extractor(imagepaths)
#     text_features = text_feature_extractor(texts)
#     image_features = image_features.to('cpu')
#     text_features = text_features.to('cpu')
#     all_features.extend(list(zip(ids, text_features, image_features)))

In [13]:
# Test feature fusion
# torch.cat((text_features, image_features))

In [14]:
## Validation metrics for model
from sklearn.metrics import accuracy_score, f1_score, ConfusionMatrixDisplay, confusion_matrix

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}


def plot_confusion_matrix(y_preds, y_true):
    cm = confusion_matrix(y_true, y_preds, normalize="true")
    fig, ax = plt.subplots(figsize=(6, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False)
    plt.title("Normalized confusion matrix")
    plt.show()

In [15]:
## Model training and evaluation
def train_model(model, trainloader, validloader, criterion, optimizer, scheduler, num_epochs = 5, valid=True):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    image_model = model.image_model
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')

        pbar = tqdm(train_loader, total=len(trainloader))
        for items in pbar:
            model.train()
            optimizer.zero_grad()
            item, labels = items
            labels = labels.to(DEVICE)
            outputs = model(item)
            preds = torch.argmax(outputs, 1)
            loss = criterion(outputs, labels)
            pbar.set_postfix({'loss':loss})
            loss.backward()
            optimizer.step()
        scheduler.step()

        if valid == True:
            print(validate_model(model, criterion, validloader, image_model))

def validate_model(model, critierion, valid_loader, image_model='vit'):
    model.eval()

    val_accuracy = []
    val_loss = []

    for items in tqdm(valid_loader, total=len(valid_loader)):
        model.eval()
        item, labels = items
        with torch.no_grad():
            if image_model != 'vit':
                model.training = False
                outputs = model(item)
            else:
                outputs = model(item)
        labels = labels.to("cuda")
        loss = critierion(outputs, labels)
        val_loss.append(loss.item())
        preds = torch.argmax(outputs, 1).to("cpu")
        accuracy = (preds == labels.to('cpu')).numpy().mean() * 100
        val_accuracy.append(accuracy)
    
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)
    return val_loss, val_accuracy

In [16]:
mid_model = MidModel(image_model='resnet')
critierion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mid_model.feature_classifier.parameters(), lr=0.0005)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
train_model(mid_model, train_loader, valid_loader, critierion, optimizer, exp_lr_scheduler)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch 0/4


100%|██████████| 1706/1706 [17:14<00:00,  1.65it/s, loss=tensor(0.6256, device='cuda:0', grad_fn=<NllLossBackward0>)]
100%|██████████| 214/214 [02:21<00:00,  1.52it/s]


(0.7172805863284619, 73.48130841121495)
Epoch 1/4


100%|██████████| 1706/1706 [16:52<00:00,  1.68it/s, loss=tensor(0.8590, device='cuda:0', grad_fn=<NllLossBackward0>)]
100%|██████████| 214/214 [02:21<00:00,  1.52it/s]


(0.6317783666548328, 77.27803738317758)
Epoch 2/4


  6%|▋         | 108/1706 [01:06<16:26,  1.62it/s, loss=tensor(0.5746, device='cuda:0', grad_fn=<NllLossBackward0>)]


KeyboardInterrupt: 

In [None]:
# validate_model(mid_model, critierion, test_loader)

In [None]:
# torch.save(mid_model, 'roberta_vit_lstm_attn.pth')