<a href="https://colab.research.google.com/github/huiliterry/CornSoybeanUSCanadaBrazil/blob/main/USGridMapping_StratifySample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import ee
import geemap.core as geemap

In [None]:
ee.Authenticate()
ee.Initialize(project='ee-huil7073') #please change to your project name

In [None]:
import geemap
m = geemap.Map()

## **1. S2 image collection**

In [None]:
def S2_cutCldSlw(start_date,end_date,boundary,CLD_PRB_THRESH,NIR_DRK_THRESH,CLD_PRJ_DIST,BUFFER):
  def addVariables(image):
    #// Compute time in fractional years since the epoch.
    date = ee.String(image.get('system:index'))
    #// var days = ee.Date(date.slice(0,4).cat('-').cat(date.slice(4,6)).cat('-').cat(date.slice(6,8))).format('DDD');
    year = date.slice(0,4)
    month = date.slice(4,6)
    dateOfMonth = date.slice(6,8)
    #//generating day number in the imgcollection
    days = (ee.Number.parse(ee.Date(year.cat('-').cat(month).cat('-').cat(dateOfMonth)).format('DDD'))
                        .add((ee.Number.parse(year).subtract(ee.Number.parse(startYear))).multiply(365))
                        .subtract(initialDays))
    #// Return the image with the added bands.
    return (image
    #// Add an NDVI band.
    .addBands(image.normalizedDifference(['B8', 'B4']).float().rename('NDVI'))
    #// Add an GCVI band.
    .addBands(image.select('B8').divide(image.select('B3')).subtract(ee.Image(1)).float().rename('GCVI'))
    .addBands(image.normalizedDifference(['B3','B8']).float().rename('NDWI'))
    #// Add an MSI band.
    .addBands(image.select('B11').divide(image.select('B8')).float().rename('MSI'))
    #// edit band names.
    .select(band)
    .set('system:day_start',ee.Number.parse(days)))

  #// Remove clouds, add variables and filter to the area of interest.
  #// Cloud components
  def add_cloud_bands(img):
      #//Get s2cloudless image, subset the probability band.
      cld_prb = ee.Image(img.get('s2cloudless')).select('probability')
      #//Condition s2cloudless by the probability threshold value.
      is_cloud = cld_prb.gt(CLD_PRB_THRESH).rename('clouds')
      #//Add the cloud probability layer and cloud mask as image bands.
      return img.addBands(ee.Image([cld_prb, is_cloud]))


  #//Cloud shadow components
  def add_shadow_bands(img):
      #//Identify water pixels from the SCL band.
      #//var not_water = img.select('SCL').neq(6)
      #//Identify dark NIR pixels that are not water (potential cloud shadow pixels).
      SR_BAND_SCALE = 1e4
      dark_pixels = img.select('B8').lt(NIR_DRK_THRESH*SR_BAND_SCALE).rename('dark_pixels')
      #//Determine the direction to project cloud shadow from clouds (assumes UTM projection).
      shadow_azimuth = ee.Number(90).subtract(ee.Number(img.get('MEAN_SOLAR_AZIMUTH_ANGLE')))
      #//Project shadows from clouds for the distance specified by the CLD_PRJ_DIST input.
      cld_proj = (img.select('clouds').directionalDistanceTransform(shadow_azimuth, CLD_PRJ_DIST*10)
          .reproject(crs= img.select(0).projection(), scale= 100)
          .select('distance')
          .mask()
          .rename('cloud_transform'))
      #//Identify the intersection of dark pixels with cloud shadow projection.
      shadows = cld_proj.multiply(dark_pixels).rename('shadows')
      #//Add dark pixels, cloud projection, and identified shadows as image bands.
      return img.addBands(ee.Image([dark_pixels, cld_proj, shadows]))


  #//Final cloud-shadow mask
  def add_cld_shdw_mask(img):
      #//Add cloud component bands.
      img_cloud = add_cloud_bands(img)
      #//Add cloud shadow component bands.
      img_cloud_shadow = add_shadow_bands(img_cloud)
      #//Combine cloud and shadow mask, set cloud and shadow as value 1, else 0.
      is_cld_shdw = img_cloud_shadow.select('clouds').add(img_cloud_shadow.select('shadows')).gt(0)
      #//Remove small cloud-shadow patches and dilate remaining pixels by BUFFER input.
      #//20 m scale is for speed, and assumes clouds don't require 10 m precision.
      is_cld_shdw = (is_cld_shdw.focalMin(2).focalMax(BUFFER*2/20)
          .reproject(crs = img.select([0]).projection(), scale = 20)
          .rename('cloudmask'))
      #//Add the final cloud-shadow mask to the image.
      return img_cloud_shadow.addBands(is_cld_shdw)


  def apply_cld_shdw_mask(img):
      #//Subset the cloudmask band and invert it so clouds/shadow are 0, else 1.
      not_cld_shdw = img.select('cloudmask').Not()
      #//Subset reflectance bands and update their masks, return the result.
      return img.select('B.*').updateMask(not_cld_shdw)


  #//mask the water pixel
  def apply_scl_water_mask(img):
    scl = img.select('SCL')
    wantedPixels = scl.neq(6)
    targetPixels = scl.eq(4).Or(scl.eq(5))
    return img.updateMask(wantedPixels)


  #// Import and filter S2 SR.
  s2_sr_col = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
          .filterDate(start_date, end_date)
          .filterBounds(boundary)
          #// .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', CLOUD_FILTER))
          .filter(ee.Filter.eq('GENERAL_QUALITY','PASSED')))

  #// Import and filter s2cloudless.
  s2_cloudless_col = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
          .filterBounds(boundary)
          .filterDate(start_date, end_date))

  s2_sr_cld_col_eval = (ee.ImageCollection(ee.Join.saveFirst('s2cloudless').apply(
          primary = s2_sr_col,
          secondary = s2_cloudless_col,
          condition = ee.Filter.equals(
              leftField = 'system:index',
              rightField = 'system:index'
          )
      )))

  #//finding the initial days of the year
  initialDate = ee.String(s2_sr_col.first().get('system:index'))
  startYear = initialDate.slice(0,4)
  month = initialDate.slice(4,6)
  dateOfMonth = initialDate.slice(6,8)
  initialDays = ee.Number.parse(ee.Date(startYear.cat('-').cat(month).cat('-').cat(dateOfMonth)).format('DDD'))
  #// print('initialDays',initialDays)
  #//creating final imgcollection
  s2_no_cld_shdw =  (s2_sr_cld_col_eval
                        .map(add_cld_shdw_mask)
                        .map(apply_cld_shdw_mask)
                        .map(lambda img: img.clip(boundary))
                        .map(addVariables))
  #// print('s2_no_cld_shdw',s2_no_cld_shdw)
  return s2_no_cld_shdw

