In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import sys
import gc
from time import time
from collections import Counter
import pandas as pd
import numpy as np
np.random.seed(42)

import matplotlib.pyplot as plt

import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
from tqdm import tqdm
import wandb
import glob
import random

wandb.login()

sys.path.insert(0,'../src/')

# -----
from data_processing import *
from model_train import *
from labeling_system import *
import cvt as cvt
import cvt_benchmark as cvtb

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33memadonev[0m ([33memadonev-xv-gimnazija[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from data_processing import *
from model_train import *
import cvt as cvt
import cvt_benchmark as cvtb

---

In [3]:
reference_images = pd.read_csv('../input/filename_mapping.csv')


In [4]:
main_catalogue = pd.read_csv('../input/gz2_classes.csv')


In [5]:
reference_images.head()

Unnamed: 0,objid,sample,asset_id
0,587722981736120347,original,1
1,587722981736579107,original,2
2,587722981741363294,original,3
3,587722981741363323,original,4
4,587722981741559888,original,5


In [6]:
main_catalogue.head()

Unnamed: 0,specobjid,dr8objid,dr7objid,ra,dec,rastring,decstring,sample,gz2class,total_classifications,...,t11_arms_number_a36_more_than_4_fraction,t11_arms_number_a36_more_than_4_weighted_fraction,t11_arms_number_a36_more_than_4_debiased,t11_arms_number_a36_more_than_4_flag,t11_arms_number_a37_cant_tell_count,t11_arms_number_a37_cant_tell_weight,t11_arms_number_a37_cant_tell_fraction,t11_arms_number_a37_cant_tell_weighted_fraction,t11_arms_number_a37_cant_tell_debiased,t11_arms_number_a37_cant_tell_flag
0,1.802675e+18,,588017703996096547,160.9904,11.70379,10:43:57.70,+11:42:13.6,original,SBb?t,44,...,0.225,0.225,0.225,0,10,10.0,0.25,0.25,0.25,0
1,1.992984e+18,,587738569780428805,192.41083,15.164207,12:49:38.60,+15:09:51.1,original,Ser,45,...,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0.0,0
2,1.489569e+18,,587735695913320507,210.8022,54.348953,14:03:12.53,+54:20:56.2,original,Sc+t,46,...,0.651,0.651,0.651,0,3,3.0,0.07,0.07,0.07,0
3,2.924084e+18,1.237668e+18,587742775634624545,185.30342,18.382704,12:21:12.82,+18:22:57.7,original,SBc(r),45,...,0.071,0.071,0.071,0,6,6.0,0.429,0.429,0.429,0
4,1.387165e+18,1.237658e+18,587732769983889439,187.36679,8.749928,12:29:28.03,+08:44:59.7,extra,Ser,49,...,0.0,0.0,0.0,0,1,1.0,1.0,1.0,1.0,0


In [None]:
list(main_catalogue.columns)

['specobjid',
 'dr8objid',
 'dr7objid',
 'ra',
 'dec',
 'rastring',
 'decstring',
 'sample',
 'gz2class',
 'total_classifications',
 'total_votes',
 't01_smooth_or_features_a01_smooth_count',
 't01_smooth_or_features_a01_smooth_weight',
 't01_smooth_or_features_a01_smooth_fraction',
 't01_smooth_or_features_a01_smooth_weighted_fraction',
 't01_smooth_or_features_a01_smooth_debiased',
 't01_smooth_or_features_a01_smooth_flag',
 't01_smooth_or_features_a02_features_or_disk_count',
 't01_smooth_or_features_a02_features_or_disk_weight',
 't01_smooth_or_features_a02_features_or_disk_fraction',
 't01_smooth_or_features_a02_features_or_disk_weighted_fraction',
 't01_smooth_or_features_a02_features_or_disk_debiased',
 't01_smooth_or_features_a02_features_or_disk_flag',
 't01_smooth_or_features_a03_star_or_artifact_count',
 't01_smooth_or_features_a03_star_or_artifact_weight',
 't01_smooth_or_features_a03_star_or_artifact_fraction',
 't01_smooth_or_features_a03_star_or_artifact_weighted_fractio

In [8]:
main_catalogue.shape

(243500, 233)

In [9]:
main_catalogue_merged = main_catalogue.merge(
    reference_images[['objid', 'asset_id']], 
    left_on='dr7objid', 
    right_on='objid', 
    how='left'
).drop(columns=['objid'])  # Drop extra 'objid' column after merging
main_catalogue_merged = main_catalogue_merged.sort_values(by=['asset_id']).reset_index(drop=True)

In [10]:
main_catalogue_merged.shape

(243500, 234)

In [11]:
main_merged_nA = main_catalogue_merged[main_catalogue_merged['gz2class'] != 'A'].copy()
main_merged_nA.shape

(243253, 234)

In [None]:
main_runs = pd.DataFrame()
main_runs['asset_id'] = main_merged_nA['asset_id']
main_runs['class_reference'] = main_merged_nA['gz2class']
main_runs['a01_smooth'] = main_merged_nA['t01_smooth_or_features_a01_smooth_debiased']
main_runs['a02_features_disk'] = main_merged_nA['t01_smooth_or_features_a02_features_or_disk_debiased']
main_runs['a08_spiral'] = main_merged_nA['t04_spiral_a08_spiral_debiased']
main_runs['a04_edgeon_yes'] = main_merged_nA['t02_edgeon_a04_yes_debiased']
main_runs['a04_edgeon_no'] = main_merged_nA['t02_edgeon_a05_no_debiased']


main_runs['a16_completely_round'] = main_merged_nA['t07_rounded_a16_completely_round_debiased']
main_runs['a17_in_between'] = main_merged_nA['t07_rounded_a17_in_between_debiased']
main_runs['a18_cigar_shaped'] = main_merged_nA['t07_rounded_a18_cigar_shaped_debiased']

main_runs['a06_bar'] = main_merged_nA['t03_bar_a06_bar_debiased']
main_runs['a07_no_bar'] = main_merged_nA['t03_bar_a07_no_bar_debiased']

main_runs['a25_round_bulge'] = main_merged_nA['t09_bulge_shape_a25_rounded_debiased']
main_runs['a26_boxy_bulge'] = main_merged_nA['t09_bulge_shape_a26_boxy_debiased']
main_runs['a27_no_bulge'] = main_merged_nA['t09_bulge_shape_a27_no_bulge_debiased']

main_runs.head()

Unnamed: 0,asset_id,class_reference,a01_smooth,a02_features_disk,a08_spiral,a04_edgeon_yes,a04_edgeon_no,a16_completely_round,a17_in_between,a18_cigar_shaped,a06_bar,a07_no_bar,a25_round_bulge,a26_boxy_bulge,a27_no_bulge
0,3,Ei,0.539134,0.518162,0.0,0.095926,0.902379,0.313218,0.621107,0.0,0.0,1.0,1.0,0.0,0.0
1,4,Sc,0.51996,0.601431,0.166973,0.0,1.0,0.318003,0.73608,0.0,0.13372,0.863893,0.0,0.0,0.0
2,5,Er,0.816,0.038,0.0,0.0,1.0,0.93,0.07,0.0,0.0,1.0,0.0,0.0,0.0
3,6,Er,0.621229,0.447836,0.077496,0.0,1.0,0.910455,0.088848,0.0,0.0,1.0,0.0,0.0,0.0
4,7,Ei,0.573432,0.485706,0.0,0.0,1.0,0.371883,0.595642,0.0,0.0,1.0,0.0,0.0,0.0


In [13]:
def run1_soft_labels(row):
    # E 
    p_e = row["a01_smooth"]

    # S
    p_s = (row["a02_features_disk"] * row["a04_edgeon_no"] * row['a08_spiral'])

    # Se
    p_se = (row["a02_features_disk"] * row['a04_edgeon_yes'])

    # Normalize
    total = p_e + p_s + p_se
    if total == 0:
        return np.array([1.0, 0.0, 0.0])  # fallback: assume elliptical

    return np.array([p_e, p_s, p_se]) / total

In [14]:
main_runs.head()

Unnamed: 0,asset_id,class_reference,a01_smooth,a02_features_disk,a08_spiral,a04_edgeon_yes,a04_edgeon_no,a16_completely_round,a17_in_between,a18_cigar_shaped,a06_bar,a07_no_bar,a25_round_bulge,a26_boxy_bulge,a27_no_bulge
0,3,Ei,0.539134,0.518162,0.0,0.095926,0.902379,0.313218,0.621107,0.0,0.0,1.0,1.0,0.0,0.0
1,4,Sc,0.51996,0.601431,0.166973,0.0,1.0,0.318003,0.73608,0.0,0.13372,0.863893,0.0,0.0,0.0
2,5,Er,0.816,0.038,0.0,0.0,1.0,0.93,0.07,0.0,0.0,1.0,0.0,0.0,0.0
3,6,Er,0.621229,0.447836,0.077496,0.0,1.0,0.910455,0.088848,0.0,0.0,1.0,0.0,0.0,0.0
4,7,Ei,0.573432,0.485706,0.0,0.0,1.0,0.371883,0.595642,0.0,0.0,1.0,0.0,0.0,0.0


In [13]:
main_runs.to_csv('../input/main_runs.csv', index=False)

In [15]:
soft_label_dict1 = {
    int(row["asset_id"]): run1_soft_labels(row)
    for _, row in main_runs.iterrows()
    }

In [16]:
def run2_soft_labels(row):

    pr = (row["a01_smooth"] * row['a16_completely_round'])

    pi = (row["a01_smooth"] * row['a17_in_between'])

    pc = (row["a01_smooth"] * row['a18_cigar_shaped'])

    # -----

    pBar = (row["a02_features_disk"] * row["a04_edgeon_no"] * row['a08_spiral'] * row['a06_bar'])

    pnoBar = (row["a02_features_disk"] * row["a04_edgeon_no"] * row['a08_spiral'] * row['a07_no_bar'])

    # -----

    pSeBulge = (row["a02_features_disk"] * row['a04_edgeon_yes'] * row['a25_round_bulge'] + row["a02_features_disk"] * row['a04_edgeon_yes'] * row["a26_boxy_bulge"])

    pSenoB = (row["a02_features_disk"] * row['a04_edgeon_yes'] * row["a27_no_bulge"])

    # Normalize
    total = pr + pi + pc + pBar + pnoBar + pSeBulge + pSenoB
    if total == 0:
        return np.array([0.0,0.0,0.0,0.0,0.0])

    return np.array([pr, pi, pc, pBar, pnoBar, pSeBulge, pSenoB]) / total

In [17]:
soft_label_dict2 = {
    int(row["asset_id"]): run2_soft_labels(row)
    for _, row in main_runs.iterrows()
    }

In [18]:
main_runs.head()

Unnamed: 0,asset_id,class_reference,a01_smooth,a02_features_disk,a08_spiral,a04_edgeon_yes,a04_edgeon_no,a16_completely_round,a17_in_between,a18_cigar_shaped,a06_bar,a07_no_bar,a25_round_bulge,a26_boxy_bulge,a27_no_bulge
0,3,Ei,0.539134,0.518162,0.0,0.095926,0.902379,0.313218,0.621107,0.0,0.0,1.0,1.0,0.0,0.0
1,4,Sc,0.51996,0.601431,0.166973,0.0,1.0,0.318003,0.73608,0.0,0.13372,0.863893,0.0,0.0,0.0
2,5,Er,0.816,0.038,0.0,0.0,1.0,0.93,0.07,0.0,0.0,1.0,0.0,0.0,0.0
3,6,Er,0.621229,0.447836,0.077496,0.0,1.0,0.910455,0.088848,0.0,0.0,1.0,0.0,0.0,0.0
4,7,Ei,0.573432,0.485706,0.0,0.0,1.0,0.371883,0.595642,0.0,0.0,1.0,0.0,0.0,0.0


In [19]:
from scipy.stats import entropy

In [20]:
def get_label_entropy(soft_label, num):
    max_entropy = np.log2(num)
    return (entropy(soft_label, base=2)) / max_entropy if max_entropy > 0 else 0

In [21]:
def section_spurious(soft_label_dict, num, entropy_threshold=0.7):
    confident = {}
    spurious = {}

    for asset_id, label in soft_label_dict.items():
        if get_label_entropy(label, num) > entropy_threshold:
            spurious[asset_id] = label
        else:
            confident[asset_id] = label

    return confident, spurious

In [22]:
soft_run1_conf, soft_run1_spur = section_spurious(soft_label_dict1, num=3)

In [23]:
len(soft_run1_conf)

204762

In [24]:
def create_hard_labels(labels_dict):
    hard_labels = {}
    for asset_id, label in labels_dict.items():
        hard_labels[asset_id] = int(np.argmax(label))
    return hard_labels

In [25]:
hard1 = create_hard_labels(soft_run1_conf)
len(hard1)

204762

In [26]:
print(Counter(list(hard1.values())))

Counter({0: 130557, 1: 56602, 2: 17603})


In [27]:
soft_run2_conf, soft_run2_spur = section_spurious(soft_label_dict2, num=7)

In [28]:
len(soft_run2_conf)

224287

In [29]:
hard2 = create_hard_labels(soft_run2_conf)
len(hard2)

224287

In [30]:
print(Counter(list(hard2.values())))

Counter({1: 69767, 0: 58485, 4: 51702, 5: 14903, 3: 13821, 2: 9566, 6: 6043})


In [31]:
def create_file_list(imgs_path, label_dict1, label_dict2):
    file_list = glob.glob(os.path.join(imgs_path, '*.jpg'))
    file_list = sorted(file_list)

    file_list = [(f, int(f.split('/')[-1].split('.')[0])) for f in file_list if (int(f.split('/')[-1].split('.')[0]) in label_dict1) and (int(f.split('/')[-1].split('.')[0]) in label_dict2)]

    return file_list

In [32]:
imgs_path = '../input/images_gz2/images/'


In [33]:
conf_file_list = create_file_list(imgs_path, soft_run1_conf, soft_run2_conf)
len(conf_file_list)

203488

In [34]:
conf_file_list

[('../input/images_gz2/images/100.jpg', 100),
 ('../input/images_gz2/images/1000.jpg', 1000),
 ('../input/images_gz2/images/10000.jpg', 10000),
 ('../input/images_gz2/images/100000.jpg', 100000),
 ('../input/images_gz2/images/100002.jpg', 100002),
 ('../input/images_gz2/images/100004.jpg', 100004),
 ('../input/images_gz2/images/100006.jpg', 100006),
 ('../input/images_gz2/images/100007.jpg', 100007),
 ('../input/images_gz2/images/100010.jpg', 100010),
 ('../input/images_gz2/images/100011.jpg', 100011),
 ('../input/images_gz2/images/100012.jpg', 100012),
 ('../input/images_gz2/images/100013.jpg', 100013),
 ('../input/images_gz2/images/100016.jpg', 100016),
 ('../input/images_gz2/images/100020.jpg', 100020),
 ('../input/images_gz2/images/100021.jpg', 100021),
 ('../input/images_gz2/images/100022.jpg', 100022),
 ('../input/images_gz2/images/100023.jpg', 100023),
 ('../input/images_gz2/images/100025.jpg', 100025),
 ('../input/images_gz2/images/100026.jpg', 100026),
 ('../input/images_gz2/i

In [49]:
print(Counter(list(runs.values())))

In [56]:
def data_setup(file_list, labels_dict, n):
    runs = {}

    for f in file_list:
        asset_id = f[1]
        label_val = labels_dict.get(asset_id, None) # get the label value
        runs[f[0]] = label_val # connect the filename and the label value

    print(Counter(list(runs.values())))

    images_orig = [x for x in runs]
    labels_orig = [runs[x] for x in runs]
    
    pairs = [(images_orig[x],labels_orig[x]) for x in range(len(images_orig))]

    print(pairs[:4])

    label0 = [x for x in pairs if x[1]==0]
    label1 = [x for x in pairs if x[1]==1]
    label2 = [x for x in pairs if x[1]==2]
    label3 = [x for x in pairs if x[1]==3]
    label4 = [x for x in pairs if x[1]==4]
    label5 = [x for x in pairs if x[1]==5]
    label6 = [x for x in pairs if x[1]==6]

    print(len(label0), len(label1), len(label2), len(label3), len(label4), len(label5), len(label6))

    label0_selection = random.sample(label0, n-500)
    label1_selection = random.sample(label1, n-500)
    label2_selection = random.sample(label2, n-500)
    label3_selection = random.sample(label3, n)
    label4_selection = random.sample(label4, n)
    label5_selection = random.sample(label5, n)
    label6_selection = random.sample(label6, n)

    pairs_rand = label0_selection + label1_selection + label2_selection + label3_selection + label4_selection + label5_selection + label6_selection

    images_orig = [x[0] for x in pairs_rand]
    labels_orig = [x[1] for x in pairs_rand]

    return images_orig, labels_orig

In [59]:
n = 5000

In [60]:
images_orig, labels_orig = data_setup(conf_file_list, hard2, n)

Counter({1: 63971, 0: 57249, 4: 44374, 3: 12825, 5: 12442, 2: 7602, 6: 5025})
[('../input/images_gz2/images/100.jpg', 0), ('../input/images_gz2/images/1000.jpg', 0), ('../input/images_gz2/images/10000.jpg', 1), ('../input/images_gz2/images/100000.jpg', 5)]
57249 63971 7602 12825 44374 12442 5025


In [61]:
print(Counter(labels_orig))

Counter({3: 5000, 4: 5000, 5: 5000, 6: 5000, 0: 4500, 1: 4500, 2: 4500})


In [62]:
def split_data(x, y):
    x_train, x_rem, y_train, y_rem = train_test_split(x, y, train_size=0.7, random_state=42, 
    stratify=y, 
    shuffle=True)

    x_valid, x_test, y_valid, y_test = train_test_split(x_rem, y_rem, test_size=0.34, random_state=42, 
    stratify=y_rem, 
    shuffle=True)

    print(len(x_train), len(x_valid), len(x_test))

    print(x_train[:5], y_train[:5])

    return x_train, x_valid, x_test, y_train, y_valid, y_test

In [63]:
traino, valido, testo, y_traino, y_valido, y_testo = split_data(images_orig, labels_orig)

23450 6632 3418
['../input/images_gz2/images/78735.jpg', '../input/images_gz2/images/181010.jpg', '../input/images_gz2/images/13038.jpg', '../input/images_gz2/images/40631.jpg', '../input/images_gz2/images/261941.jpg'] [6, 6, 1, 5, 6]


In [64]:
23450+6632+3418

33500

In [65]:
print(Counter(y_traino))

Counter({6: 3500, 5: 3500, 3: 3500, 4: 3500, 1: 3150, 2: 3150, 0: 3150})


In [67]:
bs = 32

In [None]:
#train_dl, valid_dl, test_dl, y_train, y_valid, y_test = create_data_loaders(traino, valido, testo, hard1, soft_run1_conf, bs, aux_train=None, aux_valid=None, aux_test=None)

: 