In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir /content/data

In [23]:
!mkdir /content/twitter2015_model

In [None]:
!mkdir /content/twitter2017_model

In [2]:
!cp -r /content/drive/MyDrive/twitter2015 /content/data/

In [3]:
!cp -r /content/drive/MyDrive/twitter2015_images /content/data/

In [4]:
!cp -r /content/drive/MyDrive/twitter2017 /content/data/

In [5]:
!cp -r /content/drive/MyDrive/twitter2017_images /content/data/

In [6]:
!pip install torch torchvision transformers pillow pytorch-crf

Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached

In [7]:
!pip install pytorch-crf



In [8]:
!cp /content/drive/MyDrive/nlp-files/twitter2015/model.pth /content/twitter2015_model/

In [9]:
!cp /content/drive/MyDrive/nlp-files/twitter2017/model.pth /content/twitter2017_model/

In [10]:
import os
import json
import torch
from torch import nn, optim
from torchvision.models import resnet152
from transformers import BertModel, BertTokenizer
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torchvision.transforms
from transformers import BertTokenizer, BertConfig, BertPreTrainedModel
from sklearn.metrics import f1_score
from torch.optim import Adam
from transformers import get_linear_schedule_with_warmup
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchcrf import CRF
from sklearn.metrics import classification_report

In [11]:
class TwitterDataset(Dataset):
    def __init__(self, data_folder, img_folder, tokenizer, transform, file_name, label2id, id2label):
        super().__init__()
        self.data_lines = []
        self.img_folder = img_folder
        self.tokenizer = tokenizer
        self.transform = transform
        self.label2id = label2id
        self.id2label = id2label

        # Load data from the specified file
        file_path = os.path.join(data_folder, file_name)
        with open(file_path, 'r', encoding="utf8") as file:
            img_id = None
            text = []
            labels = []

            # counter = 0
            for line in file:
                # if counter == 100:
                # break
                if line.strip() == '' and img_id is not None:  # save previous instance
                    try:
                        image_path = os.path.join(self.img_folder, img_id)

                        test_image = Image.open(image_path).convert("RGB")
                        self.data_lines.append((img_id, text, labels))
                    except:
                        print("Skipping corrupted image")

                    finally:
                        img_id = None  # Reset for the next image
                        text = []
                        labels = []

                elif line.startswith('IMGID:'):
                    img_id = line.strip().split(':')[1] + '.jpg'  # New image id
                else:
                    parts = line.strip().split('\t')
                    if len(parts) == 2:
                        text.append(parts[0])
                        labels.append(parts[1])

                # counter+=1

            # Save last instance if not empty
            if img_id is not None:
                img_path = os.path.join(self.img_folder, img_id)
                try:
                    test_image = Image.open(image_path).convert("RGB")
                    self.data_lines.append((img_id, text, labels))

                except:
                    print("Skipping again corrupted images !!")

    def __len__(self):
        return len(self.data_lines)

    def __getitem__(self, idx):
        img_id, text, labels = self.data_lines[idx]
        image_path = os.path.join(self.img_folder, img_id)
        image = Image.open(image_path).convert('RGB')
        text = ' '.join(text)
        labels = [self.label_to_idx(label, self.label2id) for label in labels]  # Convert labels to indices

        inputs = self.tokenizer(text, padding='max_length', max_length=MAX_LENGTH, truncation=True, return_tensors="pt")
        image = self.transform(image)
        labels = torch.tensor(labels, dtype=torch.long)

        return inputs.input_ids.squeeze(0), inputs.attention_mask.squeeze(0), image, labels

    @staticmethod
    def label_to_idx(label, label_map):
        # Define your label to index mapping based on your dataset's labels
        try:
            result = label_map[label]  # Convert unrecognized labels to 'O'
        except:
            result = 0

        return result