## **2. Cosine Regression**

In [None]:
import math

def ResCoesCosine(order,speed,imgCollection):
  dependentSeries = ee.List(band)#//,'B1II','B2II','B3II''NDVI',,'NDMI','MSI'
  #//print('dependentSeries',dependentSeries);
  #// The number of cycles per year to model.
  #// Make a list of harmonic frequencies to model.
  #// These also serve as band name suffixes.
  harmonicFrequencies = ee.List.sequence(1, order)
  #// Function to get a sequence of band names for harmonic terms.
  # getNames = function(base, list) {
  #   return ee.List(list).map(function(i) {
  #     return ee.String(base).cat(ee.Number(i).int())
  #   })
  # }
  def getNames(base, list):
    return ee.List(list).map(lambda i:ee.String(base).cat(ee.Number(i).int()))

  #// Construct lists of names for the harmonic terms.
  cosNames = getNames('cos_', harmonicFrequencies)
  #// Independent variables.
  independents = ee.List(['constant']).cat(cosNames)
  #// print('independents',independents);

  def addConstant(image):
    return image.addBands(ee.Image(1))

  basicFrequency = speed
  #// Function to add a time band.
  def addTime(image):
    #// Compute time in fractional years since the epoch.
    days = ee.String(image.get('system:day_start'))
    timeRadians = ee.Image(
      ee.Number.parse(days).divide(365).multiply(basicFrequency*2*math.pi)
    )
    return image.addBands(timeRadians.rename('t').float())


  def addHarmonics(freqs):
    def addCosine(image):
      #// Make an image of frequencies.
      frequencies = ee.Image.constant(freqs)
      #// This band should represent time in radians.
      time = ee.Image(image).select('t')
      #// Get the sin terms.
      cosines = time.multiply(frequencies).cos().rename(cosNames)
      return image.addBands(cosines)
    return addCosine

  #//print('imgCollection',imgCollection.map(addConstant).map(addTime));
  #//print('speed',speed);
  harmonicS2 = (imgCollection
    .map(addConstant)
    .map(addTime)
    .map(addHarmonics(harmonicFrequencies)))
  #//print('harmonicS2',harmonicS2);
  #//amplitudes

  amplitudes = (ee.ImageCollection(dependentSeries.map(lambda dependent: (harmonicS2.select(independents.add(dependent))
                                                                                  .reduce(ee.Reducer.robustLinearRegression(independents.length(), 1))
                                                                                  .select('coefficients')
                                                                                  .arrayProject([0])
                                                                                  .arrayFlatten([independents])
                                                                                  #//.slice(1)
                                                                                  #// .abs()
                                                                                  .set('spectral',dependent))
                                                                                )))
  #// print('amplitudes',amplitudes);
  return amplitudes


