## CLIP using multi habitat sentences per image

In [None]:
from torch.utils.data import Dataset
import torch
import pandas as pd
import numpy as np
from PIL import Image
from crop_image import getImages
from collections import OrderedDict
from transformers import AutoImageProcessor
from torchvision import transforms

# TODO: Change to be dimensions of continental US.
# bounds = [-90.6809899999999942, -90.0909899999996924, 38.4560099999999991, 38.8860099999999136]

class MultiData(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.coords = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.obs = self.coords.drop_duplicates(subset=["species"])["species"].tolist()
        self.obs = list(sorted(self.obs))
        self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
        self.stl_coords = pd.read_csv("st_louis_coords.csv")
        self.spec_freqs = self.coords.value_counts(['species']) / self.coords.shape[0]

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):

        lon = float(self.coords.iloc[idx]["decimalLongitude"])
        lat = float(self.coords.iloc[idx]["decimalLatitude"])
        feats_lon = 2 * ((lon - bounds[0]) / (bounds[1] - bounds[0])) - 1
        feats_lat = 2 * ((lat - bounds[2]) / (bounds[3] - bounds[2])) - 1
        feats = torch.FloatTensor([np.sin(np.pi*feats_lon/2), np.cos(np.pi*feats_lon/2+np.pi/2), np.sin(np.pi*feats_lat/2), np.cos(np.pi*feats_lat/2+np.pi/2)])
        image = Image.fromarray(getImages(lon, lat, self.stl_coords))
        image = self.image_processor(image, return_tensors="pt")
        rand_lon = np.random.uniform(bounds[0]+0.01, bounds[1]-0.01)
        rand_lat = np.random.uniform(bounds[2]+0.01, bounds[3]-0.01)
        rand_feats_lon = 2 * ((rand_lon - bounds[0]) / (bounds[1] - bounds[0])) - 1
        rand_feats_lat = 2 * ((rand_lat - bounds[2]) / (bounds[3] - bounds[2])) - 1
        rand_feats = torch.FloatTensor([np.sin(np.pi*rand_feats_lon/2), np.cos(np.pi*rand_feats_lon/2+np.pi/2), np.sin(np.pi*rand_feats_lat/2), np.cos(np.pi*rand_feats_lat/2+np.pi/2)])
        rand_image = Image.fromarray(getImages(rand_lon, rand_lat, self.stl_coords))
        rand_image = self.image_processor(rand_image, return_tensors="pt")
        species = self.coords.iloc[idx]["species"]
        species_class = self.obs.index(species)
        species_weights = 1 / (self.spec_freqs[species] + 1e-5)

        return image, torch.LongTensor([species_class]), feats, rand_image, rand_feats, species_weights

In [None]:
"""
File: crisp.py
------------------
Our implementation of CLIP classes and functions for CRISP. 
Uses openclip code. 
"""


import torch
import torch.nn as nn
import torchvision
import numpy as np
import torch.nn.functional as F
import pdb



# ----------------- Model class ----------------- #

VALID_ENCODERS = ["resnet50"]

