In [None]:
import random
from pathlib import Path
from typing import List, Tuple

import pandas as pd

In [None]:
import plotly.express as px
import sklearn.pipeline
import torch
from nn_core.serialization import load_model, NNCheckpointIO
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer, AutoTokenizer
from pytorch_lightning import seed_everything

In [None]:
from rae.data.text import TREC
from rae.modules.attention import RelativeAttention, AttentionOutput
from rae.pl_modules.pl_text_classifier import LightningTextClassifier
from rae import PROJECT_ROOT

In [None]:
def load_ckpt(ckpt_path: Path):
    return load_model(module_class=LightningTextClassifier, checkpoint_path=ckpt_path, strict=False).eval()

In [None]:
CODE_VERSION = 0.1

device: str = "cuda"

In [None]:
from datasets import load_dataset, ClassLabel

dataset_path: Path = Path("/mnt/data/projects/N24News/nytimes_dataset_full.json")
dataset = load_dataset("json", data_files=str(dataset_path))["train"]

dataset = dataset.add_column(name="label", column=dataset["section"])
all_labels = sorted(set(dataset["label"]))
dataset = dataset.cast_column("label", ClassLabel(names=all_labels))
dataset

In [None]:
datasets_dir: Path = PROJECT_ROOT / "data" / "hf_datasets"
datasets_dir.mkdir(exist_ok=True, parents=True)

In [None]:
def encode_field(batch, src_field: str, tgt_field: str, transformation):
    data = batch[src_field]
    transformed = transformation(data)

    return {tgt_field: transformed}

In [None]:
from typing import *
import torch


def tokenize(texts: Sequence[str], tokenizer):
    pass


@torch.no_grad()
def text_encode(texts: Sequence[str], tokenizer, transformer):
    encoding = tokenizer(
        texts,
        return_tensors="pt",
        return_special_tokens_mask=True,
        truncation=True,
        padding=True,
    ).to(device)
    mask = encoding["attention_mask"] * encoding["special_tokens_mask"].bool().logical_not()
    del encoding["special_tokens_mask"]

    encoding = transformer(**encoding)
    encoding = encoding["hidden_states"][-1]

    result = []
    for sample_encoding, sample_mask in zip(encoding, mask):
        result.append(sample_encoding[sample_mask].mean(dim=0).cpu().numpy())

    return result

In [None]:
def load_transformer(transformer_name):
    transformer = AutoModel.from_pretrained(transformer_name, output_hidden_states=True, return_dict=True)
    transformer.requires_grad_(False).eval()
    return transformer, AutoTokenizer.from_pretrained(transformer_name)

In [None]:
text_encoded_dir: Path = datasets_dir / "N24News" / "text_encoded"

In [None]:
import itertools
import functools

transformers = ("roberta-base",)
fields = ("body",)
FORCE_COMPUTE: bool = False

for transformer_name, src_field in itertools.product(transformers, fields):
    tgt_field: str = f"{src_field}_{transformer_name}"
    if tgt_field not in dataset or FORCE_COMPUTE:
        transformer, tokenizer = load_transformer(transformer_name=transformer_name)
        transformer = transformer.to(device)
        dataset = dataset.map(
            functools.partial(
                encode_field,
                src_field=src_field,
                tgt_field=tgt_field,
                transformation=functools.partial(
                    text_encode,
                    transformer=transformer,
                    tokenizer=tokenizer,
                ),
            ),
            num_proc=1,
            batched=True,
            batch_size=32,
            desc=f"text_encoding field <{src_field}> with <{transformer_name}>",
        )
        transformer = transformer.cpu()
        dataset.set_format(type="torch", columns=[tgt_field], output_all_columns=True)
dataset.save_to_disk(str(text_encoded_dir))

In [None]:
from datasets import load_from_disk

dataset = load_from_disk(text_encoded_dir)
dataset

In [None]:
image_encoded_dir: Path = datasets_dir / "N24News" / "image_encoded"

In [None]:
from typing import *
import torch

from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

base_path: Path = Path("/mnt/data/projects/N24News/images")


@torch.no_grad()
def image_encode(images: Sequence[str], transform, encoder):
    images = [Image.open(str(base_path / f"{image}.jpg")).convert("RGB") for image in images]

    images: Sequence[torch.Tensor] = [transform(image) for image in images]
    images: torch.Tensor = torch.stack(images, dim=0).to(device)
    encoding = encoder(images)

    return list(encoding.cpu().numpy())

In [None]:
import itertools
import functools

import timm

encoders = ("vit_base_patch16_224",)
FORCE_COMPUTE: bool = False

for encoder_name in encoders:
    tgt_field: str = f"image_{encoder_name}"
    if tgt_field not in dataset or FORCE_COMPUTE:
        encoder = timm.create_model(encoder_name, pretrained=True, num_classes=0).to(device)
        config = resolve_data_config({}, model=encoder)
        transform = create_transform(**config)
        encoder.eval()
        dataset = dataset.map(
            functools.partial(
                encode_field,
                src_field="image_id",
                tgt_field=tgt_field,
                transformation=functools.partial(
                    image_encode,
                    transform=transform,
                    encoder=encoder,
                ),
            ),
            num_proc=1,
            batched=True,
            batch_size=64,
            desc=f"image_encoding field <{src_field}> with <{encoder_name}>",
        )
        encoder = encoder.cpu()
        dataset.set_format(type="torch", columns=[tgt_field], output_all_columns=True)
dataset.save_to_disk(str(image_encoded_dir))

In [None]:
dataset = load_from_disk(image_encoded_dir)
dataset

In [None]:
dataset = dataset.train_test_split(test_size=0.1, stratify_by_column="label", seed=42)
dataset

In [None]:
dataset.save_to_disk(str(datasets_dir / "N24News" / "encoded"))

In [None]:
dataset = load_from_disk(str(datasets_dir / "N24News" / "encoded"))
dataset.set_format(type="torch", columns=["body_roberta-base", "image_vit_base_patch16_224"])