In [3]:
import os
import ee
import random
import torch
import timm
import pickle
import numpy as np
import torch.nn as nn
import torchgeo
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset, Landsat8
from torchgeo.models import ResNet18_Weights
from torchgeo.samplers import RandomGeoSampler
from torchvision import transforms
from datetime import datetime

GEE Python API to extract images

In [None]:
ee.Authenticate()
ee.Initialize()

# Define arbitrary ROI in upstate New York for seasonality
roi = ee.Geometry.Rectangle([-74.8, 43.5, -74, 44])

def get_image_collection(month, year, start_day, end_day):
    return ee.ImageCollection('LANDSAT/LC08/C02/T1') \
        .filterBounds(roi) \
        .filterDate(f'{year}-{month:02d}-{start_day:02d}', f'{year}-{month:02d}-{end_day:02d}') \
        .sort('CLOUD_COVER_LAND') \
        .limit(2)

#Define months and years (handpicked for optimal cloudcover)
months_years = [
    (1, 2018, 1, 30), (2, 2019, 1, 28), (3, 2021, 1, 28), 
    (4, 2020, 1, 28), (5, 2020, 1, 28), (6, 2020, 1, 30), 
    (7, 2020, 1, 30), (8, 2020, 1, 30), (9, 2019, 1, 30), 
    (10, 2021, 1, 30), (11, 2019, 1, 30), (12, 2019, 1, 30)
]

def export_image_to_cloud(image, description):
    task = ee.batch.Export.image.toCloudStorage(
        image=image,
        description=description,
        bucket='dlcv_finalproj_data1',
        fileNamePrefix=f'imgcollect/{description}',
        scale=30,
        region=image.geometry().bounds(),
        fileFormat='GeoTIFF'
    )
    task.start()

def get_and_export_images(image_collection):
    image_collection.select('B.+')
    image_ids = image_collection.aggregate_array('system:index').getInfo()
    
    for image_id in image_ids:
        image = image_collection.filter(ee.Filter.eq('system:index', image_id)).first()
        export_image_to_cloud(image, image_id)

for month, year, start_day, end_day in months_years:
    image_collection = get_image_collection(month, year, start_day, end_day)
    get_and_export_images(image_collection)

Mount preprocessed (randomly sampled, mixed) image patches from Google Cloud Storage. For data any larger we'd need to load batches directly from bucket.

In [None]:
bucket_name = 'dlcv_finalproj_data1'
folder_path = 'imgcollect'
file_path = 'combined_data.pkl'

!gsutil -m -o 'gsutil:sliced_object_download_threshold=150M' cp -r gs://{bucket_name}/{file_path} /content

In [None]:
with open("/content/combined_data.pkl", 'rb') as file:
    data = pickle.load(file)

random.shuffle(data)

In [None]:
train_data = data[:10000]
val_data = data[10001:11250]
test_data = data[11250:]

train_loader = DataLoader(train_data, batch_size=64)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Import Resnet18 pretrained Landsat weights; freeze backbone for first round training

In [None]:
weights = ResNet18_Weights.LANDSAT_OLI_SR_MOCO
in_chans = weights.meta["in_chans"]
model = timm.create_model("resnet18", in_chans=in_chans, num_classes=12)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
model = model.to(device)

for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [None]:
epoch = 0
val_accuracy = 0.0

while val_accuracy < 96:
      epoch += 1
      model.train()
      running_loss = 0.0

      for inputs, labels in train_loader:
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()

          outputs = model(inputs).to(device)
          labels -= 1
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          running_loss += loss.item()

      # Validate
      val_loss = 0.0
      correct = 0
      total = 0
      model.eval()

      for inputs, labels in val_loader:
          inputs, labels = inputs.to(device), labels.to(device)
          labels -= 1
          outputs = model(inputs).to(device)
          loss = criterion(outputs, labels)
          val_loss += loss.item()
          _,predicted = torch.max(outputs,1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()

      #Calculate Performance
      val_loss /= len(val_loader)
      val_accuracy = 100 * (correct / total)

      print(f'Epoch {epoch}, Train Loss: {running_loss / len(train_loader)}, Val Accuracy: {val_accuracy:.2f}%')

best_weights = model.state_dict()
torch.save(best_weights, 'first_round_train.pth')

Unfreeze backbone

In [None]:
for param in model.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [None]:
epoch = 0
val_accuracy = 0.0

while val_accuracy < 98:
      epoch += 1
      model.train()
      running_loss = 0.0

      for inputs, labels in train_loader:
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()

          outputs = model(inputs).to(device)
          labels -= 1
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          running_loss += loss.item()

      # Validate
      val_loss = 0.0
      correct = 0
      total = 0
      model.eval()

      for inputs, labels in val_loader:
          inputs, labels = inputs.to(device), labels.to(device)
          labels -= 1
          outputs = model(inputs).to(device)
          loss = criterion(outputs, labels)
          val_loss += loss.item()
          _,predicted = torch.max(outputs,1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()

      #Calculate Performance
      val_loss /= len(val_loader)
      val_accuracy = 100 * (correct / total)

      print(f'Epoch {epoch}, Train Loss: {running_loss / len(train_loader)}, Val Accuracy: {val_accuracy:.2f}%')

best_weights = model.state_dict()
torch.save(best_weights, 'best_model.pth')

In [None]:
model.eval()

correct = 0
total = 0

with torch.no_grad():
  for inputs, labels in test_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs).to(device)
    _, predicted = torch.max(outputs,1)
    total += labels.size(0)
    labels -= 1
    correct += (predicted == labels).sum().item()

accuracy = 100 * (correct / total)
print(f'Test Accuracy: {accuracy:.2f}%')