In [None]:
import pickle

import timm
import torch
from PIL import Image
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tabulate import tabulate
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from constants import (
    TRAIN_DATA_CSV,
    TEST_DATA_CSV,
    MULTIMODAL_MODEL_PATH,
    FINE_TUNED_FASTVIT_MODEL_PATH,
    FINE_TUNED_BERT_MODEL_PATH,
    TARGET_SCALER_PATH,
    IMAGES_PATH,
)

In [ ]:
# import patoolib
# patoolib.extract_archive("resized_images.zip",outdir="/workspace")

In [ ]:
def compute_metrics(predictions, ground_truths):
    mae = mean_absolute_error(ground_truths, predictions)
    mse = mean_squared_error(ground_truths, predictions)
    rmse = np.sqrt(mse)
    r2 = r2_score(ground_truths, predictions)

    return {"MAE": mae, "MSE": mse, "RMSE": rmse, "R2": r2}


def plot_loss_and_metrics(history, metrics_history, SLICE_START=10):
    plt.plot(history["train_loss"][SLICE_START:], label="train loss")
    plt.plot(history["test_loss"][SLICE_START:], label="test loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    plt.plot(metrics_history["train_mae"][SLICE_START:], label="train mae")
    plt.plot(metrics_history["test_mae"][SLICE_START:], label="test mae")
    plt.xlabel("Epoch")
    plt.ylabel("MAE")
    plt.legend()
    plt.show()

    plt.plot(metrics_history["train_rmse"][SLICE_START:], label="train rmse")
    plt.plot(metrics_history["test_rmse"][SLICE_START:], label="test rmse")
    plt.xlabel("Epoch")
    plt.ylabel("RMSE")
    plt.legend()
    plt.show()

    plt.plot(metrics_history["train_r2"][SLICE_START:], label="train r2")
    plt.plot(metrics_history["test_r2"][SLICE_START:], label="test r2")
    plt.xlabel("Epoch")
    plt.ylabel("R2")
    plt.legend()
    plt.show()

    plt.plot(metrics_history["train_mse"][SLICE_START:], label="train mse")
    plt.plot(metrics_history["test_mse"][SLICE_START:], label="test mse")
    plt.xlabel("Epoch")
    plt.ylabel("MSE")
    plt.legend()
    plt.show()


def print_metrics_table(metrics_history):
    headers = ["Epoch", "MAE", "RMSE", "R2", "MSE"]

    # Prepare train data
    train_data = [
        [
            len(metrics_history["train_mae"]) - 1,
            f"{metrics_history['train_mae'][-1]:.5f}",
            f"{metrics_history['train_rmse'][-1]:.5f}",
            f"{metrics_history['train_r2'][-1]:.5f}",
            f"{metrics_history['train_mse'][-1]:.5f}",
        ]
    ]

    # Prepare test data
    test_data = [
        [
            len(metrics_history["test_mae"]) - 1,
            f"{metrics_history['test_mae'][-1]:.5f}",
            f"{metrics_history['test_rmse'][-1]:.5f}",
            f"{metrics_history['test_r2'][-1]:.5f}",
            f"{metrics_history['test_mse'][-1]:.5f}",
        ]
    ]

    # Print train metrics table
    print("Train Metrics")
    print(tabulate(train_data, headers=headers, tablefmt="grid"))

    # Print test metrics table
    print("\nTest Metrics")
    print(tabulate(test_data, headers=headers, tablefmt="grid"))

In [None]:
SLICE: int | None = None

# df_train = pd.read_csv('train_data.csv', dtype={"unique_id": str})[:SLICE]
# df_test = pd.read_csv('test_data.csv', dtype={"unique_id": str})[:SLICE]

df_train = pd.read_csv(TRAIN_DATA_CSV, dtype={"unique_id": str})[:SLICE]
df_test = pd.read_csv(TEST_DATA_CSV, dtype={"unique_id": str})[:SLICE]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.cuda.empty_cache()

# with open('target_scaler.pkl', "rb") as f:
#     target_scaler = pickle.load(f)

with open(TARGET_SCALER_PATH, "rb") as f:
    target_scaler = pickle.load(f)

In [None]:
fastvit = timm.create_model("fastvit_t8.apple_in1k", pretrained=True, num_classes=0)

fastvit.head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(fastvit.num_features, 1))

