In [None]:
import numpy as np
import os
import pandas as pd
import torch
from PIL import Image
from functools import cache
from matplotlib import pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer, StandardScaler
from skopt import BayesSearchCV
from skopt.space import Real
from torchvision.transforms import functional as F, transforms
from tqdm import tqdm

from src.datasets.coco_org import Coco2017Dataset
from src.models.image_encoder import get_image_embedding_module
from src.models.text_encoder import get_text_embedding_module


tqdm.pandas()
plt.style.use('seaborn-v0_8-bright')

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
image_model = get_image_embedding_module("beit", device=device)
text_model = get_text_embedding_module("clip", device=device)
print(f"Using device {device}")

In [None]:
PROJECT_DIR = os.path.abspath(
    "/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/main-project")

SAVE_DIR = os.path.join(PROJECT_DIR, "checkpoints")

COCO_DIR = os.path.abspath(
    "/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/coco-org/coco2017")

In [None]:
coco = Coco2017Dataset(COCO_DIR)
dataset = coco.dataframe
target_classes = coco.topics
dataset = dataset.explode("labels")

In [None]:
lb = LabelBinarizer()
lb.fit(dataset["labels"].tolist())
lb.classes_

In [None]:
def get_k_samples_per_class(df: pd.DataFrame, k: int, seed: int = 42) -> pd.DataFrame:
    print(f"Sampling {k} points per concept")
    _samples = []
    for _concept in target_classes:
        _data = df[df["labels"] == _concept]
        _samples.append(_data.sample(n=k, replace=len(_data) < k, random_state=seed))

    _returned_df = pd.concat(_samples, axis=0)
    _returned_df.reset_index(inplace=True)
    return _returned_df


img_resize = transforms.Resize((224, 224))


@cache
def process_image(path: str) -> np.ndarray:
    image = Image.open(path).convert("RGB")
    image = F.pil_to_tensor(image)
    image = img_resize(image).float()
    encoded = image_model(image).cpu().numpy()
    return encoded


def convert_image_to_X_y(data: pd.DataFrame, show_progress=False) -> tuple[np.ndarray, np.ndarray]:
    all_images = list()
    all_labels = data["labels"].tolist()
    # all_labels = lb.transform(all_labels)

    iter = data.iterrows()
    if show_progress:
        iter = tqdm(iter, total=len(data))

    with torch.no_grad():
        for i, row in iter:
            encoded = process_image(row["filepath"])
            all_images.append(encoded)

    all_images = np.vstack(all_images)
    return all_images, all_labels


@cache
def process_text(text) -> np.ndarray:
    return text_model(text).cpu().numpy()


def convert_text_to_X_y(data: pd.DataFrame, show_progress=False) -> tuple[np.ndarray, np.ndarray]:
    all_text = list()
    all_labels = data["labels"].tolist()
    # all_labels = lb.transform(all_labels)

    iter = data.iterrows()
    if show_progress:
        iter = tqdm(iter, total=len(data))

    with torch.no_grad():
        for i, row in iter:
            encoded = process_text(row["text"])
            all_text.append(encoded)

    all_text = np.vstack(all_text)
    return all_text, all_labels

In [None]:
train_df, test_df = train_test_split(dataset, train_size=0.6, random_state=42)
test_df = get_k_samples_per_class(test_df, 50, seed=42)

print(f"Train size: {len(train_df)}, Test size: {len(test_df)}")

In [None]:
test_image_features, test_image_labels = convert_image_to_X_y(test_df, show_progress=True)
test_text_features, test_text_labels = convert_text_to_X_y(test_df, show_progress=True)

In [None]:
np.array(test_image_labels).shape, np.array(test_text_labels).shape

In [None]:
assert np.array_equal(test_image_labels, test_text_labels), "Labels are not the same"
test_labels = test_image_labels