In [19]:
class ContextAwareGate(nn.Module):
    def __init__(self, text_dim, visual_dim):
        super().__init__()
        # network to calculate threshold
        self.threshold_network = nn.Sequential(
            nn.Linear(text_dim, 1),  # Averaging text features to a single value
            nn.Sigmoid()  # Ensure the threshold is between 0 and 1
        )
        self.gate = nn.Sequential(
            nn.Linear(visual_dim * 2, visual_dim),  # Combine visual and transformed text features
            nn.Tanh(),
            nn.Linear(visual_dim, visual_dim),
            nn.Sigmoid()
        )

    def forward(self, text_features, visual_features):
        combined_features = torch.cat([text_features, visual_features], dim=-1)
        # Compute gating values
        gate_values = self.gate(combined_features)
        # Apply the gate to the visual features only
        text_mean = torch.mean(text_features, dim=1)
        update_threshold = self.threshold_network(text_mean).squeeze()  # Ensuring scalar output per batch item

        # if the threshold is larger than certain value apply combined features to the visual features
        update_threshold_expanded = update_threshold.unsqueeze(-1).unsqueeze(-1)
        update_mask = (gate_values > update_threshold_expanded).float()

        # Apply the gate to the visual features conditionally
        updated_visual_features = visual_features * (1 - update_mask) + (visual_features * gate_values) * update_mask

        return updated_visual_features


class DynamicAttentionModule(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.text_weight_predictor = nn.Sequential(
            nn.Linear(feature_dim, 1),
            nn.Sigmoid()
        )
        self.visual_weight_predictor = nn.Sequential(
            nn.Linear(feature_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, text_features, visual_features):
        text_weights = self.text_weight_predictor(text_features).expand_as(text_features)
        visual_weights = self.visual_weight_predictor(visual_features).expand_as(visual_features)
        attended_text = text_features * text_weights
        attended_visuals = visual_features * visual_weights
        return attended_text, attended_visuals


class BertCrossAttention(nn.Module):
    """Implements cross-attention between two different modalities using a decoder layer."""

    def __init__(self, config, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
            for _ in range(num_layers)
        ])

    def forward(self, query, key, mask=None):
        output = query
        for layer in self.layers:
            output = layer(output, key, tgt_key_padding_mask=mask)
        return output


class MTCCMBertForMMTokenClassificationCRF(BertPreTrainedModel):
    def __init__(self, config, num_labels, add_context_aware_gate=False, use_dynamic_cross_modal_fusion=False):
        super().__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.resnet = resnet152(pretrained=True)
        self.resnet.fc = nn.Identity()  # Adapt ResNet to remove the final fully connected layer

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vismap2text = nn.Linear(2048, config.hidden_size)

        self.txt2img_attention = BertCrossAttention(config, num_layers=1)
        self.img2txt_attention = BertCrossAttention(config, num_layers=1)

        self.add_context_aware_gate = add_context_aware_gate
        # Initialize the visual filter gate
        if add_context_aware_gate:
            self.visual_gate = ContextAwareGate(config.hidden_size, 768)

        if use_dynamic_cross_modal_fusion:
            self.dynamic_attention = DynamicAttentionModule(config.hidden_size)

        self.classifier = nn.Linear(config.hidden_size * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)
        self.init_weights()

        self.use_dynamic_cross_modal_fusion = use_dynamic_cross_modal_fusion

    def forward(self, input_ids, attention_mask, visual_embeds, labels=None):
        # Text feature extraction
        text_outputs = self.bert(input_ids, attention_mask=attention_mask)
        text_features = self.dropout(text_outputs.last_hidden_state)

        # Image feature extraction with ResNet-152
        visual_features = self.resnet(visual_embeds)  # Assuming visual_embeds is [batch_size, 3, 224, 224]
        visual_features = visual_features.view(visual_features.size(0), -1)  # Flatten the output of the ResNet
        visual_features = self.vismap2text(visual_features)  # Transform to match BERT hidden size
        visual_features = visual_features.unsqueeze(1).expand(-1, text_features.size(1),
                                                              -1)  # Expand to match text sequence length

        if self.add_context_aware_gate:
            visual_features = self.visual_gate(text_features, visual_features)

        # Cross-modal attention
        if self.use_dynamic_cross_modal_fusion:
            attended_text, attended_visuals = self.dynamic_attention(text_features, visual_features)
            combined_features = torch.cat([attended_text, attended_visuals], dim=-1)
        else:
            txt_attended_visuals = self.txt2img_attention(text_features, visual_features)
            img_attended_text = self.img2txt_attention(visual_features, text_features)
            combined_features = torch.cat([txt_attended_visuals, img_attended_text], dim=-1)

        logits = self.classifier(combined_features)

        # crf processing
        if labels is not None:
            # Ensure labels and logits have the same sequence length
            labels = torch.where(labels == -100, torch.zeros_like(labels), labels)

            seq_length = logits.size(1)
            if labels.size(1) < seq_length:
                padding_size = seq_length - labels.size(1)
                # Use a valid label index for padding, e.g., 0
                labels_padded = torch.full((labels.size(0), padding_size), fill_value=0, dtype=torch.long,
                                           device=labels.device)
                labels = torch.cat([labels, labels_padded], dim=1)

                # Adjust attention_mask to cover only the non-padded areas
                attention_mask_padded = torch.zeros((attention_mask.size(0), seq_length), dtype=torch.uint8,
                                                    device=attention_mask.device)
                attention_mask_padded[:, :attention_mask.size(1)] = attention_mask
                attention_mask = attention_mask_padded

            # CRF loss calculation
            loss = -self.crf(logits, labels, mask=attention_mask.byte(), reduction='mean')
            return loss
        else:
            return self.crf.decode(logits, mask=attention_mask.byte())


In [20]:
def extract_labels(data_folder):
    labels_set = set()

    for filename in os.listdir(data_folder):
        if filename.endswith(".txt"):
            file_path = os.path.join(data_folder, filename)
            with open(file_path, "r", encoding="utf8") as file:
                for line in file:
                    if line.strip() and not line.startswith("IMGID:") and line != "\n":
                        parts = line.strip().split('\t')
                        if len(parts) == 2:
                            label = parts[1]
                            labels_set.add(label)

    return labels_set


def create_labels_dict(labels_set):
    label2id = {label: idx for idx, label in enumerate(sorted(labels_set))}
    id2label = {idx: label for label, idx in label2id.items()}

    return label2id, id2label

def collate_fn(batch):
    input_ids, attention_masks, images, labels = zip(*batch)

    # Pad the sequences
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)

    # Ensure the first timestep of each mask is on
    attention_masks[:, 0] = 1

    # Stack images and pad labels
    images = torch.stack(images)
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)  # Assuming -100 is your ignore index for labels

    return input_ids, attention_masks, images, labels


