In [1]:
import os
import numpy as np
from tqdm import tqdm
from PIL import Image
import pickle

In [2]:
classes = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    # "dining_table",
    # "dog",
    # "horse",
    # "motorbike",
    # "person",
    # "potted_plant",
    # "sheep",
    # "sofa",
    # "train",
    # "tv_monitor",
]

In [3]:
def get_prompt(mode, class_name):
    if mode == "baseline":
        return f"photo of a {class_name}"
    elif mode == "mrte":
        return f"realistic and modern photo of a {class_name} in its typical environment"

In [4]:
def count_images(replay_root, task):
    overall_needed_imgs = {c: 0 for c in classes}
    for ov in [False, True]:
        ov_string = "-ov" if ov else ""
        task_and_ov = task + ov_string
        multiplier = 2 if ov else 3
        print(f"Checking {replay_root}/{task_and_ov}")
        idxs_path = f"/home/thesis/marx/wilson_gen/WILSON/data/voc/{task_and_ov}/train-0.npy"
        idxs = np.load(idxs_path)
        # print(len(idxs))
        # print(idxs)
        with open("/home/thesis/marx/wilson_gen/WILSON/data/voc/splits/train_aug.txt") as f:
            lines = f.readlines()
            lines = [l.strip().split(" ")[1][1:] for l in lines]
            image_names = np.array(lines)[idxs]
        print("    counting voc images")
        img_counts = {c: [0, 0] for c in classes}
        for image_name in image_names:
            labels = Image.open(f"/home/thesis/marx/wilson_gen/WILSON/data/voc/{image_name}")
            labels = np.array(labels)
            uniques, counts = np.unique(labels, return_counts=True)
            if 0 in uniques:
                uniques = uniques[1:]
                counts = counts[1:]
            uniques = [unq for unq in uniques if unq <= 10]
            counts = counts[:len(uniques)]
            principle_class = classes[uniques[np.argmax(counts)]-1]
            img_counts[principle_class][0] += 1
        print("    counting gen images")
        for c in classes:
            n_imgs = len(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/{replay_root}/{task_and_ov}/{c}/images/"))
            img_counts[c][1] = n_imgs
        for c in classes:
            print(f"    {(c+':').ljust(10)} {(str(img_counts[c]) + ',').ljust(13)} gen = {multiplier}xvoc: {img_counts[c][1] == multiplier * img_counts[c][0]}")
            overall_needed_imgs[c] = max(overall_needed_imgs[c], img_counts[c][0] * multiplier - img_counts[c][1])
        class_counts = {c: img_counts[c][0] for c in classes}
        with open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_root}/{task_and_ov}/class_counts.pkl", "wb") as f:
            pickle.dump(class_counts, f)
    for c in classes:
        print(f"Needed {(c+':').ljust(10)} {overall_needed_imgs[c]}")
    mode = replay_root.split("_")[-1]
    # print("    generating refill prompt file")
    # if sum(list(overall_needed_imgs.values())) != 0:
    #     with open(f"/home/thesis/marx/wilson_gen/voc_{mode}_prompts_refill.txt", "w") as f:
    #         for c in classes:
    #             p = get_prompt(mode, c)
    #             for i in range(overall_needed_imgs[c]):
    #                 f.write(p + "\n")
    print("")

In [5]:
for replay_root in ["replay_data_baseline", "replay_data_mrte"]:
    for task in ["10-10"]:
        count_images(replay_root, task)

Checking replay_data_baseline/10-10
    counting voc images
    counting gen images
    aeroplane: [527, 1581],  gen = 3xvoc: True
    bicycle:   [176, 528],   gen = 3xvoc: True
    bird:      [642, 1926],  gen = 3xvoc: True
    boat:      [298, 894],   gen = 3xvoc: True
    bottle:    [159, 477],   gen = 3xvoc: True
    bus:       [204, 612],   gen = 3xvoc: True
    car:       [469, 1760],  gen = 3xvoc: False
    cat:       [812, 2436],  gen = 3xvoc: True
    chair:     [204, 1672],  gen = 3xvoc: False
    cow:       [215, 645],   gen = 3xvoc: True
Checking replay_data_baseline/10-10-ov
    counting voc images
    counting gen images
    aeroplane: [578, 1581],  gen = 2xvoc: False
    bicycle:   [443, 886],   gen = 2xvoc: True
    bird:      [678, 1926],  gen = 2xvoc: False
    boat:      [448, 896],   gen = 2xvoc: True
    bottle:    [495, 990],   gen = 2xvoc: True
    bus:       [368, 736],   gen = 2xvoc: True
    car:       [880, 1760],  gen = 2xvoc: True
    cat:       [958, 2436]

In [9]:
# REMOVE UNNECESSARY IMAGES
for replay_root in ["replay_data_baseline", "replay_data_mrte"]:
    task = "10-10"
    for ov in [False, True]:
        ov_string = "-ov" if ov else ""
        task_and_ov = task + ov_string
        multiplier = 2 if ov else 3
        print(f"Working on {replay_root}/{task_and_ov}")
        with open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_root}/{task_and_ov}/class_counts.pkl", "rb") as f:
            class_counts = pickle.load(f)
        for c in classes:
            image_names = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/{replay_root}/{task_and_ov}/{c}/images/"))[class_counts[c]*multiplier:]
            for img in image_names:
                # uncomment below to actually do it
                # os.remove(f"/home/thesis/marx/wilson_gen/WILSON/{replay_root}/{task_and_ov}/{c}/images/{img}")
                # try:
                #     os.remove(f"/home/thesis/marx/wilson_gen/WILSON/{replay_root}/{task_and_ov}/{c}/pseudolabels/{img}")
                # except FileNotFoundError:
                #     pass
                pass
            print(f"    removed {len(image_names)} images for {c}".ljust(40) + f"preexisting: {len(os.listdir(f'/home/thesis/marx/wilson_gen/WILSON/{replay_root}/{task_and_ov}/{c}/images/'))}".ljust(20) + f"needed: {class_counts[c]*multiplier}".ljust(13) + f"= {multiplier} x {class_counts[c]}")

Working on replay_data_baseline/10-10
    removed 0 images for aeroplane      preexisting: 1581   needed: 1581 = 3 x 527
    removed 0 images for bicycle        preexisting: 528    needed: 528  = 3 x 176
    removed 0 images for bird           preexisting: 1926   needed: 1926 = 3 x 642
    removed 0 images for boat           preexisting: 894    needed: 894  = 3 x 298
    removed 0 images for bottle         preexisting: 477    needed: 477  = 3 x 159
    removed 0 images for bus            preexisting: 612    needed: 612  = 3 x 204
    removed 352 images for car          preexisting: 1407   needed: 1407 = 3 x 469
    removed 0 images for cat            preexisting: 2436   needed: 2436 = 3 x 812
    removed 1060 images for chair       preexisting: 612    needed: 612  = 3 x 204
    removed 0 images for cow            preexisting: 645    needed: 645  = 3 x 215
Working on replay_data_baseline/10-10-ov
    removed 425 images for aeroplane    preexisting: 1156   needed: 1156 = 2 x 578
    remo