In [None]:
import pandas as pd
import timm
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np

from core.src.constants import (
    IMAGES_PATH,
    TRAIN_DATA_CSV,
    TEST_DATA_CSV,
    TRAIN_IMAGE_FEATURES_PATH,
    TEST_IMAGE_FEATURES_PATH,
    TRAIN_TEXT_FEATURES_PATH,
    TEST_TEXT_FEATURES_PATH,
    FINE_TUNED_BERT_MODEL_PATH,
)

In [None]:
df_train = pd.read_csv(TRAIN_DATA_CSV, dtype={"unique_id": str})
df_test = pd.read_csv(TEST_DATA_CSV, dtype={"unique_id": str})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MODELS

In [None]:
fastvit_model = timm.create_model("fastvit_ma36.apple_in1k", pretrained=True, num_classes=0)
fastvit_model.to(device)
fastvit_model = fastvit_model.eval()

data_config = timm.data.resolve_model_data_config(fastvit_model)
transforms = timm.data.create_transform(**data_config, is_training=False)

tokenizer = AutoTokenizer.from_pretrained("dumitrescustefan/bert-base-romanian-cased-v1")
bert_model = AutoModelForMaskedLM.from_pretrained(FINE_TUNED_BERT_MODEL_PATH)
bert_model.config.output_hidden_states = True
bert_model.to(device)

torch.cuda.empty_cache()

## Prepare images

In [None]:
train_images = df_train["unique_id"].values
train_images = [IMAGES_PATH / path / "00.png" for path in train_images]

test_images = df_test["unique_id"].values
test_images = [IMAGES_PATH / path / "00.png" for path in test_images]

In [None]:
train_encodings = tokenizer(df_train["input"].tolist(), padding=True, truncation=True, max_length=512)
test_encodings = tokenizer(df_test["input"].tolist(), padding=True, truncation=True, max_length=512)

In [None]:
def compute_image_features(images, model, transforms):
    fastvit_model.eval()
    features = []
    for i in tqdm(range(0, len(images), 8)):
        batch = images[i : i + 8]
        batch = [Image.open(image) for image in batch]
        inputs = [transforms(image) for image in batch]
        inputs = torch.stack(inputs).to(device)
        with torch.no_grad():
            outputs = model(inputs)
            features.append(outputs.cpu())
    features = torch.cat(features, dim=0)
    features_numpy = features.numpy()
    print(features_numpy.shape)
    return features_numpy


def compute_text_features(encodings, model):
    features = []
    for i in tqdm(range(0, len(encodings["input_ids"]), 8)):
        input_ids = torch.tensor(encodings["input_ids"][i : i + 8]).to(device)
        attention_mask = torch.tensor(encodings["attention_mask"][i : i + 8]).to(device)
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
            features.append(outputs.hidden_states[-1].mean(dim=1).cpu())
    features = torch.cat(features, dim=0)
    features_numpy = features.numpy()
    print(features_numpy.shape)
    return features_numpy

In [None]:
train_images_features = compute_image_features(train_images, fastvit_model, transforms)
np.save(TRAIN_IMAGE_FEATURES_PATH, train_images_features)

test_image_features = compute_image_features(test_images, fastvit_model, transforms)
np.save(TEST_IMAGE_FEATURES_PATH, test_image_features)

In [None]:
train_text_features = compute_text_features(train_encodings, bert_model)
np.save(TRAIN_TEXT_FEATURES_PATH, train_text_features)

test_text_features = compute_text_features(test_encodings, bert_model)
np.save(TEST_TEXT_FEATURES_PATH, test_text_features)