# Imports and constants


In [None]:
import torch
from torch import nn
import torchvision
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import transformers

  from .autonotebook import tqdm as notebook_tqdm


# CV model


In [3]:
# torch.hub._validate_not_a_forked_repo = lambda a, b, c: True

class CNNBackbone(nn.Module):
    def __init__(self, resnet):
        super().__init__()

        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.bottleneck = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 768),
            nn.ReLU()
        )

    def forward(self, x) -> torch.Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = x.reshape(-1, x.shape[-2]*x.shape[-1], 512)  # BxCxHxW -> BxNxD
        x = self.bottleneck(x)

        return x

In [9]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
resnet = torchvision.models.resnet18(pretrained=True)
flant5: transformers.models.t5.modeling_t5.T5ForConditionalGeneration = AutoModelForSeq2SeqLM.from_pretrained(
    "google/flan-t5-base")

image = torch.randn(1, 3, 224, 224)
backbone = CNNBackbone(resnet)
output = backbone(image)

start_token_id = tokenizer.pad_token_id  # T5 uses pad token as decoder start
input_ids = torch.full((1, 1), start_token_id, dtype=torch.long)

# 4. Generate text
output_ids = flant5.generate(
    None,
    encoder_outputs=transformers.modeling_outputs.BaseModelOutput(output),
    max_length=30,
    num_beams=3,
    early_stopping=True
)

# 5. Decode
decoded = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(decoded)



['bildung 1ntontototototoTOtoTOTOTOTOTOTOTOTOTOTOTOTOTOTOTOTOTOTOTO']


# Text model