## **3. Generate Regression Coefficiency**

In [None]:
def coefficientsImg(startDate,endDate, region, MAX_CLOUD_PROBABILITY, order, speed):
  NIR_DRK_THRESH = 0.25
  CLD_PRJ_DIST = 1
  BUFFER = 20
  # startDate = year + startday
  # endDate = year + endday
  preProsS2 = S2_cutCldSlw(startDate,endDate,region,MAX_CLOUD_PROBABILITY,NIR_DRK_THRESH,CLD_PRJ_DIST,BUFFER)
  coefficients = ResCoesCosine(order,speed,preProsS2).toBands()
  return coefficients

## **4. Training sample processing**

In [None]:
def labelExtractEven(year,region,labelNum):
  # cdl = ee.Image('USDA/NASS/CDL/'+year).clip(region).select('cropland')
  # confidence = ee.Image('USDA/NASS/CDL/'+year).clip(region).select('confidence')
  # year  = ee.String('2019')
  year1 = ee.Date(year)
  year2 = ee.Date(ee.String(ee.Number.parse(year).add(1)))
  cdl = ee.ImageCollection('USDA/NASS/CDL').filterDate(year1,year2).first().select('cropland')#.rename('cropland')
  confidence = ee.ImageCollection('USDA/NASS/CDL').filterDate(year1,year2).first().select('confidence')#.rename('confidence')

  cornValue = 1  # Corn
  soybeanValue = 5 # Soybeans
  # // Create a mask that removes corn and soybean pixels
  otherMask = (cdl.neq(cornValue).And(cdl.neq(soybeanValue))).And(confidence.gt(85)).rename('mask')
  csmask = (cdl.eq(cornValue).Or(cdl.eq(soybeanValue))).And(confidence.gt(85)).rename('mask')

  # // Apply the mask to the CDL data
  # // var otherFiltered = cdl.updateMask(otherMask).rename('croptype')
  csFiltered = cdl.updateMask(csmask).rename('croptype')

  otherLabel = otherMask.selfMask().stratifiedSample(
    numPoints = labelNum,
    classBand = "mask",
    region = region,
    scale = 30,
    geometries = True
  ).map(lambda ele: ele.set('croptype',0))
  #display('otherLabel',otherLabel.size(),otherLabel)

  cornsoybenaLabel = csFiltered.stratifiedSample(
    numPoints = otherLabel.size(),
    classBand = "croptype",
    region = region,
    scale = 30,
    geometries = True
  )
  #display('cornsoybenaLabel',cornsoybenaLabel.size(),cornsoybenaLabel)

  allLabel = otherLabel.merge(cornsoybenaLabel)
  return allLabel

#//extract training samples in a district
def trainingSample(coefficiency, labelPoints, cropLabel):
  extractSample = coefficiency.sampleRegions(
                                            collection = labelPoints,
                                            properties = [cropLabel],
                                            scale = 10,
                                            tileScale = 4,
                                            geometries = True
                                          )
  return extractSample


def cropSamples(roi,year,coefficiency,number,classBand):
  croplabel = labelExtractEven(year,roi,number)
  trainingSamples = trainingSample(coefficiency, croplabel, classBand)
  return trainingSamples


#//sampel type setting
def setType(feature):
  croptype = feature.get('croptype')
  oneBool = ee.Algorithms.IsEqual(croptype, 1)
  fiveBool = ee.Algorithms.IsEqual(croptype, 5)
  equalBool = ee.Algorithms.IsEqual(oneBool, fiveBool)
  return (feature.set('croptype', ee.Algorithms.If(equalBool, 0, croptype)))

## **5. Generate classifier for each given grid**

