In [None]:
!pip install torch torchvision transformers pillow

In [None]:
class BertCrossAttention(nn.Module):
    """Implements cross-attention between two different modalities using a decoder layer."""

    def __init__(self, config, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
            for _ in range(num_layers)
        ])

    def forward(self, query, key, mask=None):
        output = query
        for layer in self.layers:
            output = layer(output, key, tgt_key_padding_mask=mask)
        return output


class MTCCMBertForMMTokenClassificationCRF(BertPreTrainedModel):
    def __init__(self, config, num_labels):
        super().__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.resnet = resnet152(pretrained=True)
        self.resnet.fc = nn.Identity()  # Adapt ResNet to remove the final fully connected layer

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vismap2text = nn.Linear(2048, config.hidden_size)

        self.txt2img_attention = BertCrossAttention(config, num_layers=1)
        self.img2txt_attention = BertCrossAttention(config, num_layers=1)

        self.classifier = nn.Linear(config.hidden_size * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)
        self.init_weights()

    def forward(self, input_ids, attention_mask, visual_embeds, labels=None):
        # Text feature extraction
        text_outputs = self.bert(input_ids, attention_mask=attention_mask)
        text_features = self.dropout(text_outputs.last_hidden_state)

        # Image feature extraction with ResNet-152
        visual_features = self.resnet(visual_embeds)  # Assuming visual_embeds is [batch_size, 3, 224, 224]
        visual_features = visual_features.view(visual_features.size(0), -1)  # Flatten the output of the ResNet
        visual_features = self.vismap2text(visual_features)  # Transform to match BERT hidden size
        visual_features = visual_features.unsqueeze(1).expand(-1, text_features.size(1),
                                                              -1)  # Expand to match text sequence length

        # Cross-modal attention
        txt_attended_visuals = self.txt2img_attention(text_features, visual_features)
        img_attended_text = self.img2txt_attention(visual_features, text_features)

        # Combine and classify
        combined_features = torch.cat([txt_attended_visuals, img_attended_text], dim=-1)
        logits = self.classifier(combined_features)

        # crf processing
        if labels is not None:
            # Ensure labels and logits have the same sequence length
            labels = torch.where(labels == -100, torch.zeros_like(labels), labels)

            seq_length = logits.size(1)
            if labels.size(1) < seq_length:
                padding_size = seq_length - labels.size(1)
                # Use a valid label index for padding, e.g., 0
                labels_padded = torch.full((labels.size(0), padding_size), fill_value=0, dtype=torch.long,
                                           device=labels.device)
                labels = torch.cat([labels, labels_padded], dim=1)

                # Adjust attention_mask to cover only the non-padded areas
                attention_mask_padded = torch.zeros((attention_mask.size(0), seq_length), dtype=torch.uint8,
                                                    device=attention_mask.device)
                attention_mask_padded[:, :attention_mask.size(1)] = attention_mask
                attention_mask = attention_mask_padded

            # CRF loss calculation
            loss = -self.crf(logits, labels, mask=attention_mask.byte(), reduction='mean')
            return loss
        else:
            return self.crf.decode(logits, mask=attention_mask.byte())


In [None]:
def load_data(file_path, image_dir):
    texts, labels, image_paths = [], [], []
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            if len(parts) == 3:
                text, label, img_id = parts
                texts.append(text)
                labels.append(label)
                image_paths.append(f"{image_dir}/{img_id}.jpg")
    return texts, labels, image_paths

# Evaluate on a test dataset
def evaluate_model(model, texts, labels, image_paths):
    predictions = []
    for text, img_path in zip(texts, image_paths):
        input_ids, attention_mask, image = prepare_input(text, img_path)
        with torch.no_grad():
            logits = model(input_ids, attention_mask, image)
            prediction = model.crf.decode(logits)
        predictions.append(prediction)
    # Here you might want to compare predictions with true labels, compute accuracy, etc.
    return predictions

def demo_prediction(text, image_path):
    input_ids, attention_mask, image = prepare_input(text, image_path)
    with torch.no_grad():
        logits = model(input_ids, attention_mask, image)
        prediction_indices = model.crf.decode(logits)
        predicted_tags = [id2label[idx] for idx in prediction_indices[0]]  # Convert indices to labels

    # Show the image and prediction
    img = Image.open(image_path)
    plt.imshow(img)
    plt.title(f"Predicted Tags: {predicted_tags}")
    plt.show()

In [None]:
import torch
from transformers import BertTokenizer
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms

# Load the tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = MTCCMBertForMMTokenClassificationCRF(config=config, num_labels=7)
model.load_state_dict(torch.load('path_to_model.pth'))
model.eval()

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Dummy paths
texts, true_labels, image_paths = load_data("test.txt", "path_to_images")
predictions = evaluate_model(model, texts, true_labels, image_paths)

In [None]:
with torch.no_grad():
    logits = model(input_ids, attention_mask, image)
    predictions = model.crf.decode(logits)  # Get tag indices

def decode_tags(tag_indices, id2label):
    # Convert tag indices to tag names
    return [id2label[idx] for idx in tag_indices[0]]  # Assuming batch size of 1

# Assuming id2label mapping is available
predicted_tags = decode_tags(predictions, id2label)
print(predicted_tags)