# fastvit.load_state_dict(torch.load('fine_tuned_fastvit_model.pth'))
fastvit.load_state_dict(torch.load(FINE_TUNED_FASTVIT_MODEL_PATH))
fastvit.to(device)

data_config = timm.data.resolve_model_data_config(fastvit)
transforms = timm.data.create_transform(**data_config, is_training=False)


class FastViTEmbedding(nn.Module):
    def __init__(self, model):
        super(FastViTEmbedding, self).__init__()
        self.model = model
        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        # Assuming the embeddings you want are just before the head.
        # This accesses the last layer before the regression head.
        x = self.model.forward_features(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return x


fastvit_model = FastViTEmbedding(fastvit).to(device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    "dumitrescustefan/bert-base-romanian-uncased-v1", do_lower_case=True, add_special_tokens=True, max_length=512, padding=True, truncation=True
)
bert_model = AutoModel.from_pretrained("dumitrescustefan/bert-base-romanian-uncased-v1")
bert_model.to(device)


class BERTRegressor(nn.Module):
    def __init__(self):
        super(BERTRegressor, self).__init__()
        self.bert = bert_model
        self.fc = nn.Linear(768, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        outputs = outputs[1]  # Use the output of the [CLS] token
        return outputs


bert_model = BERTRegressor().to(device)

# bert_model.load_state_dict(torch.load('fine_tuned_bert_model.pth'))
bert_model.load_state_dict(torch.load(FINE_TUNED_BERT_MODEL_PATH))

In [None]:
train_images = df_train["unique_id"].values
test_images = df_test["unique_id"].values
train_images = [IMAGES_PATH / f"{path}.png" for path in train_images]
test_images = [IMAGES_PATH / f"{path}.png" for path in test_images]
# train_images = [f"resized_images/{path}.png" for path in train_images]
# test_images = [f"resized_images/{path}.png" for path in test_images]

train_encodings = tokenizer(df_train["input"].tolist(), padding=True, truncation=True, max_length=512)
test_encodings = tokenizer(df_test["input"].tolist(), padding=True, truncation=True, max_length=512)

STRUCTURED_COLUMNS = [
    "km",
    "putere",
    "capacitate cilindrica",
    "anul producției",
    "marca",
    "model",
    "combustibil",
    "tip caroserie",
    "firma",
    "is_automatic",
]

train_structured_data = df_train[STRUCTURED_COLUMNS].to_numpy()
test_structured_data = df_test[STRUCTURED_COLUMNS].to_numpy()

train_targets = df_train["price_std"].to_numpy()
test_targets = df_test["price_std"].to_numpy()

print(f"Train images: {len(train_images)}")
print(f"Train encodings: {len(train_encodings['input_ids'])}")
print(f"Train structured data: {train_structured_data.shape}")
print(f"Train targets: {train_targets.shape}")

print(f"Test images: {len(test_images)}")
print(f"Test encodings: {len(test_encodings['input_ids'])}")
print(f"Test structured data: {test_structured_data.shape}")
print(f"Test targets: {test_targets.shape}")

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, images_paths, encodings, structured_data, targets):
        self.images_paths = images_paths
        self.encodings = encodings
        self.structured_data = structured_data
        self.targets = targets

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        image = Image.open(self.images_paths[idx])
        image = transforms(image)
        input_ids = torch.tensor(self.encodings["input_ids"][idx])
        attention_mask = torch.tensor(self.encodings["attention_mask"][idx])
        structured_data = torch.tensor(self.structured_data[idx]).float()
        target = torch.tensor(self.targets[idx]).float()

        return image, input_ids, attention_mask, structured_data, target

In [None]:
BATCH_SIZE = 4

train_dataset = MultimodalDataset(train_images, train_encodings, train_structured_data, train_targets)
test_dataset = MultimodalDataset(test_images, test_encodings, test_structured_data, test_targets)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class MultiModalModel(nn.Module):
    def __init__(self, fastvit_model, bert_model):
        super(MultiModalModel, self).__init__()
        self.fastvit_model = fastvit_model
        self.bert_model = bert_model
        self.fc = nn.Linear(768 + 768 + 10, 1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, image, input_ids, attention_mask, structured_data):
        fastvit_embedding = self.fastvit_model(image)
        bert_embedding = self.bert_model(input_ids, attention_mask)
        x = torch.cat([fastvit_embedding, bert_embedding, structured_data], dim=1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [None]:
multimodal_model = MultiModalModel(fastvit_model, bert_model).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(multimodal_model.parameters(), lr=0.00001)

history = {"train_loss": [], "test_loss": []}
metrics_history = {"train_mae": [], "test_mae": [], "train_rmse": [], "test_rmse": [], "train_r2": [], "test_r2": [], "train_mse": [], "test_mse": []}

In [None]:
def print_model_statuses():
    print(f"Core FastViT model: {multimodal_model.fastvit_model.model.training}")
    print(f"Core BERT model: {multimodal_model.bert_model.bert.training}")

    print(f"FastVitEmbedding model: {multimodal_model.fastvit_model.training}")
    print(f"BERTEmbedding model: {multimodal_model.bert_model.training}")

    print(f"Multimodal model: {multimodal_model.training}")

In [None]:
def train(EPOCHS=100):
    best_val_loss = float("inf")
    multimodal_model.train()
    for epoch in range(EPOCHS):
        multimodal_model.train()
        # print_model_statuses()
        train_losses = []
        all_train_predictions = []
        all_train_ground_truths = []

        for batch in tqdm(train_loader):
            images, input_ids, attention_mask, structured_data, targets = [b.to(device) for b in batch]
            targets = targets.view(-1, 1)

            outputs = multimodal_model(images, input_ids, attention_mask, structured_data)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

            unnorm_outputs = target_scaler.inverse_transform(outputs.cpu().detach().numpy())
            unnorm_targets = target_scaler.inverse_transform(targets.cpu().detach().numpy())

            all_train_predictions.extend(unnorm_outputs)
            all_train_ground_truths.extend(unnorm_targets)

        train_metrics = compute_metrics(all_train_predictions, all_train_ground_truths)
        metrics_history["train_mae"].append(train_metrics["MAE"])
        metrics_history["train_mse"].append(train_metrics["MSE"])
        metrics_history["train_rmse"].append(train_metrics["RMSE"])
        metrics_history["train_r2"].append(train_metrics["R2"])

        avg_train_loss = np.sum(train_losses) / len(train_loader)
        history["train_loss"].append(avg_train_loss)
        print(f"Epoch {epoch + 1}, Train Loss: {avg_train_loss}")
        print(f"Epoch {epoch + 1}, Train Metrics: {train_metrics}")

        multimodal_model.eval()
        # print_model_statuses()
        with torch.no_grad():
            validation_losses = []
            all_test_predictions = []
            all_test_ground_truths = []
            for batch in test_loader:
                images, input_ids, attention_mask, structured_data, targets = [b.to(device) for b in batch]
                targets = targets.view(-1, 1)

                outputs = multimodal_model(images, input_ids, attention_mask, structured_data)
                loss = criterion(outputs, targets)
                validation_losses.append(loss.item())

                unnorm_outputs = target_scaler.inverse_transform(outputs.cpu().detach().numpy())
                unnorm_targets = target_scaler.inverse_transform(targets.cpu().detach().numpy())

                all_test_predictions.extend(unnorm_outputs)
                all_test_ground_truths.extend(unnorm_targets)

            avg_val_loss = np.sum(validation_losses) / len(test_loader)
            history["test_loss"].append(avg_val_loss)

            test_metrics = compute_metrics(all_test_predictions, all_test_ground_truths)
            metrics_history["test_mae"].append(test_metrics["MAE"])
            metrics_history["test_mse"].append(test_metrics["MSE"])
            metrics_history["test_rmse"].append(test_metrics["RMSE"])
            metrics_history["test_r2"].append(test_metrics["R2"])

            print(f"Epoch {epoch + 1}, Validation Loss: {avg_val_loss}")
            print(f"Epoch {epoch + 1}, Test Metrics: {test_metrics}")

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                # torch.save(multimodal_model.state_dict(), 'multimodal_model.pth')
                torch.save(multimodal_model.state_dict(), MULTIMODAL_MODEL_PATH)
                print(f"Epoch {epoch + 1}: New best test loss: {best_val_loss}")

In [None]:
train(10)

In [None]:
plot_loss_and_metrics(history, metrics_history, SLICE_START=0)
print_metrics_table(metrics_history)