In [1]:
import sys
sys.path.append("../src")

In [2]:
from typing import List
from typing import Tuple

import tempfile
from pathlib import Path

import torch
from IPython.display import display
from IPython.display import Markdown

from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map
from health_multimodal.text import get_bert_inference
from health_multimodal.text.utils import BertEncoderType
from health_multimodal.image import get_image_inference
from health_multimodal.image.utils import ImageModelType
from health_multimodal.vlp import ImageTextInferenceEngine

  from .autonotebook import tqdm as notebook_tqdm


## Load BioViL model

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'CXRBertTokenizer'.
You are using a model of type bert to instantiate a model of type cxr-bert. This is not supported for all configurations of models and can yield errors.


Using downloaded and verified file: /var/folders/54/s690rsnj4qz9x9cqtk0ky7w40000gn/T/biovil_t_image_model_proj_size_128.pt


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


## Training

In [None]:
def plot_phrase_grounding(image_path: Path, text_prompt: str, bboxes: List[TypeBox]) -> None:
    similarity_map = BioVLP.get_similarity_map_from_raw_data(
        image_path=image_path,
        query_text=text_prompt,
        interpolation="bilinear",
    )

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# train_loader = DataLoader(MyDataset(), batch_size=32, shuffle=True)

import torch.nn as nn

class BoundingBoxPredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super(BoundingBoxPredictor, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 4)

        self.relu = nn.ReLU()

    def forward(self, similarity_map):
        x = similarity_map.view(similarity_map.size(0), -1)

        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        coords = self.fc3(x)

        return coords


class CombinedModel(nn.Module):
    def __init__(self, BioViL, box_predictor):
        super(CombinedModel, self).__init__()
        self.BioViL = BioViL
        self.box_predictor = box_predictor

    def forward(self, image_path, text_prompt):
        similarity_map = self.BioViL.get_similarity_map_from_raw_data(
            image_path=image_path,
            query_text=text_prompt,
            interpolation="linear",
        )

        bbox_coordinates = self.box_predictor(similarity_map)
        return bbox_coordinates

# Load BioViL Model
text_inference = get_bert_inference(BertEncoderType.BIOVIL_T_BERT)
image_inference = get_image_inference(ImageModelType.BIOVIL_T)

BioViL = ImageTextInferenceEngine(
    image_inference_engine=image_inference,
    text_inference_engine=text_inference,
)

bbox_predictor = BoundingBoxPredictor()
model = CombinedModel(BioViL, bbox_predictor)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

n_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(n_epochs):
    model.train()
    
    for batch_idx, (image_path, text_prompt, ground_truth_boxes) in enumerate(train_loader):
        
        predicted_boxes = model(image_path, text_prompt)
        loss = criterion(predicted_boxes, ground_truth_boxes)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}")
    

