<a href="https://colab.research.google.com/github/gremlin97/ToySatCLIP/blob/main/InferenceSatCLIPCaliMedianPrice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install timm
!pip install transformers
!pip install torchgeo
!pip install geoclip
!pip install rasterio

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->timm)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->timm)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->timm)
  Using cach

Collecting geoclip
  Downloading geoclip-1.2.0-py3-none-any.whl (40.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.3/40.3 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: geoclip
Successfully installed geoclip-1.2.0


In [229]:
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import numpy as np
import torchgeo.models
from torchgeo.models import ResNet18_Weights
import rasterio
import torchvision.transforms as transforms
import itertools
import shutil
from google.colab import drive

import torch
from torch import nn
import torch.nn.functional as F
import timm
from geoclip import LocationEncoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [230]:
class Params:
    debug = False
    batch_size = 32*10
    num_workers = 18
    head_lr = 1e-3
    image_encoder_lr = 1e-5
    location_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = 'resnet50'
    image_embedding = 512
    location_embedding = 512
    pretrained = True
    trainable = True
    temperature = 1.0
    size = 224
    num_projection_layers = 1
    projection_dim = 1024
    dropout = 0.1

In [231]:
class Resnet(nn.Module):
    def __init__(
        self, trainable=Params.trainable
    ):
        super().__init__()
        weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
        in_chans = weights.meta["in_chans"]
        self.visual = timm.create_model("resnet18", in_chans=in_chans, num_classes=0)
        self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False)
        self.visual.requires_grad_(False)
        self.visual.fc.requires_grad_(True)
        # for p in self.visual.parameters():
        #   p.requires_grad = trainable

    def forward(self, x):
        return self.visual(x)

In [232]:
class LE(nn.Module):

  def __init__(self):
    super().__init__()
    self.encoder = LocationEncoder()

    for p in self.encoder.parameters():
      p.requires_grad = Params.trainable

  def forward(self, a):
    c = a[0][0]
    d = a[0][1]
    merged_list = [[lat.item(), long.item()] for lat, long in zip(c, d)]
    merged_list = torch.tensor(merged_list)
    merged_list = merged_list.to(device)
    embed = self.encoder(merged_list)
    return embed

