In [1]:
# !find ../images-legacy-dr9 -type f -print0 | xargs -0 ls -l | awk '{size[int(log($5)/log(2))]++}END{for (i in size) printf("%10d %3d\n", 2^i, size[i])}' | sort -n

         0 1810110
       512   6
      1024 22624
      2048 778
      4096 2983
      8192 66819
     16384 4363249
     32768  45


# Test with hdxresnet34

In [2]:
from fastai2.basics import *
from fastai2.vision.all import *
from pathlib import Path

from mish_cuda import MishCuda
import cmasher as cmr
import gc
import scipy.cluster.hierarchy as hcluster

from sklearn.model_selection import KFold
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix, roc_curve

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import tqdm.notebook as tqdm

seed = 256

%matplotlib inline

In [4]:
import sys
PATH = Path('..').resolve()

sys.path.append(f'{PATH}/src')
from utils import *

learner_metrics = [accuracy, F1Score(), Recall(), Precision()]

sys.path.append('/home/jupyter/morphological-spectra/src')
from xresnet_deconv import *

In [5]:

# https://github.com/fastai/fastai/blob/master/fastai/losses.py#L48
class FocalLossFlat(CrossEntropyLossFlat):
    """
    Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al.
    https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be
    implemented through pytorch `weight` argument in nn.CrossEntropyLoss.
    """
    y_int = True
    @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
    def __init__(self, *args, gamma=2, axis=-1, **kwargs):
        self.gamma = gamma
        self.reduce = kwargs.pop('reduction') if 'reduction' in kwargs else 'mean'
        super().__init__(*args, reduction='none', axis=axis, **kwargs)
    def __call__(self, inp, targ, **kwargs):
        ce_loss = super().__call__(inp, targ, **kwargs)
        pt = torch.exp(-ce_loss)
        fl_loss = (1-pt)**self.gamma * ce_loss
        return fl_loss.mean() if self.reduce == 'mean' else fl_loss.sum() if self.reduce == 'sum' else fl_loss


In [6]:
saga = pd.read_csv(f'{PATH}/data/saga_redshifts_2021-02-19.csv', dtype={'OBJID': str})

df = saga.sample(frac=1, random_state=seed).copy()
df['low_z'] = df.SPEC_Z < 0.03

# in order to use with previous utils
df['SPEC_FLAG'] = 1

In [7]:
label_column = 'low_z'

# focal loss weight params
gamma = 2
loss_func = FocalLossFlat(gamma=gamma)

In [8]:
sz = 144
bs = 128

img_dir = 'images-legacy_saga-2021-02-19'

legacy_image_stats = [np.array([0.14814416, 0.14217226, 0.13984123]), np.array([0.0881476 , 0.07823102, 0.07676626])]

item_tfms = [Resize(sz)]
batch_tfms = (
    aug_transforms(max_zoom=1., flip_vert=True, max_lighting=0., max_warp=0.) + 
    [Normalize.from_stats(*legacy_image_stats)]
)

seed = 256  

In [9]:

dblock = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_x=ColReader(['OBJID'], pref=f'{PATH}/images-legacy_saga-2021-02-19/', suff='.jpg'),
    get_y=ColReader('low_z'),
    splitter=RandomSplitter(0),
    item_tfms=item_tfms,
    batch_tfms=batch_tfms,
)

dls = ImageDataLoaders.from_dblock(dblock, df, path=PATH, bs=bs)

model = xresnet34_hybrid(n_out=2, sa=True, act_cls=MishCuda, groups=64, reduction=8)
learn = Learner(
    dls, model, 
    opt_func=ranger, 
    loss_func=loss_func,
)
learn.load(f'{PATH}/models/desi-sv_FL-hdxresnet34-sz{sz}');


In [10]:
filenames = list(x for x in (PATH/'images-legacy-dr9').rglob('*.jpg') if (x.stat().st_size >= 4096))

In [27]:
test_dl = dls.test_dl(filenames, num_workers=8, bs=32)

This takes > 4hrs!!!

In [30]:
m = learn.model.eval()

outputs = []
with torch.no_grad():
    for (xb,) in tqdm.tqdm(iter(test_dl), total=len(test_dl)):
        outputs.append(m(xb).cpu())
        
outs = torch.cat(outputs)

HBox(children=(FloatProgress(value=0.0, max=138535.0), HTML(value='')))




## Save

In [38]:
filenames[:5]

[Path('/home/jupyter/xSAGA/images-legacy-dr9/1237658629697765559.jpg'),
 Path('/home/jupyter/xSAGA/images-legacy-dr9/1237662336799277285.jpg'),
 Path('/home/jupyter/xSAGA/images-legacy-dr9/1237654870528753879.jpg'),
 Path('/home/jupyter/xSAGA/images-legacy-dr9/1237655124480230147.jpg'),
 Path('/home/jupyter/xSAGA/images-legacy-dr9/1237664094510711200.jpg')]

In [None]:
xsaga = pd.read_csv(PATH/'data/xSAGA_SDSS_all.csv', index_col=0)

# only keep relevant ones
objIDs = list(x.stem for x in filenames)
objIDs = np.array(objIDs, dtype=np.int64)


In [43]:
ps = outs.softmax(1)

In [44]:
preds = pd.DataFrame({'p_CNN': ps[:, 1]}, index=objIDs)

In [45]:
preds

Unnamed: 0,p_CNN
1237658629697765559,0.032666
1237662336799277285,0.062886
1237654870528753879,0.120395
1237655124480230147,0.081715
1237664094510711200,0.071647
...,...
1237667783910621435,0.064162
1237665584860954694,0.308611
1237653618544476213,0.076890
1237662225679123024,0.120305


In [65]:
preds.to_csv(PATH/'results/predictions-dr9_only-preds.csv')

In [51]:
preds = preds.reset_index().rename({'index': 'objID'}, axis=1)

In [61]:
combined = preds.reset_index().join(xsaga, on='objID')

In [62]:
combined = combined.drop('index', axis=1).set_index('objID', drop=True)

In [63]:
combined

Unnamed: 0_level_0,p_CNN,ra,dec,g0,r0,R_eff
objID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1237658629697765559,0.032666,192.754583,12.250387,21.51018,20.93348,1.033137
1237662336799277285,0.062886,254.266444,21.339830,19.22336,18.58278,2.300709
1237654870528753879,0.120395,137.319777,-2.669010,19.43149,18.83373,1.551386
1237655124480230147,0.081715,202.515145,4.021077,21.30742,20.96608,1.072581
1237664094510711200,0.071647,131.218050,27.566439,20.69109,20.29162,1.135022
...,...,...,...,...,...,...
1237667783910621435,0.064162,190.218533,23.268112,21.18684,20.60037,1.422956
1237665584860954694,0.308611,246.568907,58.398485,16.87466,16.33039,4.533244
1237653618544476213,0.076890,158.390940,60.321651,20.23501,19.67139,1.498932
1237662225679123024,0.120305,188.081897,40.583701,20.62520,20.48272,1.916257


In [64]:
combined.to_csv(PATH/'results/predictions-dr9.csv')