In [1]:
# Starts the autoreload extension, which allows editing the .py files with the notebook running and automatically imports the latest changes

%load_ext autoreload
%autoreload 2

import trim_duplicates, model, network, gradcam, plots
from dataset import Dataset
import jax.numpy as jnp
import numpy as np
import jax
import sklearn
import wandb
from trim_duplicates import DuplicatesData
import matplotlib.pyplot as plt

#assert jax.local_device_count() >= 8

NUM_CLASSES = 4
SEED = 14
BATCH_SIZE = 128

def basemodel_process(x): return x

I0000 00:00:1653958818.577289  549031 tpu_initializer_helper.cc:165] libtpu.so already in use by another process probably owned by another user. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
  from .autonotebook import tqdm as notebook_tqdm


In [None]:
rng = jax.random.PRNGKey(SEED)
rng = jax.random.split(jax.random.PRNGKey(1))[0]
dataset_mendeley = Dataset.load("mendeley", rng=rng, official_split=True)
dataset_tawsifur = Dataset.load("tawsifur", rng=rng, official_split=True)
print("Loaded mendeley", dataset_mendeley.classnames)
print("Loaded tawsifur", dataset_tawsifur.classnames)

tcmalloc: large alloc 7241465856 bytes == 0x950c94000 @  0x7f7a6f1fe680 0x7f7a6f21f824 0x7f7a6498df34 0x7f7a6498e64f 0x7f7a649ec4be 0x7f7a649ed121 0x7f7a64a90c96 0x5f2fb9 0x5f3446 0x56fb02 0x56822a 0x5f6033 0x56b115 0x56822a 0x5f6033 0x56ef97 0x56822a 0x5f6033 0x56b115 0x56822a 0x68c1e7 0x5ff1f4 0x5c3cb0 0x569f5e 0x5002e8 0x56b95e 0x5002e8 0x56b95e 0x5002e8 0x503f46 0x56a136
tcmalloc: large alloc 7241465856 bytes == 0x1113814000 @  0x7f7a6f1fe680 0x7f7a6f21f824 0x7f7a6498df34 0x7f7a6498e64f 0x7f7a649eb4ba 0x7f7a649eb598 0x7f7a64a7e38a 0x7f7a64a7edeb 0x50fc9c 0x56a3a0 0x56822a 0x5f6033 0x56b115 0x56822a 0x68c1e7 0x5ff1f4 0x5c3cb0 0x569f5e 0x5002e8 0x56b95e 0x5002e8 0x56b95e 0x5002e8 0x503f46 0x56a136 0x5f5e56 0x569f5e 0x5f5e56 0x56a136 0x56822a 0x5f6033
tcmalloc: large alloc 7241465856 bytes == 0x12c3214000 @  0x7f7a6f1fe680 0x7f7a6f21f824 0x7f7a6498df34 0x7f7a6498e64f 0x7f7a649eb4ba 0x7f7a64a917f9 0x7f7a64a91f47 0x7f7a64a9209c 0x6b224d 0x7f7a649d4574 0x5f305a 0x5f3446 0x56f1ca 0x56822a

In [5]:
def report_dups(title, count, ds):
    print(title + ":", count, "(" + str(round(count / ds.x_all.shape[0] * 1000)/10) + "%)")

def is_dup(indices, i, dont_count):
    for v in indices:
        if len(v) > 1 and i not in dont_count and i in v:
            for j in v:
                if j != i:
                    dont_count[j] = True
            return True
    return False

def count_dups(groups, ds):
    dont_count = {}
    dups_count = 0

    for i in range(ds.x_all.shape[0]):
        if is_dup(groups, i, dont_count):
            dups_count += 1
    
    return dups_count

def show_diff(ds, global_set, max_rows=4):
    actual_name = ds.name.split("/")[-1]
    dups = DuplicatesData.load("dup_data/cv_" + actual_name + "_0.pickle")
    dups_pix = DuplicatesData.load("dup_data/" + actual_name + "_pix.pickle")

    assert global_set == "embed" or global_set == "pix"

    if global_set == "embed":
        global_set = dups.indices
        other_set = dups_pix.indices
    else:
        global_set = dups_pix.indices
        other_set = dups.indices
    
    diff = global_set - dups.indices.intersection(dups_pix.indices)
    
    total_dups_two_sets = 0
    dont_count = {}

    for i in range(ds.x_all.shape[0]):       
        if is_dup(dups, i, dont_count) and is_dup(dups_pix, i, dont_count):
            total_dups_two_sets += 1

    report_dups("Total pix dups", count_dups(dups_pix.indices, ds), ds)
    report_dups("Total embed dups", count_dups(dups.indices, ds), ds)
    # report_dups("Intersection", total_dups_two_sets, ds)

    #assert jnp.all(ds.rng == dups.rng) and jnp.all(ds.rng == dups_pix.rng)

    other_dups_map = {}
    for group in other_set:
        if len(group) > 1:
            for i in group:
                other_dups_map[i] = True
        
    imgs = []
    for v in diff:
        imgs.append([])
        if len(v) > 1:
            for i in v:
                print(i, ds.paths_all[i])
                imgs[len(imgs)-1].append({ "img": ds.x_all[i,:,:,:], "color": "red" if i not in other_dups_map else "black" })
    
    plots.compare_n_images(imgs, rows=max_rows)

show_diff(dataset_mendeley, "embed")

### COUNTS DUPS INTERSECTION ACROSS CROSS VALIDATION SETS ###

actual_name = "tawsifur"
ds = dataset_tawsifur

