In [3]:
import yaml
from torch import nn
from torchsummary import summary
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from VIT_backbone.vit_transformer import PatchEmbedding, TransformerEncoder
from transformer_decoder.transformer_decoder import LanguageModel
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader
from dataset_structure.medical_datasets import RocoDataset

In [4]:
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_cased')

with open('VIT_backbone/config_vit.yaml', 'r') as file:
    parameters = yaml.safe_load(file)

In [5]:
class ViT(nn.Sequential):
    def __init__(self,
                 in_channels: int = parameters['in_channels'],
                 patch_size: int = parameters['patch_size'],
                 emb_size: int = parameters['emb_size'],
                 img_size: int = parameters['img_size'],
                 depth: int = parameters['depth'],
                 **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
        )

In [7]:
img = Image.open('img_test.jpg')

transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0)  # add batch dim
print(x.shape)

vision_transformer = ViT()

#summary(vision_transformer)

output_encoder = vision_transformer.forward(x)

torch.Size([1, 3, 224, 224])


In [8]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = torch.device(dev)
print('Using {}'.format(device))

Using cuda:0


In [9]:
max_seq_length = 9

model = LanguageModel(
    vocab_size=tokenizer.vocab_size,
    max_seq_length=max_seq_length,
    dim=64,
    pad_token_id=tokenizer.pad_token_id,
).to(device)

In [10]:
roco_path = "roco-dataset/"
dataset = RocoDataset(roco_path=roco_path, mode="train")
sample_img = dataset[1]

roco_loader = DataLoader(dataset, batch_size=10, shuffle=True)
img, caption_input, caption_target, keywords_input, img_name = next(iter(roco_loader))

output_encoder = vision_transformer.forward(img)


In [None]:
print(f"shape da saida: {output_encoder.shape}")

model(caption_input.to(device), output_encoder.to(device), output_encoder.to(device))

In [None]:
%debug