In [2]:
import os
# from PIL import Image
import pandas as pd

import torch
from torch import nn, optim
import torchvision.models as models
from torchvision import transforms

from PIL import Image
import cv2

from transformers import BertTokenizer, BertModel

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

In [3]:
def load_data(path):
    
    '''
    Load image and text data from the specified directory structure.

    Parameters:
    - path (str): Path to the root directory containing subdirectories for images and OCR text.

    Returns:
    - pd.DataFrame: A DataFrame containing image paths, text file paths, and corresponding class labels.
    
    '''
    
    image_data = []
    text_data = []
    labels = []

    for class_dir in os.listdir(os.path.join(path, 'images')):
        if os.path.isdir(os.path.join(path, 'images', class_dir)):
            
            # get the image and text directory paths
            image_files = os.listdir(os.path.join(path, 'images', class_dir))
            text_files = os.listdir(os.path.join(path, 'ocr', class_dir))
            label = int(class_dir)
            
            # get the images and text paths
            for image_file, text_file in zip(image_files, text_files):
                image_path = os.path.join(path, 'images', class_dir, image_file)
                text_path = os.path.join(path, 'ocr', class_dir, text_file)
                
                image_data.append(image_path)
                text_data.append(text_path)
                labels.append(label)

    return pd.DataFrame(data=[image_data, text_data, labels]).T.rename({0:"img_path", 1:"text_path", 2:"target"}, axis=1)

In [4]:
data = load_data("./data")
data.head()

Unnamed: 0,img_path,text_path,target
0,./data/images/0/466c6bc2-3196-499e-b506-8f0bda...,./data/ocr/0/412056de-bfed-4053-90b0-8fb3f6b9d...,0
1,./data/images/0/88e2faad-f4c7-4331-beda-8a5f8a...,./data/ocr/0/7d92b943-8506-427e-ac1c-f8888d38b...,0
2,./data/images/0/005a7f78-9e3f-44b3-beb8-3fe834...,./data/ocr/0/005a7f78-9e3f-44b3-beb8-3fe834af4...,0
3,./data/images/0/00b8e9c9-76c8-431d-9584-16f37f...,./data/ocr/0/00b8e9c9-76c8-431d-9584-16f37f138...,0
4,./data/images/0/00d62096-c544-44d6-84a4-3fd3df...,./data/ocr/0/00d62096-c544-44d6-84a4-3fd3df471...,0


In [5]:
class DocumentDataset:
    
    '''
    Dataset class for processing document images and text data.

    This class loads document images and corresponding OCR text data, preprocesses them,
    and prepares them for input into a model. It applies image transformations and tokenizes
    text using the BERT tokenizer.

    Parameters:
    - data (pd.DataFrame): DataFrame containing image paths and text file paths.
    - target (array-like): Array containing class labels.

    Attributes:
    - data (pd.DataFrame): DataFrame containing image paths and text file paths.
    - target (array-like): Array containing class labels.
    - tokenizer (BertTokenizer): BERT tokenizer for tokenizing text.
    - image_transform (torchvision.transforms.Compose): Image transformation pipeline.

    '''
    
    def __init__(self, data, target):
        self.data = data
        self.target = target
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        
        self.image_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
    
    def __getitem__(self, idx):
        img_path = self.data["img_path"][idx]
        text_path = self.data["text_path"][idx]
        
        output = self.target[idx]
        
        # loading image
        img_data = cv2.imread(img_path)
        img_data = Image.fromarray(cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB))
        
        img_data = self.image_transform(img_data)
        
        # loading text
        with open(text_path, 'r') as f:
            text = f.read()
            
        tokens = self.tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        
        tokens.pop("token_type_ids")
        
        return [img_data, tokens], torch.tensor(output)
    
    def __len__(self):
        
        return len(self.target)

In [6]:
class ResNet18(nn.Module):
    
    '''
    ResNet18 with modified final dense layer.

    This class loads a pretrained ResNet18 model and modifies it by freezing
    all layers except the final dense layer. The final dense layer is replaced
    with a new linear layer with 256 output features.

    Attributes:
    - resnet (torchvision.models.ResNet): The pretrained ResNet18 model.
    - fc (torch.nn.Linear): The modified final dense layer.

    '''
    
    def __init__(self):
        super(ResNet18, self).__init__()
        
        # loading pretrained resnet
        self.resnet = models.resnet18(pretrained=True)
        
        # freezeing all layers except the final dense layer
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        # setting the final dense layer to trainable
        for param in self.resnet.fc.parameters():
            param.requires_grad = True
            
        # modifying the final layer
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features=num_features, out_features=256)
        
    def forward(self, data):
        
        '''
        Performs a forward pass through the network.

        Parameters:
        - data (torch.Tensor): Input data with shape (batch, channels, width, height).

        Returns:
        - torch.Tensor: Output tensor from the network.
        '''
        
        # data -> (batch, channels, width, height)
        
        out = self.resnet(data)
        
        return out
        