class CrispModel(nn.Module):
    """
    Trainable PyTorch module for CLIP/CRISP pre-training. 
    Has two submodules:
    - sat2cap encoder 1 
    - CLIP text encoder 2
    User must pass in a PyTorch encoder module for each submodule. 
    """

    def __init__(self, encoder_name, embedding_dim=512, pretrained_weights=None):
        super().__init__()

        # assert encoder name is valid
        assert encoder_name in VALID_ENCODERS, f"encoder name {encoder_name} is not valid. Valid encoders are {VALID_ENCODERS}"
    
        # construct the encoder modules
        self.ground_description_encoder = self.construct_encoder(encoder_name, embedding_dim, pretrained_weights)
        self.remote_sensing_encoder = self.construct_encoder(encoder_name, embedding_dim, pretrained_weights)

        # extra 
        self.embedding_dim = embedding_dim

        # see https://github.com/mlfoundations/open_clip/blob/6ee59e10510ec9761b8b9871b9fd1eeb8e28627d/src/open_clip/model.py#L202
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))


    def construct_encoder(self, encoder_name, embedding_dim, pretrained_weights=None):
        """
        Construct the encoder module. 
        Specify the encoder name and optionally pass in pretrained weights
        name from torchvision docs. 

        scratched: 
            base_model = torchvision.models.resnet50(weights=pretrained_weights)
            base_layers = list(base_model.children())[:-1]
            projection_layer = torch.nn.Linear(base_model.fc.in_features, embedding_dim)
            modules = base_layers + [projection_layer]
            encoder = torch.nn.Sequential(*modules)
            return encoder
        """
        if encoder_name == "resnet50":
            model = torchvision.models.resnet50(weights=pretrained_weights)
            d = model.fc.in_features
            model.fc = nn.Linear(d, embedding_dim)
            return model
        

    def lock(self, encoder_to_lock):
        """
        Lock certain parameters of the model so that it cannot be trained.
        See CoCa paper and https://github.com/mlfoundations/open_clip/blob/
        6ee59e10510ec9761b8b9871b9fd1eeb8e28627d/src/open_clip/modified_resnet.py#L154.  
        """
        if encoder_to_lock == "ground_description":
            # freeze the ground image encoder parameters
            for param in self.ground_description_encoder.parameters():
                param.requires_grad = False
        elif encoder_to_lock == "remote_sensing":
            # freeze the remote sensing encoder parameters
            for param in self.remote_sensing_encoder.parameters():
                param.requires_grad = False
        else:
            raise ValueError(f"encoder_to_lock must be either 'ground_description' or 'remote_sensing'")
        

    def encode_remote_sensing_image(self, x):
        return self.remote_sensing_encoder(x)
    

    def encode_ground_description(self, x):
        return self.ground_description_encoder(x)


    def load_remote_sensing_encoder_weights(self, encoder_path):
        """
        If you have weights for the remote sensing encoder, load them here. 
        """
        self.remote_sensing_encoder.load_state_dict(torch.load(encoder_path))
        print(f"Successfully loaded remote sensing encoder weights from {encoder_path}")

    
    def load_ground_description_encoder_weights(self, encoder_path):
        """
        If you have weights for the ground image encoder, load them here. 
        """
        self.ground_description_encoder.load_state_dict(torch.load(encoder_path))
        print(f"Successfully loaded ground image encoder weights from {encoder_path}")


    def cosine_similarity_logits(self, ground_description, remote_sensing_image):
        """
        Maps images into latent space and computes cosine similarity between them as the logits. 
        We also have this in the loss. 
        Gets the CLIP matrix (one for each image input)
        """
        # see https://github.com/openai/CLIP/blob/a9b1bf5920416aaeaec965c25dd9e8f98c864f16/clip/model.py#LL362C9-L362C9. 
        
        # get latent features 
        ground_description_latent = self.encode_ground_description(ground_description)
        remote_sensing_latent = self.encode_remote_sensing_image(remote_sensing_image)
        
        # normalized features
        ground_description_latent = ground_description_latent / ground_description_latent.norm(dim=1, keepdim=True)
        remote_sensing_latent = remote_sensing_latent / remote_sensing_latent.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_ground_description = logit_scale * ground_description_latent @ remote_sensing_latent.t()
        logits_per_remote_sensing_image = logits_per_ground_description.t()

        return logits_per_ground_description, logits_per_remote_sensing_image
    

    def alternate_cosine_similarity_logits(self, ground_description, remote_sensing_image):
        """
        for debugging. the first version uses openai official repo code.
        This uses https://github.com/mlfoundations/open_clip/blob/6ee59e10510ec9761b8b9871b9fd1eeb8e28627d/src/open_clip/loss.py#L102. 
        Note: after some debugging, this function returns the same value 
        (Pdb) cos_sim_1
        (tensor([[-0.7614]], grad_fn=<MmBackward0>), tensor([[-0.7614]], grad_fn=<TBackward0>))
        (Pdb) cos_sim_1.shape
        *** AttributeError: 'tuple' object has no attribute 'shape'
        (Pdb) cos_sim_2
        (tensor([[-0.7614]], grad_fn=<MmBackward0>), tensor([[-0.7614]], grad_fn=<MmBackward0>))
        (Pdb) cos_sim_1 == cos_sim_2
        True
        """
        # get latent features 
        ground_description_latent = self.encode_ground_description(ground_description)
        remote_sensing_latent = self.encode_remote_sensing_image(remote_sensing_image)
        
        # normalized features
        ground_description_latent = ground_description_latent / ground_description_latent.norm(dim=1, keepdim=True)
        remote_sensing_latent = remote_sensing_latent / remote_sensing_latent.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_ground_description = logit_scale * ground_description_latent @ remote_sensing_latent.T
        logits_per_remote_sensing_image = logit_scale * remote_sensing_latent @ ground_description_latent.T
        
        return logits_per_ground_description, logits_per_remote_sensing_image
    
    
    def forward(self, ground_description, remote_sensing_image):
        """
        Forward pass through the model. Produces cosine sim logits. 
        """
        gr_logits, rs_logits = self.cosine_similarity_logits(ground_description, remote_sensing_image)
        return gr_logits, rs_logits

In [None]:
# ----------------- Loss class ----------------- #


class ClipLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        # calculated ground-truth
        return torch.arange(num_logits, device=device, dtype=torch.long)
        # my confusion is that this will be [0,1,..., batch-size] but why this the label? 
        # I think its cuz you are prediciting which cosine sim score goes with which image-image pair in the batch. 
        # that way, you are optimizing for the diagonal of the matrix to be the highest since i == j (positive pair). 


    def forward(self, logits_per_ground_description, logits_per_remote_sensing_image):
        device = logits_per_ground_description.device

        labels = self.get_ground_truth(device, logits_per_ground_description.shape[0])

        total_loss = (
            F.cross_entropy(logits_per_ground_description, labels) +
            F.cross_entropy(logits_per_remote_sensing_image, labels)
        ) / 2

        return total_loss