In [233]:
class Project(nn.Module):
    def __init__(self, input_dim, output_dim=Params.projection_dim, dropout_prob=Params.dropout):
        super().__init__()
        self.projection_layer = nn.Linear(input_dim, output_dim)
        self.fc_layer = nn.Linear(output_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)
        self.gelu_activation = nn.GELU()
        self.dropout_layer = nn.Dropout(dropout_prob)

    def forward(self, x):
        projected = self.projection_layer(x)
        x = self.gelu_activation(projected)
        x = self.fc_layer(x)
        x = self.dropout_layer(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

In [234]:
class SatCLIP(nn.Module):
    def __init__(self, temperature=Params.temperature, ie=Params.image_embedding, le=Params.location_embedding,):
        super().__init__()
        self.image_encoder = Resnet()
        self.location_encoder = LE()
        self.image_projection = Project(input_dim=ie)
        self.loc_projection = Project(input_dim=le)
        self.temperature = temperature

    def forward(self, batch):
        image_features = self.image_encoder(batch["image"])
        loc = [[batch['lat'], batch['lon']]]
        loc_features = self.location_encoder(loc)
        image_embeddings = self.image_projection(image_features)
        loc_embeddings = self.loc_projection(loc_features)

        logits = (loc_embeddings @ image_embeddings.T) / self.temperature
        image_similarity_matrix = image_embeddings @ image_embeddings.T
        loc_similarity_matrix = loc_embeddings @ loc_embeddings.T
        targets = F.softmax((image_similarity_matrix + loc_similarity_matrix) / 2 * self.temperature, dim=-1)
        loc_loss = calculate_cross_entropy(logits, targets)
        images_loss = calculate_cross_entropy(logits.T, targets.T)
        loss = (images_loss + loc_loss) / 2.0
        return loss.mean()

def calculate_cross_entropy(predictions, targets):
    log_softmax_func = nn.LogSoftmax(dim=-1)
    individual_losses = (-targets * log_softmax_func(predictions)).sum(1)
    return individual_losses

In [235]:
model = SatCLIP()
checkpoint = torch.load("/content/satclip.pt", map_location=Params.device)
model.load_state_dict(checkpoint)
model.eval()

SatCLIP(
  (image_encoder): Resnet(
    (visual): ResNet(
      (conv1): Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act2): ReLU(inplace=True)
        )
        (1): BasicBlock(
          (conv1): 

## Downstream Task
California Median Income Prediction

In [236]:
from sklearn.datasets import fetch_california_housing
california_housing = fetch_california_housing()
df_down = pd.DataFrame(california_housing.data, columns=california_housing.feature_names)
df_down['Price'] = california_housing.target
# df_down = pd.read_csv('/content/housing.csv')

In [237]:
df_down.head()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,Price
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


In [238]:
len(df_down)

20640

In [239]:
class LEDataset(torch.utils.data.Dataset):
    def __init__(self, lat, lon):
        self.lat = list(lat)
        self.lon = list(lon)

    def __getitem__(self, idx):
        it = {}
        it['lat'] = self.lat[idx]
        it['lon'] = self.lon[idx]
        return it

    def __len__(self):
        return len(self.lat)

In [240]:
df_down = df_down[['Price','Latitude','Longitude']]

In [241]:
df_down

Unnamed: 0,Price,Latitude,Longitude
0,4.526,37.88,-122.23
1,3.585,37.86,-122.22
2,3.521,37.85,-122.24
3,3.413,37.85,-122.25
4,3.422,37.85,-122.25
...,...,...,...
20635,0.781,39.48,-121.09
20636,0.771,39.49,-121.21
20637,0.923,39.43,-121.22
20638,0.847,39.43,-121.32


In [242]:
df_down = df_down.dropna()

In [243]:
y = df_down['Price']
X = df_down.drop(['Price'],axis='columns')

In [278]:
train_size = 0.45
train_samples = int(len(X) * train_size)

X_train, X_val = X[:train_samples], X[train_samples:]
y_train, y_val = y[:train_samples], y[train_samples:]

print("Train set shape:", X_train.shape, y_train.shape)
print("Validation set shape:", X_val.shape, y_val.shape)

Train set shape: (9288, 2) (9288,)
Validation set shape: (11352, 2) (11352,)


In [279]:
dataset_train = LEDataset(X_train["Latitude"].values, X_train['Longitude'].values)
infer_data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=Params.batch_size)

dataset_val = LEDataset(X_val["Latitude"].values, X_val['Longitude'].values)
infer_data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=Params.batch_size)

In [280]:
infer_embeddings_train = []
infer_embeddings_val = []

with torch.no_grad():
    for batch in tqdm(infer_data_loader_train):
        loc = [[batch['lat'], batch['lon']]]
        loc_features = model.location_encoder(loc)
        loc_embeddings = model.loc_projection(loc_features)
        infer_embeddings_train.append(loc_embeddings)

with torch.no_grad():
    for batch in tqdm(infer_data_loader_val):
        loc = [[batch['lat'], batch['lon']]]
        loc_features = model.location_encoder(loc)
        loc_embeddings = model.loc_projection(loc_features)
        infer_embeddings_val.append(loc_embeddings)

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/36 [00:00<?, ?it/s]

In [281]:
concatenated_embeddings_train = torch.cat(infer_embeddings_train, dim=0)
concatenated_embeddings_val = torch.cat(infer_embeddings_val, dim=0)
flattened_embeddings_train = concatenated_embeddings_train.view(concatenated_embeddings_train.size(0), -1)
flattened_embeddings_val = concatenated_embeddings_val.view(concatenated_embeddings_val.size(0), -1)

In [282]:
embeddings_train = [tensor.numpy().astype(np.float64) for tensor in flattened_embeddings_train]
embeddings_val = [tensor.numpy().astype(np.float64) for tensor in flattened_embeddings_val]

In [283]:
print(embeddings_train[0], embeddings_train[0].shape, embeddings_val[0], embeddings_val[0].shape)

[-0.06119435 -0.30125538  0.08766108 ...  0.0005298   0.12854132
  0.0954475 ] (1024,) [-0.096979   -0.19978145  0.10853552 ... -0.02400015  0.16459656
  0.05200193] (1024,)


In [284]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

lr = LinearRegression()
lr.fit(embeddings_train, y_train)

# Making predictions on the test set
y_pred = lr.predict(embeddings_val)

# Calculating Mean Squared Error (MSE) on the val set
mse = mean_squared_error(y_val, y_pred)
print("Mean Squared Error:", mse)

Mean Squared Error: 2.847136248021505
