<a href="https://colab.research.google.com/github/juan-villa02/medical_vqa_vlm/blob/main/code/notebooks/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VQA Model - BERT + ResNet18

## Libraries/Dependencies

In [10]:
# PyTorch framework
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
# Pillow (image handling)
from PIL import Image
# Transformers (Hugginface)
from transformers import BertModel, BertTokenizer
# Matplotlib
import matplotlib.pyplot as plt
# Numpy
import numpy as np
# Extra dependencies
import json
import os
from tqdm import tqdm
import zipfile

In [11]:
# Set training device to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Paths & Data Extraction

In [1]:
#from google.colab import drive
#drive.mount('/content/drive')

In [12]:
path_dir = '.' #path to drive folder '/content/drive/MyDrive/TFG-VQA/data/bases de datos.zip'
# path_dir = '/content/drive/MyDrive/TFG-VQA/data/bases de datos.zip'

In [13]:
# Data folder
data_folder = './data/'

In [14]:
# ISIC_2016 database (images)
train_ISIC_path = 'bases de datos/ISIC_2016/images/ISBI2016_ISIC_Part3_Training_Data_orig'
test_ISIC_path = 'bases de datos/ISIC_2016/images/ISBI2016_ISIC_Part1_Test_Data_orig'

In [24]:
# Pizarro database (images)
pizarro_path = 'bases de datos/pizarro/images/Entrega1'

In [16]:
# Function to extract databases from the main zip file
def extract_databases(zip_path, extract_path):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)

ISIC (Fusion)

Images path: drive/content/TFG Juan Villanueva/bases de datos/ISIC_2016/images

Test -> ISBI2016_ISIC_Part1_Test_Data_orig
Train -> ISBI2016_ISIC_Part3_Training_Data_orig


Pizarro (Melanoma)

Images path: drive/content/bases de datos/pizarro/images

500 imágenes

Las 46 primeras son más complejas, en cuanto a diagnóstico. (Entrega 1)

El resto están más equilibradas (Entrega 2-5)


In [17]:
# Extract all databases from the main zip file
if path_dir != '.':
  extract_databases(path_dir, data_folder)

In [18]:
# ISIC_2016 images
database1_train_path = os.path.join(data_folder, train_ISIC_path)
database1_test_path = os.path.join(data_folder, test_ISIC_path)

In [25]:
# Pizarro images
database2_path = os.path.join(data_folder, pizarro_path)

## VQA Dataset - Images & QA

In [43]:
# Dataset class to load images and questions
class VQADataset(Dataset):
    def __init__(self, data_dir, json_file, tokenizer=None, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.tokenizer = tokenizer

        with open(json_file, 'r') as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.data_dir, item['image_id'])
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # question = item['question']
        #tokenized_question = self.tokenizer(question, return_tensors='pt', padding=True, truncation=True)
        question = item['qa_pairs'][5]['question']

        return image, question, item['qa_pairs'][5]['answer']

In [44]:
pizarroDataset = VQADataset(database2_path,'./data/bases de datos/pizarro/df_melanoma_qa2.json')

In [46]:
pizarroDataset.__getitem__(10)

(<PIL.Image.Image image mode=RGB size=640x480>,
 'What is the histology diagnostic?',
 'The histology diagnostic indicates non-atypical stable mole (m).')