In [None]:
!git clone https://github.com/juanserrano90/codelatam.git

Cloning into 'codelatam'...
remote: Enumerating objects: 75745, done.[K
remote: Counting objects: 100% (3928/3928), done.[K
remote: Compressing objects: 100% (3923/3923), done.[K
remote: Total 75745 (delta 5), reused 3923 (delta 5), pack-reused 71817 (from 2)[K
Receiving objects: 100% (75745/75745), 693.30 MiB | 24.19 MiB/s, done.
Resolving deltas: 100% (1265/1265), done.
Updating files: 100% (90957/90957), done.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Global definitions ---------------
data_dir = "/content/codelatam/Data"
working_dir = "/content/drive/MyDrive/Doctorado/Codelatam/Files_codelatam"
num_classes = 3
inv_dict_mapping_classes = {0:'Ia-norm', 1:'Ia-pec', 2:'Others'}
dataset_folder = 'Dataset_augmented_images'

def subtype_to_class_mapping(a):
  subtype_to_class = {0:0, 1:1, 2:1, 3:1, 4:1, 5:1, 6:2, 7:2, 8:2, 9:2, 10:2, 11:2, 12:2, 13:2, 14:2, 15:2, 16:2}
  return subtype_to_class[a]

def id_to_subtype_mapping(a):
  id_to_subtype = {0: 'Ia-norm', 1: 'Ia-91T', 3: 'Ia-csm', 2: 'Ia-91bg', 6: 'Ib-norm', 4: 'Iax', 5: 'Ia-pec', 10: 'Ic-norm',
                   13: 'IIP', 14: 'IIL', 8: 'IIb', 16: 'II-pec', 11: 'Ic-broad', 12: 'Ic-pec', 15: 'IIn', 7: 'Ibn', 9: 'Ib-pec'}
  return id_to_subtype[a]

# Target ratios for Train, Val, Test
SPLIT_RATIOS = {'train': 0.8, 'val': 0.1, 'test': 0.1}

In [None]:
def load_split(n):
  with open(f"{data_dir}/Splits/saved_train_val_test_split_{n}.pkl", 'rb') as f:
    splits = pickle.load(f)
  return splits
def load_split_drive(n):
  with open(f"/content/drive/MyDrive/Doctorado/Codelatam/Files_codelatam/Splits/saved_train_val_test_split_{n}.pkl", 'rb') as f:
    splits = pickle.load(f)
  return splits
def verify_split_stats(split_results):
    print(
        f"{'Class':<6} | {'Set':<6} | {'Images':<6} | {'%':<5} | "
        f"{'Unique SNs':<10} | {'Subtypes':<8} | {'COPY imgs'}"
    )
    print("-" * 90)

    for cls in split_results:
        total_imgs = sum(len(split_results[cls][s]) for s in split_results[cls])

        for s in ['train', 'val', 'test']:
            imgs = split_results[cls][s]
            n_imgs = len(imgs)
            pct = (n_imgs / total_imgs * 100) if total_imgs > 0 else 0

            sns = [get_sn_info(i)[0] for i in imgs]
            subtypes = [get_sn_info(i)[1] for i in imgs]

            unique_sns = set(sns)
            unique_subtypes = set(st for st in subtypes if st is not None)

            n_copy = sum("COPY" in img for img in imgs)

            print(
                f"{cls:<6} | {s:<6} | {n_imgs:<6} | {pct:4.1f}% | "
                f"{len(unique_sns):<10} | {len(unique_subtypes):<8} | {n_copy}"
            )
        print("-" * 90)

def get_sn_info(filename):
    try:
        parts = filename.split('_')
        sn_name = parts[0]
        subtype = int(parts[1])
        return sn_name, subtype
    except:
        return None, None

# This was used before.. Not used now
def intelligent_split(all_images_dict, ratios):
    """
    Splits data maintaining:
    1. No SN leakage (Unique SNs in one set only)
    2. Balanced Subtypes (Subtypes distributed across sets)
    3. Balanced Image Counts (Target ratios respected)
    4. Heavy SNs prioritize Train (via greedy capacity filling)
    """
    final_split = {}

    # 1. ORGANIZE DATA
    # Structure: structured_data[class][subtype] = [ {'name': sn1, 'imgs': [...]}, ... ]
    structured_data = {}

    for cls, img_list in all_images_dict.items():
        structured_data[cls] = {}

        # Group images by SN first
        sn_groups = {}
        for img in img_list:
            sn_name, subtype = get_sn_info(img)
            if sn_name is None: continue

            if sn_name not in sn_groups:
                sn_groups[sn_name] = {'subtype': subtype, 'images': []}
            sn_groups[sn_name]['images'].append(img)

        # Regroup by Subtype
        for sn, data in sn_groups.items():
            st = data['subtype']
            if st not in structured_data[cls]:
                structured_data[cls][st] = []

            # Store object with image count for sorting later
            structured_data[cls][st].append({
                'sn_name': sn,
                'images': data['images'],
                'count': len(data['images'])
            })

    # 2. PERFORM SPLIT
    for cls in structured_data:
        final_split[cls] = {'train': [], 'val': [], 'test': []}

        # Track current image counts for this class to guide the greedy algorithm
        set_counts = {'train': 0, 'val': 0, 'test': 0}

        # Process each subtype separately to ensure stratification
        for subtype in structured_data[cls]:
            sn_objects = structured_data[cls][subtype]

            # CRITICAL STEP: Sort by image count DESCENDING.
            # This ensures "Heavy" SNs are processed first.
            # Since Train has the highest target capacity, it naturally absorbs
            # the large objects that would otherwise overflow Val/Test.
            sn_objects.sort(key=lambda x: x['count'], reverse=True)

            total_subtype_imgs = sum(x['count'] for x in sn_objects)

            for sn_obj in sn_objects:
                count = sn_obj['count']

                # Calculate which set is most "under-filled" relative to target ratio
                # Score = (Current Count + New Item Count) / Target Ratio
                # We want the set that minimizes this 'fullness' score, or
                # strictly fits the capacity.

                best_set = None
                best_score = float('inf')

                total_current = sum(set_counts.values()) + count
                if total_current == 0: total_current = 1 # Avoid div by zero

                for s_name in ['train', 'val', 'test']:
                    # Normalized fullness: How close are we to the target %?
                    # A lower score means this set is "hungrier"
                    current_pct = set_counts[s_name] / total_current
                    target = ratios[s_name]

                    # Score: Distance from target (prioritize keeping below target)
                    score = (set_counts[s_name] + count) / target

                    if score < best_score:
                        best_score = score
                        best_set = s_name

                # Assign SN to the chosen set
                final_split[cls][best_set].extend(sn_obj['images'])
                set_counts[best_set] += count

    return final_split

import pickle
import numpy as np

split0 = load_split(5)
all_images = {'0': split0['0']['train']+split0['0']['val']+split0['0']['test'],
              '1': split0['1']['train']+split0['1']['val']+split0['1']['test'],
              '2': split0['2']['train']+split0['2']['val']+split0['2']['test']}

In [None]:
# Total images in augmented_dataset_v2.0
total = []
for key, value in all_images.items():
    total.append(len(value))
print('total images:', sum(total))

total images: 7159


In [None]:
# How many original images have copies?
with_copies = []
for key, value in all_images.items():
  for image in value:
    if "COPY" in image:
      with_copies.append(image.split('_')[0])

with_copies = set(with_copies)
len(with_copies)

95

In [None]:
# How many originals per-class?
originals = {}
for key, value in all_images.items():
  originals[key] = [image for image in value if "COPY" not in image]

total_originals = []
for key, value in originals.items():
    total_originals.append(len(value))
    print(f'originals class {key}: {len(value)}')
print('original images:',sum(total_originals))

originals class 0: 2387
originals class 1: 901
originals class 2: 1416
original images: 4704


In [None]:
# How many originals have zero copies
# This are the ones eligible for test-set
originals_eligible = {}
for key, value in all_images.items():
  originals_eligible[key] = [image for image in value if image.split('_')[0] not in with_copies]

total_originals_eligible = []
for key, value in originals_eligible.items():
    total_originals_eligible.append(len(value))
    print(f'originals eligible class {key}: {len(value)}')
print('original eligible images:',sum(total_originals_eligible))

originals eligible class 0: 2387
originals eligible class 1: 62
originals eligible class 2: 1084
original eligible images: 3533


In [None]:
# With our current eligible images its impossible to
# split in test and validation without using copies in them
c = 1
subty = []
for image in originals[f"{c}"]:
  sn_name, subtype = get_sn_info(image)
  subty.append(subtype)
subty = np.array(subty)
print(np.unique(subty, return_counts=True))

subty = []
for image in originals_eligible[f"{c}"]:
  sn_name, subtype = get_sn_info(image)
  subty.append(subtype)
subty = np.array(subty)
print(np.unique(subty, return_counts=True))

(array([1, 2, 3, 4, 5]), array([398, 264,  30,  68, 141]))
(array([1, 2]), array([51, 11]))


In [None]:
# However, the test set should note include augmented data. A
# And it must represent the original distribution.
# We will keep the class imbalance in test/val
# And training will be balanced via data augmentation.
# Then we can drop copies that we will not have to use in train,
# This will give more eligible images for test and val splits.

# After doing numbers (using ratios 0.8,0.1,0.1 from original images)
# We must drop 294 copies (or spectra) from class 1 and 193 from class 2
# that is 58xsubtype for class 1 and 17xsubtype for class 2

In [None]:
# To get subtype for a given basename
def get_subtype_from_basename(basename):
  for key, value in all_images.items():
    for image in value:
      if image.split('_')[0] == basename:
        return int(image.split('_')[1])
        break

In [None]:
# Store the basenames that have copies for each class
copies = {'0': [], '1': [], '2': []}
for key, value in reduced_all_images.items():
  for image in value:
    if 'COPY' in image:
      copies[key].append(image.split('_')[0])

  copies[key] = set(copies[key])

# For a given basename that has copies, give number of originals and copies
def originals_and_copies(basename, dataset):
  copies = 0
  orig = 0
  for key, value in dataset.items():
    for image in value:
      base = image.split('_')[0]
      if base == basename:
        if "COPY" in image:
          copies = copies + 1
        else:
          orig = orig + 1
  subtype = get_subtype_from_basename(basename)
  print(f"SN {basename} (subtype {subtype}) has {orig} originals and {copies} copies")
  return orig, copies, subtype, basename

In [None]:
# Get data from filenames
def parse_filename(fname):
    parts = fname.replace('.png', '').split('_')
    sn = parts[0]
    subtype = int(parts[1])
    is_copy = fname.endswith('_COPY.png')
    return sn, subtype, is_copy


In [None]:
from collections import defaultdict
import random

# Reduce dataset size to new constraints
def drop_images_evenly(images, n_drop):
    """
    images: list of filenames (same class & subtype)
    n_drop: total images to drop
    """

    # Group images by SN
    sn_groups = defaultdict(list)
    for img in images:
        sn, _, _ = parse_filename(img)
        sn_groups[sn].append(img)

    # Separate COPY and ORIGINAL images per SN
    sn_copies = {
        sn: [img for img in imgs if img.endswith('_COPY.png')]
        for sn, imgs in sn_groups.items()
    }

    dropped = set()

    # Round-robin removal from COPY images
    sn_list = list(sn_groups.keys())
    idx = 0

    while len(dropped) < n_drop:
        sn = sn_list[idx % len(sn_list)]
        if sn_copies[sn]:
            dropped.add(sn_copies[sn].pop())
        idx += 1

        # Stop if no COPY images remain anywhere
        if all(len(v) == 0 for v in sn_copies.values()):
            break

    # If still need to drop, fall back to originals (last resort)
    if len(dropped) < n_drop:
        remaining = [
            img for img in images
            if img not in dropped
        ]
        needed = n_drop - len(dropped)
        dropped.update(random.sample(remaining, needed))

    return dropped

def build_reduced_dataset(all_images):
    reduced = {}

    # Drop rules
    drop_rules = {
        '1': 58,
        '2': 17
    }

    for cls, images in all_images.items():
        if cls not in drop_rules:
            reduced[cls] = images.copy()
            continue

        # Group by subtype
        by_subtype = defaultdict(list)
        for img in images:
            _, subtype, _ = parse_filename(img)
            by_subtype[subtype].append(img)

        to_drop = set()

        for subtype, imgs in by_subtype.items():
            dropped = drop_images_evenly(imgs, drop_rules[cls])
            to_drop.update(dropped)

        reduced[cls] = [img for img in images if img not in to_drop]

    return reduced


In [None]:
reduced_all_images = build_reduced_dataset(all_images)

In [None]:
# Sanity check the dropped images
def check_drops(original, reduced, cls):
    from collections import Counter
    o = Counter(parse_filename(i)[1] for i in original[cls])
    r = Counter(parse_filename(i)[1] for i in reduced[cls])

    for st in sorted(o):
        print(f"class {cls}, subtype {st}: dropped = {o[st] - r[st]}")


In [None]:
check_drops(all_images, reduced_all_images, '1')
check_drops(all_images, reduced_all_images, '2')

class 1, subtype 1: dropped = 58
class 1, subtype 2: dropped = 58
class 1, subtype 3: dropped = 59
class 1, subtype 4: dropped = 58
class 1, subtype 5: dropped = 58
class 2, subtype 6: dropped = 17
class 2, subtype 7: dropped = 17
class 2, subtype 8: dropped = 17
class 2, subtype 9: dropped = 17
class 2, subtype 10: dropped = 17
class 2, subtype 11: dropped = 17
class 2, subtype 12: dropped = 17
class 2, subtype 13: dropped = 17
class 2, subtype 14: dropped = 17
class 2, subtype 15: dropped = 17
class 2, subtype 16: dropped = 17


In [None]:
# Which images have copies now?
copies2 = {'0': [], '1': [], '2': []}
for key, value in reduced_all_images.items():
  for image in value:
    if 'COPY' in image:
      copies2[key].append(image.split('_')[0])

  copies2[key] = set(copies2[key])

In [None]:
with_copies2 = []
for key, value in reduced_all_images.items():
  for image in value:
    if "COPY" in image:
      with_copies2.append(image.split('_')[0])

with_copies2 = set(with_copies2)
len(with_copies2)

64

In [None]:
# Original images updated to reduced dataset
originals2 = {}
for key, value in reduced_all_images.items():
  originals2[key] = [image for image in value if "COPY" not in image]

total_originals2 = []
for key, value in originals2.items():
    total_originals2.append(len(value))
    print(f'originals class {key}: {len(value)}')
print('original images:',sum(total_originals2))

originals class 0: 2387
originals class 1: 901
originals class 2: 1331
original images: 4619


In [None]:
# Eligible for val/test updated to reduced dataset
originals_eligible2 = {}
for key, value in reduced_all_images.items():
  originals_eligible2[key] = [image for image in value if image.split('_')[0] not in with_copies2]

total_originals_eligible2 = []
for key, value in originals_eligible2.items():
    total_originals_eligible2.append(len(value))
    print(f'originals eligible class {key}: {len(value)}')
print('original eligible images:',sum(total_originals_eligible2))

originals eligible class 0: 2387
originals eligible class 1: 262
originals eligible class 2: 1031
original eligible images: 3680


In [None]:
# For class c, what is the subtype representation in originals2 and eligible originals2
c = 1
subty = []
for image in originals2[f"{c}"]:
  sn_name, subtype = get_sn_info(image)
  subty.append(subtype)
subty = np.array(subty)
print(np.unique(subty, return_counts=True))

subty = []
for image in originals_eligible2[f"{c}"]:
  sn_name, subtype = get_sn_info(image)
  subty.append(subtype)
subty = np.array(subty)
print(np.unique(subty, return_counts=True))

(array([1, 2, 3, 4, 5]), array([398, 264,  30,  68, 141]))
(array([1, 2, 5]), array([235,  26,   1]))


In [None]:
for basename in copies2['1']:
  orig, copies, subtype = originals_and_copies(basename, reduced_all_images)

SN sn02cx (subtype 4) has 8 originals and 73 copies
SN sn2002dl (subtype 2) has 4 originals and 1 copies
SN sn02fb (subtype 2) has 2 originals and 2 copies
SN sn06ke (subtype 2) has 1 originals and 1 copies
SN sn2005gj (subtype 3) has 22 originals and 182 copies
SN sn91T (subtype 1) has 21 originals and 1 copies
SN sn2000cn (subtype 2) has 10 originals and 4 copies
SN sn2008ae (subtype 5) has 5 originals and 13 copies
SN sn99by (subtype 2) has 15 originals and 12 copies
SN sn00cx (subtype 5) has 26 originals and 82 copies
SN sn05ke (subtype 2) has 3 originals and 1 copies
SN sn1998es (subtype 1) has 26 originals and 1 copies
SN sn2005ke (subtype 2) has 10 originals and 1 copies
SN sn2006hb (subtype 2) has 8 originals and 4 copies
SN sn03gq (subtype 4) has 1 originals and 4 copies
SN sn2003Y (subtype 2) has 3 originals and 2 copies
SN sn2008A (subtype 4) has 14 originals and 133 copies
SN sn2006oa (subtype 1) has 7 originals and 1 copies
SN sn2007al (subtype 2) has 8 originals and 1 cop

In [None]:
from collections import defaultdict

def group_by_sn(images):
    sn_groups = defaultdict(lambda: {
        'images': [],
        'subtype': None
    })

    for img in images:
        sn, subtype = get_sn_info(img)
        if sn is None:
            continue
        sn_groups[sn]['images'].append(img)
        sn_groups[sn]['subtype'] = subtype

    return sn_groups


In [None]:
import random
from collections import defaultdict

# Function to select the test set first from eligible2
def select_test_set_fast(
    images,
    target_images,
    target_sns,
    required_subtypes=None,
    max_len_sn=0.2,
    img_tol=5,
    sn_tol=3,
    max_tries=2000,
    seed=42
):
    random.seed(seed)

    # Group by SN
    sn_groups = defaultdict(lambda: {'images': [], 'subtype': None})
    for img in images:
        sn, subtype = get_sn_info(img)
        if sn is None:
            continue
        sn_groups[sn]['images'].append(img)
        sn_groups[sn]['subtype'] = subtype

    # Filter by subtype
    if required_subtypes is not None:
        sn_groups = {
            sn: d for sn, d in sn_groups.items()
            if d['subtype'] in required_subtypes
        }

    sn_items = list(sn_groups.items())

    best = None
    best_score = float('inf')

    for _ in range(max_tries):
        random.shuffle(sn_items)

        chosen = []
        img_sum = 0

        for sn, data in sn_items:
            n = len(data['images'])
            if img_sum + n <= target_images + img_tol:
                if n < max_len_sn*target_images:
                  chosen.append(sn)
                  img_sum += n

        sn_count = len(chosen)
        img_error = abs(img_sum - target_images)
        sn_error = abs(sn_count - target_sns)

        score = img_error * 10 + sn_error  # prioritize image count

        if img_error <= img_tol and sn_error <= sn_tol:
            best = chosen
            break

        if score < best_score:
            best_score = score
            best = chosen

    # Flatten images
    test_images = []
    for sn in best:
        test_images.extend(sn_groups[sn]['images'])

    return test_images



In [None]:
def create_test_set(seed):
  test_set = {}

  test_set['0'] = select_test_set_fast(
      originals_eligible2['0'],
      target_images=238,
      target_sns=30,
      max_len_sn=0.2,
      img_tol=5,
      sn_tol=4,
      seed=seed+10
  )

  test_set['1'] = select_test_set_fast(
      originals_eligible2['1'],
      target_images=90,
      target_sns=9,
      max_len_sn=0.3,
      required_subtypes={1, 2, 5},
      img_tol=4,
      sn_tol=2,
      seed=seed+20
  )

  test_set['2'] = select_test_set_fast(
      originals_eligible2['2'],
      target_images=141,
      target_sns=10,
      max_len_sn=0.2,
      required_subtypes={6, 8, 10, 11, 13, 16},
      img_tol=5,
      sn_tol=3,
      seed=seed+30
  )

  return test_set

In [None]:
def check_test_set(test_set):
    subt = {'0': [], '1': [], '2': []}
    for cls, imgs in test_set.items():
        sns = set()
        subtypes = set()
        for img in imgs:
            sn, st = get_sn_info(img)
            sns.add(sn)
            subtypes.add(st)
        print(f"\nClass {cls}")
        print(" images:", len(imgs))
        print(" unique SNs:", len(sns))
        print(" subtypes:", sorted(subtypes))
        subt[cls].append(len(subtypes))
    return subt

In [None]:
k=0
while True:
  test_set2 = create_test_set(seed=145+k)
  subt = check_test_set(test_set2)
  # print(subt)
  if subt['2'][0] == 5 and subt['1'][0] == 3: # ensure the maximum subtype representation possible
    break
  k = k + 1


Class 0
 images: 243
 unique SNs: 31
 subtypes: [0]

Class 1
 images: 94
 unique SNs: 11
 subtypes: [1, 2]

Class 2
 images: 146
 unique SNs: 13
 subtypes: [6, 8, 10, 11, 13]

Class 0
 images: 243
 unique SNs: 34
 subtypes: [0]

Class 1
 images: 94
 unique SNs: 10
 subtypes: [1, 2]

Class 2
 images: 146
 unique SNs: 13
 subtypes: [6, 8, 10, 11, 13]

Class 0
 images: 243
 unique SNs: 31
 subtypes: [0]

Class 1
 images: 94
 unique SNs: 10
 subtypes: [1, 2, 5]

Class 2
 images: 146
 unique SNs: 12
 subtypes: [6, 8, 10, 11, 13]


In [None]:
check_test_set(test_set2)


Class 0
 images: 243
 unique SNs: 31
 subtypes: [0]

Class 1
 images: 94
 unique SNs: 10
 subtypes: [1, 2, 5]

Class 2
 images: 146
 unique SNs: 12
 subtypes: [6, 8, 10, 11, 13]


{'0': [1], '1': [3], '2': [5]}

In [None]:
from collections import Counter, defaultdict

def images_per_sn(split_dict):
    """
    split_dict: dict with keys ['0','1','2'] and values = list of image names
    """
    summary = {}

    for cls, images in split_dict.items():
        sn_counter = Counter()

        for img in images:
            sn, _ = get_sn_info(img)
            if sn is not None:
                sn_counter[sn] += 1

        summary[cls] = sn_counter

    return summary

In [None]:
def print_images_per_sn(sn_summary, top_n=10):
    for cls, counter in sn_summary.items():
        print(f"\nClass {cls}")
        print(f" unique SNs: {len(counter)}")
        print(f" total images: {sum(counter.values())}")

        counts = list(counter.values())
        print(f" images per SN: min={min(counts)}, "
              f"max={max(counts)}, "
              f"mean={sum(counts)/len(counts):.2f}")

        print(" top SNs by image count:")
        for sn, n in counter.most_common(top_n):
            print(f"   {sn}: {n}")

In [None]:
sn_summary = images_per_sn(test_set2)
print_images_per_sn(sn_summary)


Class 0
 unique SNs: 31
 total images: 243
 images per SN: min=1, max=26, mean=7.84
 top SNs by image count:
   sn2002er: 26
   sn2003kf: 26
   sn89B: 24
   sn2002dj: 23
   sn2004eo: 19
   sn1994ae: 18
   sn2003W: 16
   sn2000fa: 15
   sn2001fe: 12
   sn2001N: 8

Class 1
 unique SNs: 10
 total images: 94
 images per SN: min=1, max=21, mean=9.40
 top SNs by image count:
   sn1997br: 21
   sn1999dq: 19
   sn2001eh: 18
   sn1998ab: 12
   sn08ds: 9
   sn2002hu: 5
   sn2003hu: 4
   sn99da: 3
   sn1999cw: 2
   snls03D3bb: 1

Class 2
 unique SNs: 12
 total images: 146
 images per SN: min=1, max=26, mean=12.17
 top SNs by image count:
   13ge: 26
   sn2005bf: 23
   sn2006bp: 21
   sn98bw: 13
   sn2007gr: 13
   sn94I: 12
   sn2003bg: 11
   sn2004gt: 10
   sn2011ei: 8
   sn2011fu: 5


In [None]:
def drop_images(all_images, images_to_drop):
    remaining = {}
    for cls, imgs in all_images.items():
        drop_set = set(images_to_drop.get(cls, []))
        remaining[cls] = [img for img in imgs if img not in drop_set]

    return remaining

In [None]:
# Drop the test set images so that we can now select the validation sets
remaining_images = drop_images(originals_eligible2, test_set2)

In [None]:
len(remaining_images['1'])

169

In [None]:
# Helper to see the overlap between two validation sets
def overlap_fraction(set_a, set_b):
    a = set(set_a)
    b = set(set_b)
    if len(a) == 0:
        return 0.0
    return len(a & b) / len(a)

In [None]:
# Create 4 validation sets
# We put a threshold to how much they can overlap (0.3)
# And also ensure that a single SN can't have more than 20% of the split
while True:
    val_sets = []

    for k in range(4):
        vs = {}

        vs['0'] = select_test_set_fast(
            remaining_images['0'],
            target_images=238,
            target_sns=30,
            max_len_sn=0.2,
            img_tol=5,
            sn_tol=3,
            seed=32 + seed_sum + k
        )

        vs['1'] = select_test_set_fast(
            remaining_images['1'],
            target_images=90,
            target_sns=9,
            max_len_sn=0.25,
            required_subtypes={1, 2, 4, 5},
            img_tol=5,
            sn_tol=3,
            seed=32 + seed_sum + k
        )

        vs['2'] = select_test_set_fast(
            remaining_images['2'],
            target_images=141,
            target_sns=10,
            max_len_sn=0.2,
            required_subtypes={6, 8, 10, 11, 13, 16},
            img_tol=5,
            sn_tol=3,
            seed=32 + seed_sum + k
        )

        val_sets.append(vs)

    thresh = 0.3
    ok = True

    for cls in ['0', '1', '2']:
        for i in range(4):
            for j in range(i + 1, 4):
                if overlap_fraction(val_sets[i][cls], val_sets[j][cls]) > thresh:
                    ok = False
                    break
            if not ok:
                break
        if not ok:
            break

    if ok:
        break

    seed_sum += 1

KeyboardInterrupt: 

In [None]:
val_set1 = val_sets[0]
val_set2 = val_sets[1]
val_set3 = val_sets[2]
val_set4 = val_sets[3]

In [None]:
# check overlap
for cls in ['0', '1', '2']:
  print(overlap_fraction(val_set1[cls], val_set2[cls]))
  print(overlap_fraction(val_set1[cls], val_set3[cls]))
  print(overlap_fraction(val_set1[cls], val_set4[cls]))
  print(overlap_fraction(val_set2[cls], val_set3[cls]))
  print(overlap_fraction(val_set2[cls], val_set4[cls]))
  print(overlap_fraction(val_set3[cls], val_set4[cls]))

0.0411522633744856
0.18106995884773663
0.0823045267489712
0.07407407407407407
0.32510288065843623
0.024691358024691357
0.3157894736842105
0.10526315789473684
0.3157894736842105
0.28421052631578947
0.2631578947368421
0.25263157894736843
0.1232876712328767
0.23972602739726026
0.273972602739726
0.3356164383561644
0.2876712328767123
0.273972602739726


In [None]:
sn_summary = images_per_sn(test_set2)
print_images_per_sn(sn_summary)


Class 0
 unique SNs: 31
 total images: 243
 images per SN: min=1, max=26, mean=7.84
 top SNs by image count:
   sn2002er: 26
   sn2003kf: 26
   sn89B: 24
   sn2002dj: 23
   sn2004eo: 19
   sn1994ae: 18
   sn2003W: 16
   sn2000fa: 15
   sn2001fe: 12
   sn2001N: 8

Class 1
 unique SNs: 13
 total images: 93
 images per SN: min=1, max=21, mean=7.15
 top SNs by image count:
   sn1997br: 21
   sn1999dq: 19
   sn2001eh: 18
   sn08ds: 9
   sn2002hu: 5
   sn2003hu: 4
   sn99da: 3
   sn2006gt: 3
   sn2006bz: 3
   sn2007ba: 3

Class 2
 unique SNs: 12
 total images: 146
 images per SN: min=1, max=26, mean=12.17
 top SNs by image count:
   13ge: 26
   sn2005bf: 23
   sn2006bp: 21
   sn98bw: 13
   sn2007gr: 13
   sn94I: 12
   sn2003bg: 11
   sn2004gt: 10
   sn2011ei: 8
   sn2011fu: 5


In [None]:
check_test_set(val_set1)


Class 0
 images: 243
 unique SNs: 29
 subtypes: [0]

Class 1
 images: 94
 unique SNs: 11
 subtypes: [1, 2]

Class 2
 images: 146
 unique SNs: 13
 subtypes: [6, 8, 10, 11, 13]


{'0': [1], '1': [2], '2': [5]}

In [None]:
# See what remains for validation sets..
c = 1

subty = []
for image in remaining_images[f"{c}"]:
  sn_name, subtype = get_sn_info(image)
  subty.append(subtype)
subty = np.array(subty)
print(np.unique(subty, return_counts=True))

(array([1, 2, 4, 5]), array([305, 186,  13,   5]))


In [None]:
from collections import defaultdict

# See what remains for validation sets..
by_sn = defaultdict(lambda: {'subtypes': set(), 'n_images': 0})

for img in remaining_images['2']:
    sn, subtype = get_sn_info(img)
    if sn is None:
        continue
    by_sn[sn]['subtypes'].add(subtype)
    by_sn[sn]['n_images'] += 1

# Print summary
print(f"{'SN name':<20} | {'#images':<8} | subtypes")
print("-" * 45)

for sn in sorted(by_sn):
    info = by_sn[sn]
    print(f"{sn:<20} | {info['n_images']:<8} | {sorted(info['subtypes'])}")


SN name              | #images  | subtypes
---------------------------------------------
10as                 | 16       | [8]
11hs                 | 15       | [8]
12au                 | 6        | [6]
15dtg                | 7        | [10]
16coi                | 13       | [11]
16gkg                | 9        | [8]
17ein                | 5        | [10]
LSQ14efd             | 14       | [10]
PTF10bzf             | 2        | [11]
PTF10qts             | 6        | [11]
PTF10vgv             | 4        | [11]
PTF12gzk             | 6        | [10]
iPTF13bvn            | 13       | [6]
sn1983N              | 2        | [6]
sn1983V              | 10       | [10]
sn1984L              | 7        | [6]
sn1987A              | 200      | [16]
sn1990B              | 17       | [10]
sn1990I              | 7        | [6]
sn1990U              | 5        | [6]
sn1992ar             | 1        | [10]
sn1993J              | 47       | [8]
sn1994I              | 30       | [10]
sn1996cb             | 1

In [None]:
to_add_c1 = []
for basename in copies['1']:
  orig, copies_, subtype, base = originals_and_copies(basename, reduced_all_images)
  if copies_ + orig < 18:
    to_add_c1.append(basename)

SN sn02cx (subtype 4) has 8 originals and 73 copies
SN sn2002dl (subtype 2) has 4 originals and 1 copies
SN sn02fb (subtype 2) has 2 originals and 2 copies
SN sn06ke (subtype 2) has 1 originals and 1 copies
SN sn2005gj (subtype 3) has 22 originals and 182 copies
SN sn91T (subtype 1) has 21 originals and 1 copies
SN sn2000cn (subtype 2) has 10 originals and 4 copies
SN sn2008ae (subtype 5) has 5 originals and 13 copies
SN sn99by (subtype 2) has 15 originals and 12 copies
SN sn00cx (subtype 5) has 26 originals and 82 copies
SN sn05ke (subtype 2) has 3 originals and 1 copies
SN sn1998es (subtype 1) has 26 originals and 1 copies
SN sn2005ke (subtype 2) has 10 originals and 1 copies
SN sn2006hb (subtype 2) has 8 originals and 4 copies
SN sn03gq (subtype 4) has 1 originals and 4 copies
SN sn2003Y (subtype 2) has 3 originals and 2 copies
SN sn2008A (subtype 4) has 14 originals and 133 copies
SN sn2006oa (subtype 1) has 7 originals and 1 copies
SN sn2007al (subtype 2) has 8 originals and 1 cop

In [None]:
# We can compromise having some copies in val sets to gain subtype representation
to_add_c2 = []
for basename in copies['2']:
  orig, copies_, subtype, base = originals_and_copies(basename, reduced_all_images)
  if copies_ + orig < 28:
    to_add_c2.append(basename)

SN sn1980K (subtype 14) has 8 originals and 122 copies
SN sn2002ao (subtype 7) has 4 originals and 16 copies
SN sn2004aw (subtype 12) has 27 originals and 133 copies
SN sn2005cs (subtype 13) has 41 originals and 10 copies
SN sn1999em (subtype 13) has 45 originals and 7 copies
SN sn1992H (subtype 13) has 10 originals and 3 copies
SN sn1998S (subtype 15) has 57 originals and 129 copies
SN sn1979C (subtype 14) has 5 originals and 65 copies
SN sn2005la (subtype 9) has 5 originals and 47 copies
SN sn2004et (subtype 13) has 48 originals and 4 copies
SN sn2005ek (subtype 12) has 6 originals and 34 copies
SN sn1996L (subtype 15) has 7 originals and 7 copies
SN sn2006jc (subtype 7) has 22 originals and 113 copies
SN sn2007uy (subtype 9) has 10 originals and 138 copies
SN sn2000er (subtype 7) has 5 originals and 40 copies


In [None]:
# Images with copies to include in the pool for validation
move1 = []
for image in reduced_all_images['1']:
  sn_name, subtype = get_sn_info(image)
  if sn_name in to_add_c1:
    move1.append(image)

move2 = []
for image in reduced_all_images['2']:
  sn_name, subtype = get_sn_info(image)
  if sn_name in to_add_c2:
    move2.append(image)


In [None]:
remaining_images['2'] = remaining_images['2'] + move2
remaining_images['1'] = remaining_images['1'] + move1

In [None]:
# Having test_set and val_sets, materialize the complete splits
def materialize_split(reduced_all_images, test_set, val_set):
    split = {}

    for cls in ['0', '1', '2']:
        all_imgs = set(reduced_all_images[cls])
        test_imgs = set(test_set.get(cls, []))
        val_imgs  = set(val_set.get(cls, []))

        # Safety checks
        assert test_imgs.isdisjoint(val_imgs), f"Overlap between test and val in class {cls}"

        train_imgs = all_imgs - test_imgs - val_imgs

        split[cls] = {
            'train': sorted(train_imgs),
            'val': sorted(val_imgs),
            'test': sorted(test_imgs)
        }

    return split


# ---- build split 1 ----
split1 = materialize_split(
    reduced_all_images=reduced_all_images,
    test_set=test_set2,
    val_set=val_set1
)

In [None]:
# ---- build split 2 ----
split2 = materialize_split(
    reduced_all_images=reduced_all_images,
    test_set=test_set2,
    val_set=val_set2
)

In [None]:
# ---- build split 3 ----
split3 = materialize_split(
    reduced_all_images=reduced_all_images,
    test_set=test_set2,
    val_set=val_set3
)

In [None]:
# ---- build split 4 ----
split4 = materialize_split(
    reduced_all_images=reduced_all_images,
    test_set=test_set2,
    val_set=val_set4
)

In [None]:
# save test split
# with open(f'{working_dir}/Splits/tvt_split4.pkl', 'wb') as f:
#     pickle.dump(split4, f)


In [None]:
# save test split
# with open(f'{working_dir}/Splits/new_test_set.pkl', 'wb') as f:
#     pickle.dump(test_set2, f)
#

In [None]:
# save test split
# with open(f'{working_dir}/Splits/new_test_set.pkl', 'wb') as f:
#     pickle.dump(test_set2, f)
#

In [None]:
# save test split
# with open(f'{working_dir}/Splits/new_test_set.pkl', 'wb') as f:
#     pickle.dump(test_set2, f)
#