# Importing the data and creating the data loaders

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [4]:
# importing libraries

import glob
import sys
from zipfile import ZipFile 
import concurrent.futures
import gc
from time import time
import cv2

sys.path.insert(0,'../src/')

import PIL as pil

import pandas as pd
import numpy as np
np.random.seed(42)
import random

import matplotlib.pyplot as plt

import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
from tqdm import tqdm
import urllib.request
import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33memadonev[0m ([33memadonev-xv-gimnazija[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
from data_processing import *
from cvt import *
from model_train import *

---

In [6]:
reference_images = pd.read_csv('../input/filename_mapping.csv')

main_catalogue = pd.read_csv('../input/gz2_classes.csv')

In [7]:
reference_images.head()

Unnamed: 0,objid,sample,asset_id
0,587722981736120347,original,1
1,587722981736579107,original,2
2,587722981741363294,original,3
3,587722981741363323,original,4
4,587722981741559888,original,5


In [8]:
main_catalogue.head()

Unnamed: 0,specobjid,dr8objid,dr7objid,ra,dec,rastring,decstring,sample,gz2class,total_classifications,...,t11_arms_number_a36_more_than_4_fraction,t11_arms_number_a36_more_than_4_weighted_fraction,t11_arms_number_a36_more_than_4_debiased,t11_arms_number_a36_more_than_4_flag,t11_arms_number_a37_cant_tell_count,t11_arms_number_a37_cant_tell_weight,t11_arms_number_a37_cant_tell_fraction,t11_arms_number_a37_cant_tell_weighted_fraction,t11_arms_number_a37_cant_tell_debiased,t11_arms_number_a37_cant_tell_flag
0,1.802675e+18,,588017703996096547,160.9904,11.70379,10:43:57.70,+11:42:13.6,original,SBb?t,44,...,0.225,0.225,0.225,0,10,10.0,0.25,0.25,0.25,0
1,1.992984e+18,,587738569780428805,192.41083,15.164207,12:49:38.60,+15:09:51.1,original,Ser,45,...,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0.0,0
2,1.489569e+18,,587735695913320507,210.8022,54.348953,14:03:12.53,+54:20:56.2,original,Sc+t,46,...,0.651,0.651,0.651,0,3,3.0,0.07,0.07,0.07,0
3,2.924084e+18,1.237668e+18,587742775634624545,185.30342,18.382704,12:21:12.82,+18:22:57.7,original,SBc(r),45,...,0.071,0.071,0.071,0,6,6.0,0.429,0.429,0.429,0
4,1.387165e+18,1.237658e+18,587732769983889439,187.36679,8.749928,12:29:28.03,+08:44:59.7,extra,Ser,49,...,0.0,0.0,0.0,0,1,1.0,1.0,1.0,1.0,0


In [9]:
# create modified main catalogue
model_01_catalogue = pd.DataFrame()
model_01_catalogue['dr7ID'] = main_catalogue['dr7objid']
model_01_catalogue['class'] = main_catalogue['gz2class']
print(model_01_catalogue.shape)
model_01_catalogue.head()

(243500, 2)


Unnamed: 0,dr7ID,class
0,588017703996096547,SBb?t
1,587738569780428805,Ser
2,587735695913320507,Sc+t
3,587742775634624545,SBc(r)
4,587732769983889439,Ser


In [10]:
model_01_catalogue.drop(model_01_catalogue[model_01_catalogue['class'] == 'A'].index, inplace=True)
model_01_catalogue.shape

(243253, 2)

In [11]:
# connecting each class with the corresponding asset_id
model_01_catalogue = model_01_catalogue.merge(
    reference_images[['objid', 'asset_id']], 
    left_on='dr7ID', 
    right_on='objid', 
    how='left'
).drop(columns=['objid'])  # Drop extra 'objid' column after merging
model_01_catalogue = model_01_catalogue.sort_values(by=['asset_id']).reset_index(drop=True)

model_01_catalogue['class'] = model_01_catalogue['class'].apply(lambda x: x.replace('(', '').replace(')', '').ljust(6, '0'))
model_01_catalogue.head()

Unnamed: 0,dr7ID,class,asset_id
0,587722981741363294,Ei0000,3
1,587722981741363323,Sc0000,4
2,587722981741559888,Er0000,5
3,587722981741625481,Er0000,6
4,587722981741625484,Ei0000,7


In [12]:
# creating a label diagram table
label_diagram = pd.DataFrame(columns=['r1', 'r2', 'r3', 'r4', 'r5'])
label_diagram['asset_id'] = model_01_catalogue['asset_id']
label_diagram['r1'] = model_01_catalogue['class'].apply(choose_class1)
label_diagram['r2'] = model_01_catalogue['class'].apply(choose_class2)
label_diagram['r3'] = model_01_catalogue['class'].apply(choose_class3)
label_diagram['r4'] = model_01_catalogue['class'].apply(choose_class4)
label_diagram['r5'] = model_01_catalogue['class'].apply(choose_class5)
label_diagram.head(10)

Unnamed: 0,r1,r2,r3,r4,r5,asset_id
0,E,is,0,0,0,3
1,S,c,0,0,0,4
2,E,rs,0,0,0,5
3,E,rs,0,0,0,6
4,E,is,0,0,0,7
5,E,is,0,0,0,8
6,E,rs,0,0,0,9
7,E,rs,0,0,0,11
8,S,c,0,0,0,12
9,E,cs,0,0,0,13


In [13]:
label_diagram.to_csv("../input/label_diagram.csv", index=False)

In [14]:
labels_bench = [label_diagram["r1"][x]+label_diagram["r2"][x] for x in range(label_diagram.shape[0])]
labels_bench[:10]

['Eis', 'Sc', 'Ers', 'Ers', 'Eis', 'Eis', 'Ers', 'Ers', 'Sc', 'Ecs']

In [15]:
unique_count = len(set(labels_bench))
print(unique_count)

14


In [16]:
label_mapping = {label_diagram['asset_id'][x]: labels_bench[x] for x in range(len(labels_bench))}

---

In [16]:
imgs_path = '../input/images_gz2/images/'
W, H, C = 224, 224, 4

In [17]:
file_list = create_file_list(imgs_path, label_diagram)
print('file list loaded')

file list loaded


In [23]:
n = 10000

In [24]:
images_orig, labels_orig = data_setup(file_list, label_diagram, n)

243253 243253
['../input/images_gz2/images/100.jpg', '../input/images_gz2/images/1000.jpg']
['S', 'E']


In [25]:
pairs = [(images_orig[x],labels_orig[x]) for x in range(len(images_orig))]
pairs[:5]

[('../input/images_gz2/images/100.jpg', 'S'),
 ('../input/images_gz2/images/1000.jpg', 'E'),
 ('../input/images_gz2/images/10000.jpg', 'E'),
 ('../input/images_gz2/images/100000.jpg', 'Se'),
 ('../input/images_gz2/images/100001.jpg', 'E')]

In [28]:
label0 = [x for x in pairs if x[1]=='E']
print(len(label0))
label1 = [x for x in pairs if x[1]=='S']
print(len(label1))
label2 = [x for x in pairs if x[1]=='SB']
print(len(label2))
label3 = [x for x in pairs if x[1]=='Se']
print(len(label3))

103515
94332
21402
24004


In [31]:
label0_selection = random.sample(label0, n)
print(len(label0_selection), label0_selection[:3])

label1_selection = random.sample(label1, n)
print(len(label1_selection), label1_selection[:3])

label2_selection = random.sample(label2, n)
print(len(label2_selection), label2_selection[:3])

label3_selection = random.sample(label3, n)
print(len(label3_selection), label3_selection[:3])

10000 [('../input/images_gz2/images/259734.jpg', 'E'), ('../input/images_gz2/images/19826.jpg', 'E'), ('../input/images_gz2/images/65764.jpg', 'E')]
10000 [('../input/images_gz2/images/245872.jpg', 'S'), ('../input/images_gz2/images/53148.jpg', 'S'), ('../input/images_gz2/images/64301.jpg', 'S')]
10000 [('../input/images_gz2/images/71807.jpg', 'SB'), ('../input/images_gz2/images/91754.jpg', 'SB'), ('../input/images_gz2/images/134536.jpg', 'SB')]
10000 [('../input/images_gz2/images/198321.jpg', 'Se'), ('../input/images_gz2/images/255406.jpg', 'Se'), ('../input/images_gz2/images/1404.jpg', 'Se')]


In [32]:
pairs_rand = label0_selection + label1_selection + label2_selection + label3_selection
print(len(pairs_rand), pairs_rand[:5])

40000 [('../input/images_gz2/images/259734.jpg', 'E'), ('../input/images_gz2/images/19826.jpg', 'E'), ('../input/images_gz2/images/65764.jpg', 'E'), ('../input/images_gz2/images/21116.jpg', 'E'), ('../input/images_gz2/images/167749.jpg', 'E')]


In [33]:
images_orig = [x[0] for x in pairs_rand]
labels_orig = [x[1] for x in pairs_rand]

print(images_orig[:2], labels_orig[:2])

['../input/images_gz2/images/259734.jpg', '../input/images_gz2/images/19826.jpg'] ['E', 'E']


In [21]:
labels[1]

array([1., 1., 8., 4., 5., 9., 1., 3., 1., 5., 1., 4., 9., 3., 8., 1., 9.,
       1., 4., 0., 8., 5., 4., 4., 9., 7., 8., 5., 4., 4., 5., 4., 3., 1.,
       8., 4., 8., 5., 9., 1., 1., 9., 5., 5., 5., 8., 8., 1., 6., 8., 5.,
       8., 4., 5., 4., 7., 1., 8., 5., 4., 5., 9., 5., 5., 8., 8., 4., 8.,
       7., 5., 8., 3., 4., 5., 8., 9., 4., 8., 1., 4., 5., 1., 8., 4., 1.,
       5., 5., 9., 9., 1., 5., 1., 6., 9., 1., 5., 9., 1., 2., 5., 4., 4.,
       1., 1., 5., 1., 3., 1., 1., 1., 8., 1., 8., 5., 4., 1., 5., 1., 4.,
       7., 5., 4., 1., 4., 4., 8., 9., 3., 8., 4., 1., 7., 3., 1., 4., 5.,
       9., 5., 4., 4., 1., 4., 8., 9., 1., 1., 1., 8., 5., 8., 7., 6., 8.,
       5., 5., 1., 4., 5., 1., 7., 4., 1., 7., 8., 4., 4., 4., 1., 0., 5.,
       8., 5., 4., 1., 5., 5., 4., 1., 5., 3., 5., 9., 3., 1., 4., 4., 5.,
       1., 4., 4., 8., 4., 9., 9., 1., 4., 5., 9., 1., 9., 4., 1., 1., 1.,
       5., 3., 1., 5., 5., 8., 4., 1., 5., 5., 4., 8., 4., 4., 3., 5., 5.,
       5., 1., 5., 4., 5.

In [17]:
outputs = np.load('../output/outputs_test.npy', allow_pickle=True)

In [18]:
outputs

array([[2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 0., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 0., 2., 2., 2., 2., 2., 2., 0., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 0., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 0., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 

In [19]:
class_mapping = {x : i for i, x in enumerate(sorted(set(labels_bench)))}

In [20]:
class_mapping

{'Ecs': 0,
 'Eis': 1,
 'Ers': 2,
 'SBa': 3,
 'SBb': 4,
 'SBc': 5,
 'SBd': 6,
 'Sa': 7,
 'Sb': 8,
 'Sc': 9,
 'Sd': 10,
 'Sebb': 11,
 'Sebn': 12,
 'Sebr': 13}