In [None]:
def few_shot(ks, all_train_df, print_results=True, show_sample_progress=False):
    _best_params = {}
    _reports_str = {}
    _reports_dict = {}

    print(f"ks:{ks.tolist()}")
    for k in tqdm(ks, desc="k"):
        _best_params[k] = {}
        _reports_str[k] = {}
        _reports_dict[k] = {}

        if k < 5:
            _image_classifier = LogisticRegression(verbose=False, max_iter=10_000, multi_class="ovr")
            _text_classifier = LogisticRegression(verbose=False, max_iter=10_000, multi_class="ovr")
        else:
            _image_classifier = BayesSearchCV(
                LogisticRegression(verbose=False, max_iter=100_000, multi_class="ovr"),
                {
                    "C": Real(1e-6, 1e+6, prior="log-uniform"),
                },
                n_jobs=-1, verbose=0
            )
            _text_classifier = BayesSearchCV(
                LogisticRegression(verbose=False, max_iter=100_000, multi_class="ovr"),
                {
                    "C": Real(1e-6, 1e+6, prior="log-uniform"),
                },
                n_jobs=-1, verbose=0
            )

        _samples = get_k_samples_per_class(all_train_df, k)
        _train_image_features, _train_image_labels = convert_image_to_X_y(_samples, show_progress=show_sample_progress)
        _train_text_features, _train_text_labels = convert_text_to_X_y(_samples, show_progress=show_sample_progress)
        print("Encoded all images and texts")

        """TRAIN IMAGE CLASSIFIER"""
        _image_scalar = StandardScaler()
        _train_image_features_scaled = _image_scalar.fit_transform(_train_image_features)
        _test_image_features_scaled = _image_scalar.transform(test_image_features)

        print("Training image classifier")
        _image_classifier.fit(_train_image_features_scaled, _train_image_labels)
        print("Done training image classifier")

        _best_params[k]["image"] = _image_classifier.get_params(deep=True)

        _image_pred = _image_classifier.predict(_test_image_features_scaled)
        _reports_str[k]["image"] = classification_report(test_image_labels, _image_pred, zero_division=1)
        _reports_dict[k]["image"] = classification_report(test_image_labels, _image_pred, output_dict=True, zero_division=1)

        """TRAIN TEXT CLASSIFIER"""
        _text_scalar = StandardScaler()
        _train_text_features_scaled = _text_scalar.fit_transform(_train_text_features)
        _test_text_features_scaled = _text_scalar.transform(test_text_features)

        print("Training text classifier")
        _text_classifier.fit(_train_text_features_scaled, _train_text_labels)
        print("Done training text classifier")

        _best_params[k]["text"] = _text_classifier.get_params(deep=True)

        _text_pred = _text_classifier.predict(_test_text_features_scaled)
        _reports_str[k]["text"] = classification_report(test_text_labels, _text_pred, zero_division=1)
        _reports_dict[k]["text"] = classification_report(test_text_labels, _text_pred, output_dict=True, zero_division=1)

        """GET CROSS MODAL PREDICTION"""
        _image_pred = lb.transform(_image_pred)
        _text_pred = lb.transform(_text_pred)
        _cross_pred = np.multiply(np.array(_image_pred), np.array(_text_pred))
        _cross_pred = lb.inverse_transform(_cross_pred)
        _reports_str[k]["cross"] = classification_report(test_labels, _cross_pred, zero_division=1)
        _reports_dict[k]["cross"] = classification_report(test_labels, _cross_pred, output_dict=True, zero_division=0)

        if print_results:
            print(f"\nk={k}")
            for key in ("image", "text", "cross"):
                print(f"{key.upper()}")
                print(_reports_str[k][key])

        torch.save({
            "best_params": _best_params,
            "reports_str": _reports_str,
            "reports_dict": _reports_dict,
        }, os.path.join(SAVE_DIR, f"few_shot_results-bayes-{k}.pt"))

    return _best_params, _reports_str, _reports_dict

In [None]:
# best_params, reports_str, reports_dict = few_shot(2 ** np.arange(0, 1), train_df, print_results=False, show_sample_progress=False)  # test
best_params, reports_str, reports_dict = few_shot(2 ** np.arange(0, 8), train_df, print_results=False)  # run
# best_params, reports_str, reports_dict = few_shot(2 ** np.arange(, 8), train_df, print_results=False)  # run