def visualize_prediction(model, dataset, img_id, id2label):
    # Find the corresponding data entry
    for img_file_name, text, labels in dataset.data_lines:
        if img_file_name == img_id:
            break
    else:
        print(f"Image ID {img_id} not found in the dataset.")
        return

    # Process the image and text for model input
    image_path = os.path.join(dataset.img_folder, img_id)
    image = Image.open(image_path).convert('RGB')
    image = dataset.transform(image).unsqueeze(0)  # Add batch dimension and send to device

    inputs = tokenizer(' '.join(text), return_tensors="pt", padding='max_length', max_length=128, truncation=True)
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    # Model prediction
    model.eval()
    model.to(device)
    with torch.no_grad():
        predictions = model(input_ids.to(device), attention_mask.to(device), image.to(device))
        if len(predictions) > 0:
            predicted_indices = predictions[0]
        else:
            print(f"Error while inference with ID {img_id}.")
            return

    # Convert predicted indices to labels
    predicted_labels = [id2label[idx] for idx in predicted_indices]
    predicted_labels = predicted_labels[:len(labels)]

    image_demo = Image.open(image_path).convert('RGB')
    # Display image
    plt.figure(figsize=(10, 5))
    plt.imshow(image_demo)
    plt.title("Demo: ")
    plt.axis('off')

    # Prepare text for display
    formatted_text = 'Text: ' + ' '.join(text)
    formatted_preds = 'Pred: ' + ' '.join(predicted_labels)
    formatted_actuals = 'Actual: ' + ' '.join(labels)

    # Display text annotations closer to the image
    plt.gca().text(0.5, -0.04, formatted_text, transform=plt.gca().transAxes,
                   fontsize=8, verticalalignment='top', horizontalalignment='center', wrap=True)
    plt.gca().text(0.5, -0.07, formatted_preds, transform=plt.gca().transAxes,
                   fontsize=8, verticalalignment='top', horizontalalignment='center', wrap=True)
    plt.gca().text(0.5, -0.1, formatted_actuals, transform=plt.gca().transAxes,
                   fontsize=8, verticalalignment='top', horizontalalignment='center', wrap=True)

    plt.show()

    # Print text with predictions
    print("Text and Predicted Tags:")
    for word, label in zip(text, predicted_labels):
        print(f"{word} [{label}]")