In [None]:
def classificationByGrid(region):

  #training and classifcation year list
  yearList = [str(year) for year in range(2019, int(predictYear))]

  startDay = '-05-01'
  endDay = '-07-01'
  diff = 62

  if season == 'July':
    endDay = '-08-01'
    diff = 92
  elif season == 'August':
    endDay = '-09-01'
    diff = 123

  startPredictDay = predictYear + startDay
  endPredictDay = predictYear + endDay

  startDayList = ee.List([year+startDay for year in yearList])

  def sampleList(startDay):
    startDay = ee.Date(startDay)
    endDay = startDay.advance(diff,'day')
    year = ee.String(startDay.format('YYYY'))
    coefficiencyImages = coefficientsImg(startDay,endDay,region,MAX_CLOUD_PROBABILITY,order,k)
    trainingSamples =  cropSamples(region,year,coefficiencyImages,singleTypeNum,cropLabel)
    return trainingSamples

  trainingSamples = ee.FeatureCollection(startDayList.map(sampleList)).flatten()

  predictCoefficiencyImages = coefficientsImg(startPredictDay,endPredictDay,region,MAX_CLOUD_PROBABILITY,order,k)

  def imgClassified():
    return (predictCoefficiencyImages.classify(ee.Classifier.smileRandomForest(100).train(
                                features = trainingSamples,
                                classProperty = 'croptype',
                                inputProperties = predictCoefficiencyImages.bandNames()
                              )
                          ).remap([1,5],[1,5]).clip(region).set('type','classification'))

  def imgNull():
    return ee.Image(0).clip(region).set('type','null')

  return ee.Algorithms.If(trainingSamples.size().neq(0).And(trainingSamples.aggregate_count_distinct("croptype").neq(1)),imgClassified(),imgNull())



# **6. State grids classifier**

In [None]:
def stateClassificaiton(stateName):
  #// Load the country features from the LSIB dataset.
  countries = ee.FeatureCollection('USDOS/LSIB_SIMPLE/2017')
  stateboundary = ee.FeatureCollection("TIGER/2018/States")
  state = stateboundary.filter(ee.Filter.eq('NAME',stateName))
  stateGeometry = state.geometry()
  #// Construct grid and intersect with country polygon
  grid = (stateGeometry.coveringGrid(stateGeometry.projection(),gridSize).map(lambda fea: fea.set('id',fea.id().replace('-', '').replace(',', ''))))

  gridList = grid.toList(grid.size())
  print(stateName + ' GridNumber:', grid.size().getInfo())

  gridNum = gridList.size().getInfo()
  for q in range(gridNum):
    geo = ee.Feature(gridList.get(q)).geometry()
    id = ee.Feature(gridList.get(q)).get('id').getInfo()

    CDL_cropland = (ee.Image('USDA/NASS/CDL/'+str(int(predictYear)-1))
                .select('cultivated')
                .clip(geo)
                .remap([1,2],[0,1]))

    gridPrediction = ee.Image(classificationByGrid(geo)).multiply(CDL_cropland).toByte()

    stateName = stateName.replace(" ", "")
    predictionName = stateName + predictYear + season + str(q) + 'th_501_424_Equal' + str(singleTypeNum)
    task = ee.batch.Export.image.toDrive(
        image = gridPrediction,
        description = predictionName,
        folder = 'US_TS501_424_Equal' + stateName,
        region = geo,
        scale = 10,
        crs = 'EPSG:5070',
        maxPixels= 10000000000000
    )
    task.start()
    print(predictionName + ' Well Done')


# **7. Operation**

In [None]:
cropLabel = "croptype"
MAX_CLOUD_PROBABILITY = 70
order = 4
k = 2.4
band = ["B3","B4","B5","B8","B11","B12","NDVI","GCVI","MSI","NDWI"]

predictYear = '2023'
singleTypeNum  = 2000
gridSize = 120000
stateName = ['Iowa','Kansas','Ohio','Wisconsin','Minnesota','Michigan']
# stateName = ['South Dakota','Missouri','Illinois','Kentucky','Indiana','North Dakota']

for m in stateName:
  season = 'August'
  stateClassificaiton(m)

# for j in stateName:
#   season = 'July'
#   stateClassificaiton(j)

# for i in stateName:
#   season = 'June'
#   stateClassificaiton(i)


Kansas GridNumber: 31
Kansas2023August24th_501_424_Equal300 Well Done
Kansas2023August18th_501_424_Equal300 Well Done
