In [3]:
import csv
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torchvision import transforms
import sys
import math
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset
from collections import defaultdict
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from transformers import PretrainedConfig
import open_clip
import clip
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from taxabind import TaxaBind


device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(device)

IMAGE_PATH = "/scratch/cher/Sat2Habitat/data/naip"
CSV_PATH = "/scratch/cher/Sat2Habitat/data/gridkey2text.csv"

cuda


In [22]:
# model, preprocss = clip.load("ViT-B/16", device=device)
config = PretrainedConfig.from_pretrained("MVRL/taxabind-config")
taxabind = TaxaBind(config)
sat_encoder = taxabind.get_sat_encoder()
location_encoder = taxabind.get_location_encoder()
text_encoder = taxabind.get_image_text_encoder()
torch.set_default_dtype(torch.float32)

In [None]:
data = pd.read_csv(CSV_PATH)
# remove_ids = set()
# with open ("remove.txt", "r") as f:
#     for line in f:
#         inat_id = line.split('_')[0]
#         remove_ids.add(int(inat_id))

# filtered_data = data[~data["inat_id"].isin(remove_ids)]
# filtered_data.to_csv("filtered_data.csv", index=False)

# sat_id = data['key']

## Dataset

In [44]:
# 0_43.83486_-71.22231.jpg
class MultiData(Dataset):
    def __init__(self, image_path, csv_path, transform=None):
        self.image_path = Path(image_path)
        self.csv_path = csv_path
        self.image_dict = self._build_image_dict()
        self.data = pd.read_csv(self.csv_path)
        # _, self.occ_id = pd.factorize(self.data["occurrenceID"])
        # self.occ_id = self.occ_id.tolist() # ??

        # text params
        self.hab_desc = 'habitat'
        self.alt_cols = ['habitat_wiki', 'distribution and habitat_wiki', 'description_wiki', 'ecology_wiki', 'distribution_wiki', 'header_wiki']
        self.random_prob = 0.9
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        sat_id = row["key"]
        # observer = row["rights_holder"]
        # observer_id = torch.tensor(self.observer_id.index(observer))
        lat = torch.tensor(row["lat"])
        lon = torch.tensor(row["lon"])
        image_file = self.image_dict.get(sat_id)
        if image_file:
            image = Image.open(image_file)
            image_transform = transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.RandomCrop((224, 224)),
                    transforms.RandomHorizontalFlip(0.5),
                    transforms.GaussianBlur(5, (0.01, 1.0)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
            ])
            image = image_transform(image).to(device)
        else:
            raise FileNotFoundError(f"No image found for sat_id: {sat_id}")
        
        # Get the text description (habitat or randomized)
        text = self._get_text_randomized(row)
        
        return image, text, torch.tensor([lat, lon])
    
    def _build_image_dict(self):
        image_dict = {}
        for image_file in self.image_path.glob("*.png"):
            try:
                sat_id = image_file.stem.split("/")[-1].replace(".png" , "")
                image_dict[sat_id] = image_file
            except ValueError:
                print(f"Invalid image file name {image_file}")
        return image_dict
    
    def _get_text_randomized(self, row):

        if np.random.rand() < self.random_prob:
            return row[self.hab_desc]
        else:
            
            alternative_values = row[self.alt_cols].to_numpy()
            non_nan_values = alternative_values[~np.isnan(alternative_values)]
            
            # If there are non-NaN values, select one randomly
            if non_nan_values.size > 0:
                return np.random.choice(non_nan_values)
            
            # If all alternatives are NaN, return 'habitat' as a fallback
            return row[self.hab_desc]

## Model

In [45]:
class ContrastiveModel(nn.Module):
    def __init__(self, sat_encoder, location_encoder, text_encoder):
        super().__init__()
        # self.clip_model = clip_model
        self.sat_encoder = sat_encoder
        self.location_encoder = location_encoder
        self.text_encoder = text_encoder
    
    def forward(self, image, lat_long, text):
        image_features = self.sat_encoder(image)
        lat_long_features = self.location_encoder(lat_long.float())
        text_features = self.text_encoder(text)


        combined_features = image_features.image_embeds + lat_long_features
        return torch.nn.functional.normalize(combined_features, dim=-1), torch.nn.functional.normalize(text_features, dim=-1)
 
        

## Train CLIP

### Load Data:

In [47]:

dataset = MultiData(IMAGE_PATH, CSV_PATH)

train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices) 

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validate_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

model = ContrastiveModel(sat_encoder, location_encoder, text_encoder).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
contrastive_loss = torch.nn.CosineEmbeddingLoss()

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 47.53 GiB of which 14.25 MiB is free. Process 3673919 has 284.00 MiB memory in use. Process 1558402 has 1.16 GiB memory in use. Process 1567082 has 1.16 GiB memory in use. Process 1641344 has 1.24 GiB memory in use. Process 1759632 has 1.71 GiB memory in use. Process 1991684 has 1.15 GiB memory in use. Process 2675952 has 1.15 GiB memory in use. Process 793765 has 38.97 GiB memory in use. Including non-PyTorch memory, this process has 682.00 MiB memory in use. Of the allocated memory 399.15 MiB is allocated by PyTorch, and 22.85 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
for epoch in range(5):
    model.train()
    train_loss = 0.0
    for i, batch in enumerate(train_dataloader):
        image, text, lat_long = batch
        image, lat_long, text = image.to(device), lat_long.to(device), text.to(device)

        combined_features, text_features = model(image, lat_long, text)
        loss = contrastive_loss(combined_features, text_features, torch.ones(combined_features.size(0)).to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        if i % 10 == 0:
            print(f"Losst at {i}: {loss.item()}")

    train_loss /= len(train_dataloader)
    print(f"Epoch {epoch+1}, Training Loss: {train_loss.item():.4f}")

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for image, lat_long, observer in validate_dataloader:
            image, lat_long, text = image.to(device), lat_long.to(device), text.to(device)

            combined_features, text_features = model(image, lat_long, text)
            loss = contrastive_loss(combined_features, text_features, torch.ones(combined_features.size(0)).to(device))
            val_loss += loss.item()
        val_loss /= len(validate_dataloader)
        print(f"Epoch {epoch+1}, Validation Loss: {val_loss:.4f}")

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.15 GiB. GPU 0 has a total capacity of 47.53 GiB of which 711.19 MiB is free. Process 3673919 has 284.00 MiB memory in use. Process 1558402 has 1.16 GiB memory in use. Process 1567082 has 1.16 GiB memory in use. Process 1641344 has 1.24 GiB memory in use. Process 1759632 has 1.71 GiB memory in use. Process 1991684 has 1.15 GiB memory in use. Process 2675952 has 1.15 GiB memory in use. Including non-PyTorch memory, this process has 38.96 GiB memory in use. Of the allocated memory 38.00 GiB is allocated by PyTorch, and 658.56 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Some images don't work, such as: /scratch/s.sastry/ecobind_satellite/taxabind_sentinel/images/sentinel/908287_-36.85098_145.9897.jpeg
/scratch/s.sastry/ecobind_satellite/taxabind_sentinel/images/sentinel/979421_45.51304_9.07877.jpeg