def prepare_input(text, image_path):
    # Tokenize text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    # Load and transform image
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return inputs['input_ids'], inputs['attention_mask'], image


def load_data(file_path, image_dir):
    texts, labels, image_paths = [], [], []
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            if len(parts) == 3:
                text, label, img_id = parts
                texts.append(text)
                labels.append(label)
                image_paths.append(f"{image_dir}/{img_id}.jpg")
    return texts, labels, image_paths


def evaluate_test_data(model, test_loader, device, label2id, id2label):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, masks, images, labels in test_loader:
            inputs, masks, images = inputs.to(device), masks.to(device), images.to(device)
            labels = labels.to(device)  # Ensure labels are on the same device for later comparison

            # Decoding without labels returns predictions
            predictions = model(inputs, masks, images)  # Assuming returns a list of lists for each batch

            # Process each batch item individually
            for idx, (pred, label) in enumerate(zip(predictions, labels)):
                label = label.cpu().numpy()
                valid_length = len(label[label != -100])  # Length without padding

                # Adjust predictions to match the valid length of labels
                pred = pred[:valid_length]  # Trim predictions to match the labels' valid length

                all_preds.extend(pred)
                all_labels.extend(label[:valid_length])  # Only consider valid label parts

    all_possible_labels = list(label2id.values())  # This should be [0, 1, 2, 3, 4, 5, 6]

    if len(all_labels) > len(all_preds):
        my_len = len(all_preds)
        all_labels = all_labels[:my_len]

    elif len(all_preds) > len(all_labels):
        my_len = len(all_labels)
        all_preds = all_preds[:my_len]

    report = classification_report(
        all_labels,
        all_preds,
        labels=all_possible_labels,  # Explicitly state which labels are expected
        target_names=[id2label[i] for i in all_possible_labels],  # Ensure this matches 'labels'
        output_dict=True
    )
    print("Classification Report:\n", report)

    return report

In [29]:
data_folder = 'data/twitter2015'  # Update accordingly
img_folder = 'data/twitter2015_images'  # Update accordingly
MAX_LENGTH = 128

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Image transformations
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
])

labels_set = extract_labels(data_folder)
label2id, id2label = create_labels_dict(labels_set)

print(labels_set)
print(label2id)
print(id2label)

config = BertConfig.from_pretrained('bert-base-uncased', num_labels=len(label2id.items()))
config.label2id = label2id
config.id2label = id2label

{'I-ORG', 'B-ORG', 'I-PER', 'I-OTHER', 'I-LOC', 'B-PER', 'B-LOC', 'O', 'B-OTHER'}
{'B-LOC': 0, 'B-ORG': 1, 'B-OTHER': 2, 'B-PER': 3, 'I-LOC': 4, 'I-ORG': 5, 'I-OTHER': 6, 'I-PER': 7, 'O': 8}
{0: 'B-LOC', 1: 'B-ORG', 2: 'B-OTHER', 3: 'B-PER', 4: 'I-LOC', 5: 'I-ORG', 6: 'I-OTHER', 7: 'I-PER', 8: 'O'}


In [26]:
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=7)
config.label2id = label2id
config.id2label = id2label

model = MTCCMBertForMMTokenClassificationCRF(config=config, num_labels=len(label2id.items()),
                                                             add_context_aware_gate=True,
                                                             use_dynamic_cross_modal_fusion=True)


model.load_state_dict(torch.load('/content/twitter2015_model/model.pth', map_location=torch.device('cpu')))
model.eval()

test_dataset = TwitterDataset(data_folder, img_folder, tokenizer, transform, 'test.txt', label2id, id2label)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

# Evaluate on test data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

MTCCMBertForMMTokenClassificationCRF(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps

In [None]:
test_report = evaluate_test_data(model, test_loader, device, label2id, id2label)

with open("test_classification_report.json", "w") as f:
    json.dump(test_report, f, indent=4)

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


In [None]:
img_id = "62654.jpg"
visualize_prediction(model, test_dataset, img_id)