cvs = []
for i in range(5):
    cvs.append(DuplicatesData.load("dup_data/cv_" + actual_name + "_" + str(i) + ".pickle"))
cvs.append(DuplicatesData.load("dup_data/" + actual_name + "_pix.pickle"))

dups_in_all = 0
dont_count = {}

for i in range(ds.x_all.shape[0]):   
    in_all = True
    for cv in range(len(cvs)):
        if not is_dup(cvs[cv].indices, i, dont_count):
            in_all = False
            break
    
    if in_all:
        dups_in_all += 1

report_dups("INTERSECAO DE TODOS CV E DOS EMBEDDINGS", dups_in_all, ds)

FileNotFoundError: [Errno 2] No such file or directory: 'dup_data/cv_mendeley_0.pickle'

In [6]:
def custom_show_diff(ds, global_set, max_rows=4):
    actual_name = ds.name.split("/")[-1]
    dups = DuplicatesData.load("dup_data/" + actual_name + "_custom.pickle")
    dups_pix = DuplicatesData.load("dup_data/" + actual_name + "_pix.pickle")

    assert global_set == "embed" or global_set == "pix"

    if global_set == "embed":
        global_set = dups.indices
        other_set = dups_pix.indices
    else:
        global_set = dups_pix.indices
        other_set = dups.indices
    
    diff = global_set - dups.indices.intersection(dups_pix.indices)
    
    print("Total embed dups:", len(dups.indices))
    print("Total pix dups:", len(dups_pix.indices))
    print("Complement of difference:", len(diff))
    
    #assert jnp.all(ds.rng == dups.rng) and jnp.all(ds.rng == dups_pix.rng)

    other_dups_map = {}
    for group in other_set:
        if len(group) > 1:
            for i in group:
                other_dups_map[i] = True
        
    imgs = []
    for v in diff:
        imgs.append([])
        if len(v) > 1:
            for i in v:
                print(i, ds.paths_all[i])
                imgs[len(imgs)-1].append({ "img": ds.x_all[i,:,:,:], "color": "red" if i not in other_dups_map else "black" })
    
    plots.compare_n_images(imgs, rows=max_rows)

custom_show_diff(dataset_mendeley, "embed")

FileNotFoundError: [Errno 2] No such file or directory: 'dup_data/mendeley_custom.pickle'

In [77]:
ds_name = "mendeley"
rng = jax.random.PRNGKey(SEED)

ds = Dataset.load(ds_name, rng=rng, official_split=True)

tcmalloc: large alloc 7241465856 bytes == 0x6cae1a000 @  0x7fe52e6ab680 0x7fe52e6cc824 0x7fe523e3af34 0x7fe523e3b64f 0x7fe523e994be 0x7fe523e9a121 0x7fe523f3dc96 0x5f2fb9 0x5f3446 0x56fb02 0x56822a 0x5f6033 0x56b115 0x56822a 0x5f6033 0x5f2b87 0x56b7b0 0x56822a 0x5f6033 0x56ef97 0x56822a 0x5f6033 0x56b115 0x56822a 0x68c1e7 0x5ff1f4 0x5c3cb0 0x569f5e 0x5002e8 0x56b95e 0x5002e8
tcmalloc: large alloc 7241465856 bytes == 0xaac6000 @  0x7fe52e6ab680 0x7fe52e6cc824 0x7fe52e6ccb8a 0x7fe3d705c37c 0x7fe3d2af0520 0x7fe3d2b00228 0x7fe3d2b03aac 0x7fe3d2a40fb2 0x7fe3d2818c88 0x7fe3d2800991 0x5f2fb9 0x5f3446 0x50aa8b 0x56ef97 0x56822a 0x5f6033 0x5f5869 0x664d7d 0x5f2c0e 0x56b7b0 0x56822a 0x5f6033 0x5f2b87 0x56b7b0 0x56822a 0x5f6033 0x5f5869 0x664d7d 0x5f2c0e 0x56b7b0 0x5f5e56


In [None]:
import pickle

def sim(x,y):
    return jnp.dot(x,y) / jnp.sqrt(jnp.dot(x,x) * jnp.dot(y,y))

def show_stats(ds_name):
    dups_pix = DuplicatesData.load("dup_data/" + ds_name + ".pickle")
    print("Total pix dups:", len(dups_pix.indices))
    
    L = ds.y_all.argmax(axis=1)
    c = 0
    
    for d in dups_pix.indices:
        x = np.asarray(d)
        #print(x, L[x], L[x[0]], L[x] == L[x[0]])
        if len(x) > 1:#np.all(L[x] == L[x[0]]) and len(x) > 1:
            c +=1
            
            print(sim(ds.x_all[x[0],].reshape(-1), ds.x_all[x[1],].reshape(-1)))
            
            plots.compare_images(ds.x_all[x,], ds.x_all[x,], rows=3)
            plt.show()
    
            if c > 4:
                break
            
    print("Not dups", c)

show_stats("mendeley")
#show_stats("tawsifur")
#show_stats("covidx")

Total pix dups: 6580
0.92664146


<Figure size 432x288 with 0 Axes>

In [None]:
ds = dataset_tawsifur

In [None]:
import pickle

with open('dup_data/tawsifur.pickle', 'rb') as f:
    dups = pickle.load(f)

In [None]:
dups = list(dups.indices)

In [None]:
for group in dups:
    labels = ds.y_all[np.array(dups[0])].argmax(1)
    if labels.std() != 0:
        print('deu ruim')