<a href="https://colab.research.google.com/github/nmcardoso/galmorpho/blob/master/sdss_stamps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# Image Options:
# p = ["G", "L", "P", "S", "O", "B", "F", "M", "Q", "I", "A", "X"]
# f = ["Grid", "Label", "PhotoObjs", "SpecObjs", "Outline",
#     "BoundingBox", "Fields", "Masks", "Plates", "InvertImage", "APOGEE", "2MASS Images"];

In [0]:
# url_format: http://skyserver.sdss.org/dr15/SkyServerWS/ImgCutout/getjpeg?TaskName=Skyserver.Chart.Navi&ra=229.525575753922&dec=42.7458537608544&scale=0.3&width=300&height=300&opt=X

In [0]:
import requests
from IPython.display import Image, display
import pandas as pd
import shutil
import os
import tarfile
from google.colab import drive
import copy
import multiprocessing as mp
import progressbar

In [0]:
data = pd.read_csv('http://gz2hart.s3.amazonaws.com/gz2_hart16.csv.gz', compression='gzip')

In [0]:
data[['ra', 'dec', 'gz2_class']].groupby('gz2_class').agg(['count']).sort_values(('ra', 'count'), ascending=False)

Unnamed: 0_level_0,ra,dec
Unnamed: 0_level_1,count,count
gz2_class,Unnamed: 1_level_2,Unnamed: 2_level_2
Ei,44038,44038
Er,36764,36764
Ser,14009,14009
Sc?t,13509,13509
Ec,10149,10149
...,...,...
SBd1l(d),1,1
SBa(d),1,1
SBd1l(m),1,1
SBd1l(o),1,1


In [0]:
pd.set_option('display.max_rows', data.shape[0]+1)
data[['ra', 'gz2_class']].groupby('gz2_class').agg(['count']).sort_values(('ra', 'count'), ascending=False)

Unnamed: 0_level_0,ra
Unnamed: 0_level_1,count
gz2_class,Unnamed: 1_level_2
Ei,44038
Er,36764
Ser,14009
Sc?t,13509
Ec,10149
Sb,6932
SBc2m,5862
Sc,5776
Sb?t,5431
Sen,4450


In [0]:
valid_classes = ['Ei', 'Er', 'Ec', 'Ser', 'Sc2m']
final_data = data[['ra', 'dec', 'gz2_class']]
final_data = final_data[final_data['gz2_class'].isin(valid_classes)]
final_data

Unnamed: 0,ra,dec,gz2_class
2,183.371979,50.741508,Ei
4,161.086395,14.084465,Er
5,246.921387,40.926968,Ei
6,249.474640,36.073040,Ei
8,195.278030,39.841473,Ei
...,...,...,...
239688,173.478195,28.623381,Ec
239689,125.736557,21.344851,Ec
239690,167.542648,28.991867,Ec
239692,21.690212,-0.546427,Ei


In [0]:
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [0]:
base_url = 'http://skyserver.sdss.org/dr15/SkyServerWS/ImgCutout/getjpeg'
downloads_path = '/content/downloads'
dataset_path = '/gdrive/My Drive/ml_datasets/splus_crossmatch.csv'
# data = None
data = final_data

In [0]:
TOTAL_IMAGES_DOWNLOAD = 0
CURRENT_IMAGE_INDEX = 0
DOWNLOAD_ERROR_COUNT = 0
PROGRESSBAR = None

def load_dataset(force=False):
  global data
  if (not data or force):
    data = pd.read_csv(dataset_path)

def get_params(ra, dec, scale=0.2, width=200, height=200, opt=''):
  return {
      'TaskName': 'Skyserver.Chart.Navi',
      'ra': ra,
      'dec': dec,
      'scale': scale,
      'width': width,
      'height': height,
      'opt': opt
  }

def prepare_downloads_dir(classes):
  if (os.path.exists(downloads_path)):
    shutil.rmtree(downloads_path)
  os.mkdir(downloads_path)
  for c in classes:
    os.mkdir(f'{downloads_path}/{c}')

def download_image(url, params, filename, output_path=downloads_path):
  resp = requests.get(url, params)

  if (resp.status_code == 200):
    with open(f'{output_path}/{filename}', 'wb') as f:
      f.write(resp.content)

      return {
          'success': True,
          'filename': filename
      }
  else:
    return {
        'success': False,
        'filename': filename,
        'message': resp.status_code
    }

def make_tarfile(source, output):
  with tarfile.open(output, "w:gz") as tar:
    tar.add(source, arcname=os.path.basename(source))
  print(f'Tarfile created successfully [{output}]')

def batch_download(data, classes_column, concurrency=12, limit=None):
  global TOTAL_IMAGES_DOWNLOAD, PROGRESSBAR

  def mp_callback(results):
    global CURRENT_IMAGE_INDEX, PROGRESSBAR
    CURRENT_IMAGE_INDEX += 1
    if (CURRENT_IMAGE_INDEX < TOTAL_IMAGES_DOWNLOAD):
      PROGRESSBAR.update(CURRENT_IMAGE_INDEX)
    # if (results['success']): 
    #   print(f'Success: {results["filename"]} ({((CURRENT_IMAGE_INDEX/TOTAL_IMAGES_DOWNLOAD)*100):.2f}%)')
    # else:
    #   global DOWNLOAD_ERROR_COUNT
    #   DOWNLOAD_ERROR_COUNT += 1
    #   print(f'Error: {results["filename"]} {results.message}')

  classes = list(final_data.groupby(classes_column).indices.keys())
  prepare_downloads_dir(classes)
  TOTAL_IMAGES_DOWNLOAD = limit if limit else data.shape[0]
  pool = mp.Pool(processes=concurrency)
  i = 0
  PROGRESSBAR = progressbar.ProgressBar(max_value=TOTAL_IMAGES_DOWNLOAD)
  PROGRESSBAR.start()

  for index, row in data.iterrows():
    if (limit and i >= limit):
      break

    pool.apply_async(download_image, callback=mp_callback, args=(
        base_url, 
        get_params(row['ra'], row['dec'], width=80, height=80, scale=1.35), 
        f'{str(index)}.jpg',
        f'{downloads_path}/{row[classes_column]}'
    ))
    i += 1
  
  pool.close()
  pool.join()
  PROGRESSBAR.finish()
  print(f'\n{CURRENT_IMAGE_INDEX} images downloaded')
  print(f'Proccess finished with {DOWNLOAD_ERROR_COUNT} failed downloads')

In [0]:
load_dataset()

print(*data.columns, sep='\n')

In [0]:
kmap = {}
for i, r in data.iterrows():
  klass = r['gz2class']
  kmap[klass] = kmap[klass] + 1 if klass in kmap else 1

sorted_kmap = sorted(kmap.items(), key=lambda x: x[1], reverse=True)
print(*[f'{x[0]}: {x[1]}' for x in sorted_kmap[:50]], sep='\n')

In [0]:
valid_classes = ['Ei', 'Er', 'Ec', 'Ser', 'Sc2m']

final_data = data[data['gz2class'].isin(valid_classes)]
final_data.describe()

In [0]:
batch_download(final_data, 'gz2_class', concurrency=35)

100% (108286 of 108286) |################| Elapsed Time: 0:35:33 Time:  0:35:33



108286 images downloaded
Proccess finished with 0 failed downloads


In [0]:
make_tarfile(downloads_path, '/gdrive/My Drive/ml_datasets/sdss_full_80px.tar.gz')

Tarfile created successfully [/gdrive/My Drive/ml_datasets/sdss_full_80px.tar.gz]