In [None]:
 for k, reports in sorted(reports_str.items()):
    print("=" * 20)
    print(f"k={k}")
    for task, report in reports.items():
        print(f"{task.upper()}")
        print(report)
        print("=" * 20)

In [None]:
plot_dir = os.path.abspath(
    "/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/main-project/notebooks/plots")

In [None]:
ks = 2 ** np.arange(0, 8)
results_i2i = np.array([20, 29, 36, 39, 43, 46, 44, 47]) / 100
results_t2t = np.array([41, 46, 53, 56, 57, 57, 59, 59]) / 100
results_i2t = np.array([39, 43, 45, 49, 54, 55, 53, 57]) / 100
results_t2i = np.array([37, 39, 44, 49, 53, 54, 52, 54]) / 100


fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))

titles = [
    ["Image-to-Image", "Text-to-Text"],
    ["Image-to-Text", "Text-to-Image"]
]
y_lims = [
    [[0.201, 0.591], [0.379, 1.05]],
    [[0.359, 0.691], [0.349, 0.691]]  # TODO
]
fs_scores = [
    [results_i2i, results_t2t],
    [results_i2t, results_t2i]  # TODO
]
full_scores = [
    [[0.563, (120, 0.560)], [0.999, (120, 0.989)]],
    [[0.664, (120, 0.660)], [0.662, (120, 0.660)]]  # TODO
]

for r, row in enumerate(ax):
    for c, curr_ax in enumerate(row):
        curr_ax.plot(ks, fs_scores[r][c], label=titles[r][c], marker="x", color="tab:cyan")
        curr_ax.axhline(y=full_scores[r][c][0], linestyle="dashed", label="All data", color="tab:red")
        curr_ax.text(full_scores[r][c][1][0], full_scores[r][c][1][1], full_scores[r][c][0], horizontalalignment='right', verticalalignment='top',
                     color="tab:red", fontsize=10)
        curr_ax.set_xticks(ks)
        curr_ax.set_xticklabels(ks, rotation=45, horizontalalignment="center", fontsize=8)
        curr_ax.set_xlabel("S (no. of samples per concept)", fontsize=8)
        curr_ax.set_ylabel("MAP score", fontsize=8)
        curr_ax.set_title(titles[r][c], fontsize=10)
        curr_ax.legend(loc="lower right", fontsize=8)
        curr_ax.grid(which="major")
        curr_ax.set_ylim(y_lims[r][c])

fig.tight_layout(pad=1.5)
plt.subplots_adjust(top=0.92)
fig.suptitle("MAP few-shot trends on test set", fontsize=12)
plt.savefig(os.path.join(plot_dir, "few-shot-all-trends.png"), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
ks = 2 ** np.arange(0, 8)
durations = np.array([5, 535, 1729, 3048, 5841, 11615, 21762, 61167]) / 1000

plt.figure(figsize=(10, 6))
plt.plot(ks, durations, label="Time to train (1000 sec)", marker="^", color="tab:green")
plt.xticks(ks, rotation=45, horizontalalignment="center")
plt.xlabel("S (no. of samples per concept)")
plt.ylabel("Time (1000 sec)")
plt.grid(which="major")
plt.legend(loc="lower right")
plt.title("Average time taken to train")
plt.savefig(os.path.join(plot_dir, "few-shot-timed.png"), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
ks = 2 ** np.arange(0, 8)
results_dict = torch.load(
    "/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/main-project/checkpoints/few_shot_results-bayes-128.pt")
results_dict.keys()

In [None]:
np.around([results_dict["reports_dict"][k]["image"]["weighted avg"]["precision"] for k in ks], decimals=2)

In [None]:
np.around([results_dict["reports_dict"][k]["text"]["weighted avg"]["precision"] for k in ks], decimals=2)

In [None]:
np.around([results_dict["reports_dict"][k]["cross"]["weighted avg"]["precision"] for k in ks], decimals=2)