#### Note
Follow the authentication workflow in "satellite_imagery_exploration.ipynb" in order to initialize ee

Then load images if necessary using "grab_images.py"

#### Images should already be loaded in image folder.

In [1]:
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import torch
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision import models

from sklearn.model_selection import train_test_split

  Referenced from: /Users/lucasrosenblatt/opt/miniconda3/envs/heuristic_fairness/lib/python3.8/site-packages/torchvision/image.so
  warn(


#### Note
Make sure you put data in folder and adjust "buffelgrass_one_time" path as necessary.

In [2]:
buffelgrass_one_time = 'data/buffelgrass_one_time.csv'
df = pd.read_csv(buffelgrass_one_time)
df_filtered = df[['Observation_ID', 'Observation_Date', 'Create_Date','Latitude', 'Longitude', 'Abundance_Name']]
df_filtered['Abundance_Binary'] = df_filtered['Abundance_Name'].apply(lambda x: 1 if x == '75-94%' or x == '50-74%' or x == '95% or more' else 0)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_filtered['Abundance_Binary'] = df_filtered['Abundance_Name'].apply(lambda x: 1 if x == '75-94%' or x == '50-74%' or x == '95% or more' else 0)


In [None]:
def preprocess_images(image_path, image_path_ir):
    # real color image
    with Image.open(image_path) as img:
        img = img.convert('RGB')
        real_color = transforms.ToTensor()(img)

    # infrared+ image
    with Image.open(image_path_ir) as img:
        img = img.convert('RGB')
        infrared = transforms.ToTensor()(img)

    # we stack the images vertically for passthrough
    concatenated = torch.cat((real_color, infrared), dim=1)  # we could use dim=2 for horizontal stacking

    # need to standardize the size and normalize for imagenet (because resnet trained on imagenet)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
    ])
    
    return transform(concatenated)
        
class SatelliteImageDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.label_map = {category: i for i, category in enumerate(dataframe['Abundance_Binary'].unique())}

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_path = 'images/' + str(row['Observation_ID']) + '.png'
        image_path_ir = 'images/' + str(row['Observation_ID']) + '_ir.png'
        image = preprocess_images(image_path, image_path_ir)
        label = float(row['Abundance_Binary'])
        return image, torch.tensor(label, dtype=torch.float)
    
# train test split
df_train, df_val = train_test_split(df_filtered, test_size=0.2, random_state=42, stratify=df_filtered['Abundance_Binary'])

# dataset loaders for training and validation
dataset = SatelliteImageDataset(df_train)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
val_dataset = SatelliteImageDataset(df_val)
validation_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False) 

# pre-trained ResNet model
model = models.resnet18(pretrained=True)

# finetune resnet
model.fc = torch.nn.Linear(model.fc.in_features, 1)
criterion = torch.nn.BCEWithLogitsLoss() #CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)

num_epochs = 32

best_val_loss = float('inf')
best_val_accuracy = 0.0 

# set to true if we want to keep the most accurate model vs. least lossy
storing_val_accuracy = True 
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(dataloader):
        outputs = model(images)
        # reshaping labels to match output shape
        labels = labels.view(-1, 1) 
        loss = criterion(outputs, labels)

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 20 == 0:
            print(i)
    
    scheduler.step()
    
    val_loss = 0
    with torch.no_grad():
        for images, labels in validation_dataloader:
            outputs = model(images)
            # reshaping labels to match output shape
            labels = labels.view(-1, 1) 
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    # average loss for tracking
    val_loss /= len(validation_dataloader)

    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in validation_dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total

    if storing_val_accuracy:
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"New best model saved with validation accuracy: {best_val_accuracy}%")
    else:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"New best model saved with validation loss: {best_val_loss}")

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

In [None]:
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 1)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for images, labels in validation_dataloader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the validation images: {100 * correct / total}%')