In [7]:
class Bert(nn.Module):
    
    '''
    BERT model with a modified final dense layer.

    This class loads a pretrained BERT model and adds a final dense layer for
    downstream task fine-tuning. By default, it uses the 'bert-base-uncased'
    pretrained model from Hugging Face's transformers library.

    Parameters:
    - model_name (str): The name of the pretrained BERT model to load. Default is 'bert-base-uncased'.

    Attributes:
    - model_name (str): The name of the pretrained BERT model being used.
    - bert_model (transformers.BertModel): The pretrained BERT model.
    - fc (torch.nn.Linear): The final dense layer added for fine-tuning.
    '''
    
    def __init__(self, model_name="bert-base-uncased"):
        super(Bert, self).__init__()
        self.model_name = model_name
        
        self.bert_model = BertModel.from_pretrained(model_name)

        # freezing the layers
        for param in self.bert_model.parameters():
            param.requires_grad = False
            
        # adding a final dense layer
        self.fc = nn.Linear(in_features=self.bert_model.config.hidden_size, out_features=256)
        
    def forward(self, text):
        
        '''
        Performs a forward pass through the BERT model.

        Parameters:
        - text (dict): Dictionary containing 'input_ids' and 'attention_mask' tensors representing tokenized input text.

        Returns:
        - torch.Tensor: Output tensor from the final dense layer.

        '''
        
        # print(text['input_ids'].shape, text["attention_mask"].shape)
        
        output = self.bert_model(text["input_ids"].squeeze(1), text["attention_mask"].squeeze(1))
        
        output = self.fc(output.pooler_output)
        
        return output

In [8]:
class MultiModalClassification(nn.Module):
    
    '''
    A custom neural network model for multimodal classification tasks.

    This model combines image and text modalities for classification using
    separate pretrained models (ResNet18 for images and BERT for text). It
    then concatenates the output features from both modalities and passes
    them through a fully connected layer for classification.

    Parameters:
    - num_classes (int): The number of classes for classification. Default is 5.

    Attributes:
    - img_net (ResNet18): The ResNet18 model for processing image data.
    - text_net (Bert): The BERT model for processing text data.
    - fc (torch.nn.Sequential): The fully connected layer for final classification.
    '''
    
    def __init__(self, num_classes=5):
        
        super(MultiModalClassification, self).__init__()
        
        self.img_net = ResNet18()
        self.text_net = Bert()
        
        self.fc = nn.Sequential(
            nn.Linear(in_features=512, out_features=num_classes)
            )

    def forward(self, data):
        
        '''
        Performs a forward pass through the network.

        Parameters:
        - data (list): List containing image data and text data.

        Returns:
        - torch.Tensor: Output tensor from the fully connected layer.
        '''
        
        img_data = data[0]
        text_data = data[1]
        
        img_outupt = self.img_net(img_data)
        text_output = self.text_net(text_data)
        
        combined_output = torch.cat([img_outupt, text_output], dim=1)
        
        output = self.fc(combined_output)
        
        return output
    
    def fit(self, dataloader, epochs=5, loss_func=nn.CrossEntropyLoss(), optimizer=optim.Adam, lr=0.01):
        
        '''
        Trains the model using the provided dataloader.

        Parameters:
        - dataloader (torch.utils.data.DataLoader): Dataloader for training data.
        - epochs (int): Number of epochs for training. Default is 5.
        - loss_func: Loss function. Default is nn.CrossEntropyLoss().
        - optimizer: Optimizer for training. Default is optim.Adam.
        - lr (float): Learning rate for the optimizer. Default is 0.01.

        '''
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        print(device)
        
        self.to(device=device)
        
        optimizer = optimizer([param for param in self.parameters() if param.requires_grad], lr=lr)
        
        for epoch in range(epochs):
            
            correct_predictions = 0
            total_samples = 0
            epoch_loss = 0.0
            
            for batch_idx, (data, target) in enumerate(dataloader):
                
                img_data = data[0].to(device=device)
                text_data = data[1].to(device=device)
                target = target.to(device=device)
                
                output = self([img_data, text_data])
                
                loss = loss_func(output, target)
                
                loss.backward()
                
                optimizer.step()
                optimizer.zero_grad()
                
                epoch_loss += loss.item()
                
                _, predicted = torch.max(output, 1)
                
                total_samples += target.shape[0]
                correct_predictions += (predicted == target).sum().item()
                # print(f"batch {batch_idx}")
                
            print(f"Epochs: {epoch+1}/{epochs}\tLoss: {epoch_loss}\taccuracy: {correct_predictions / total_samples}")
            

In [10]:
m = MultiModalClassification()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
X_train, X_test, y_train, y_test = train_test_split(data[["img_path", "text_path"]], data["target"], test_size=0.2, random_state=42)

In [12]:
train_dataset = DocumentDataset({col: X_train[col].to_list() for col in X_train.columns}, y_train.to_list())

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

In [11]:
m.fit(train_dataloader)

cuda
Epochs: 1/5	Loss: 257.12941682338715	accuracy: 0.4465
Epochs: 2/5	Loss: 32.21738764643669	accuracy: 0.751
Epochs: 3/5	Loss: 20.08889175951481	accuracy: 0.8005
Epochs: 4/5	Loss: 14.328647837042809	accuracy: 0.849
Epochs: 5/5	Loss: 13.83181281387806	accuracy: 0.868


In [12]:
# torch.save(m, "./model_multimodel.pth")

In [13]:
def evaluate(model, dataloader):
    
    '''
    Evaluates the performance of a model on a given dataset.

    Args:
    - model (nn.Module): The model to be evaluated.
    - dataloader (DataLoader): DataLoader providing the evaluation dataset.

    Returns:
    - float: Accuracy of the model on the evaluation dataset.
    '''
    
    correct_predictions = 0
    total_samples = 0
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    model.eval()
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            img_data = data[0].to(device=device)
            text_data = data[1].to(device=device)
            target = target.to(device=device)
            
            output = model([img_data, text_data])
                    
            _, predicted = torch.max(output, 1)
            
            total_samples += target.shape[0]
            correct_predictions += (predicted == target).sum().item()
            
    return correct_predictions / total_samples

In [14]:
model = torch.load("./model_multimodel.pth")

In [15]:
evaluate(model, train_dataloader)

0.902

In [16]:
test_dataset = DocumentDataset({col: X_test[col].to_list() for col in X_test.columns}, y_test.to_list())

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64)

evaluate(model ,test_dataloader)

0.916