In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.datasets as torch_datasets
import torchvision.transforms as torch_transforms

from jaxl.datasets.omniglot import MultitaskOmniglotNShotKWay, MultitaskOmniglotBursty, MultitaskOmniglotBurstyTF, MultitaskOmniglotNShotKWayTF
from jaxl.datasets.wrappers import ContextDataset, FixedLengthContextDataset

import jaxl.transforms as jaxl_transforms

In [None]:
dataset = FixedLengthContextDataset(
    MultitaskOmniglotBursty(
        train_dataset=torch_datasets.Omniglot(
            "/home/bryanpu1/projects/icl/data",
            background=True,
            download=False,
            transform=jaxl_transforms.DefaultPILToImageTransform(),
            target_transform=None,
        ),
        test_dataset=torch_datasets.Omniglot(
            "/home/bryanpu1/projects/icl/data",
            background=False,
            download=False,
            transform=jaxl_transforms.DefaultPILToImageTransform(),
            target_transform=None,
        ),
        num_holdout=10,
        train=True,
        num_sequences=20,
        sequence_length=16,
        p_bursty=0.5,
        min_num_per_class=3,
    ),
    context_len=15,
)

# dataset = FixedLengthContextDataset(
#     MultitaskOmniglotBurstyTF(
#         load_path="/home/bryanpu1/projects/icl/data/tf_omniglot-single_example-all.pkl",
#         train=True,
#         num_holdout=10,
#         num_sequences=50,
#         sequence_length=9,
#         p_bursty=0.0,
#         noise_scale=0.1
#     ),
#     context_len=8,
# )


# dataset = FixedLengthContextDataset(
#     MultitaskOmniglotNShotKWayTF(
#         load_path="/home/bryanpu1/projects/icl/data/tf_omniglot-single_example-all.pkl",
#         train=False,
#         num_holdout=10,
#         num_sequences=50,
#         sequence_length=9,
#         k_way=2,
#         noise_scale=0.1
#     ),
#     context_len=8,
# )

In [None]:
ci, co, q, o = dataset[0]

In [None]:
q.shape

In [None]:
ci_2, co_2, q_2, o_2 = dataset[0]

In [None]:
print(np.allclose(ci, ci_2))
print(np.allclose(co, co_2))
print(np.allclose(q, q_2))
print(np.allclose(o, o_2))

In [None]:
np.argmax(o, axis=-1), np.argmax(co, axis=-1)

In [None]:
ci_2, co_2, q_2, o_2 = dataset[4]

In [None]:
np.argmax(o_2, axis=-1), np.argmax(co_2, axis=-1)

In [None]:
ci.shape, co.shape, q.shape, o.shape

In [None]:
# dataset._data["is_bursty"]

In [None]:
fig = plt.figure(figsize=(15, 2), layout="constrained")
subfigs = fig.subfigures(1, 2, width_ratios=[8, 1])

ax = subfigs[0].subplots(1, 8)
subfigs[0].suptitle("Context")
for idx, (img, output) in enumerate(zip(ci, co)):
    ax[idx].imshow(img)
    ax[idx].set_title(np.argmax(output))
    ax[idx].axis('off')

subfigs[1].suptitle("Query")
subfigs[1].set_facecolor('0.75')
ax = subfigs[1].subplots(1, 1)
ax.imshow(q[0])
ax.set_title("?")
ax.axis('off')

plt.savefig("non_bursty_sequence-noise_scale_0.1.pdf", format="pdf", bbox_inches="tight", dpi=600)

In [None]:
fig, ax = plt.subplots(1, 9)

for idx, (img, output) in enumerate(zip(ci_2, co_2)):
    ax[idx].imshow(img)
    ax[idx].set_title(np.argmax(output))
    ax[idx].axis('off')
ax[-1].imshow(q[0])
ax[-1].set_title(np.argmax(o, axis=-1))
ax[-1].axis('off')
plt.show()

In [None]:
idx = 0

label = np.random.choice(dataset._classes)
base_idx = dataset._label_to_idx[
    label, dataset._data["query_idxes"][idx]
]

if label < dataset._train_size:
    print("train")
    query, l = dataset._train_dataset[
        base_idx
    ]
else:
    print("test")
    query, l = dataset._test_dataset[
        dataset._label_to_idx[
            label, dataset._data["query_idxes"][idx]
        ] - dataset._train_size * 20
    ]

base_idx = dataset._label_to_idx[
    label, dataset._data["query_idxes"][idx]
]

print(label, l)
print(base_idx, base_idx // 20, base_idx % 20, idx)

In [None]:
dataset._classes

In [None]:
dataset._test_size

In [None]:
dataset._train_size

In [None]:
label, l

In [None]:
dataset._label_to_idx[-9]

In [None]:
plt.imshow(query)

In [None]:
plt.imshow(query)

In [None]:
l, label, dataset._train_size, dataset._test_size

In [None]:
dataset._label_to_idx

In [None]:
dataset._classes

In [None]:
print(len(dataset))
# print(dataset._data["is_bursty"])

In [None]:
for ii in range(len(dataset)):
    ci, co, q, o = dataset[ii]

    fig, ax = plt.subplots(1, 9)

    for idx, (img, output) in enumerate(zip(ci, co)):
        ax[idx].imshow(img)
        ax[idx].set_title(np.argmax(output))
        ax[idx].axis('off')
    ax[-1].imshow(q[0])
    ax[-1].set_title(np.argmax(o, axis=-1))
    ax[-1].axis('off')
    plt.show()

In [None]:
dataset = FixedLengthContextDataset(
    MultitaskOmniglotBursty(
        dataset=torch_datasets.Omniglot(
            "/home/bryanpu1/projects/icl/data",
            background=True,
            download=False,
            transform=torch_transforms.Compose([jaxl_transforms.DefaultPILToImageTransform(), jaxl_transforms.Transpose(axes=(1, 2, 0)),]),
            target_transform=None,
        ),
        num_sequences=100,
        sequence_length=9,
        p_bursty=1.0,
        min_num_per_class=1,
    ),
    context_len=8,
)
    

for ii in range(len(dataset)):
    ci, co, q, o = dataset[ii]

    fig, ax = plt.subplots(1, 9)

    for idx, (img, output) in enumerate(zip(ci, co)):
        ax[idx].imshow(img)
        ax[idx].set_title(np.argmax(output))
        ax[idx].axis('off')
    ax[-1].imshow(q[0])
    ax[-1].set_title(np.argmax(o, axis=-1))
    ax[-1].axis('off')
    plt.show()

In [None]:
assert 0

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import torchvision.datasets as torch_datasets

import jaxl.transforms as jaxl_transforms

In [None]:
data = {
    "train": {},
    "test": {},
}

for split in ("train", "test"):
    ds = tfds.load(
        'omniglot', split=split, as_supervised=True, shuffle_files=False)

    def _extract_image(image, label):
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.rgb_to_grayscale(image)
        return image, label

    for image, label in ds.map(_extract_image):
        label = label.numpy().astype(np.int32)
        # Populate the dictionary of {label: image} entries.
        # Only add to the dataset if that class doesn't already exist.
        if label not in data[split]:
            image = image.numpy()
            data[split][label] = image

In [None]:
len(data["train"]), len(data["test"])

In [None]:
data = {**data["train"], **data["test"]}

In [None]:
import _pickle as pickle

In [None]:
pickle.dump(
    {
        "data": data,
        "targets": np.array(list(data.keys())),
        "img_shape": (105, 105, 1),
    },
    open("tf_omniglot-single_example-all.pkl", "wb")
)

In [None]:
type(label), type(image.numpy())