## Installing Dependencies

In [1]:
!pip install transformers



In [2]:
!pip install pytorch==1.7.1 torchvision
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

[31mERROR: Could not find a version that satisfies the requirement pytorch==1.7.1 (from versions: 0.1.2, 1.0.2)[0m[31m
[0m[31mERROR: No matching distribution found for pytorch==1.7.1[0m[31m
[0mCollecting ftfy
  Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.2.0
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-fuxlfutv
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-fuxlfutv
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->clip==1.0)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23

## Importing Libraries

In [3]:
import torch
import clip
from PIL import Image
import os
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim

# Running Inference on Pre-Trained Model

## Loading Model

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

## Preparing Data

In [19]:
classes = [
    'forest',
    'permanent crop land',
    'residential buildings or homes or apartments',
    'river',
    'pasture land',
    'lake or sea',
    'brushland or shrubland',
    'annual crop land',
    'industrial buildings or commercial buildings',
    'highway or road',
]

dataPath = os.path.join(os.getcwd(), "test")
images = []
groundTruth_labels = []

for cls in classes:
  for img in os.listdir(os.path.join(dataPath, cls)):
    images.append(os.path.join(dataPath, cls, img))
    groundTruth_labels.append(f"a centered satellite photo of {cls}")

In [20]:
# Function to analyze an image
def analyze_image(image_path, descriptions):
    # Preprocess the image
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

    # Encode the descriptions
    text = clip.tokenize(descriptions).to(device)

    # Get the image and text features
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

    # Calculate similarity
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (image_features @ text_features.T).squeeze()

    # Find the description with the highest similarity
    best_match_idx = similarity.argmax().item()
    best_description = descriptions[best_match_idx]
    return best_description

In [21]:
preTrained_predictions = []
for img in tqdm(images, desc="Analyzing images"):
  cls = analyze_image(img, classes)
  preTrained_predictions.append(f"a centered satellite photo of {cls}")

Analyzing images: 100%|██████████| 5000/5000 [01:48<00:00, 46.03it/s]


## Metrics

In [22]:
acc = accuracy_score(preTrained_predictions, groundTruth_labels)
print(acc)

0.4218


# Fine-Tuning CLIP

In [5]:
# loading model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

100%|████████████████████████████████████████| 338M/338M [00:03<00:00, 104MiB/s]


In [6]:
# defining custom dataset
class CustomDataset(Dataset):
  def __init__(self, images, descriptions):
    self.images = images
    self.descriptions = clip.tokenize(descriptions)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    image = preprocess(Image.open(self.images[idx]))
    description = self.descriptions[idx]
    return image, description

In [7]:
# defining descriptions
descriptions = ["annual crop land",
                "forest",
                "lake or sea",
                "pasture land",
                "permanent crop land",
                "river",
                "residential buildings or homes or apartments",
                "industrial buildings or commercial buildings",
                "highway or road",
                "brushland or shrubland"]

# preparing dataset
training_path = os.path.join(os.getcwd(), "train")

training_images = []
training_descriptions = []

for cls in descriptions:
  for img in os.listdir(os.path.join(training_path, cls)):
    training_images.append(os.path.join(training_path, cls, img))
    training_descriptions.append(f"a centered satellite photo of {cls}")

training_dataset = CustomDataset(training_images, training_descriptions)
training_dataloader = DataLoader(training_dataset, batch_size=32, shuffle=True)

In [8]:
# function to convert model's parameters to FP32 format
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [9]:
# defining optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9,0.98), eps=1e-6, weight_decay=0.2)

# loss functions
loss_img = torch.nn.CrossEntropyLoss()
loss_descriptions = torch.nn.CrossEntropyLoss()

In [10]:
# training loop
num_epochs = 16
for epoch in range(num_epochs):
  total_loss = 0
  pbar = tqdm(training_dataloader, total = len(training_dataloader))
  #for images, texts in training_dataloader:
  for batch in pbar:
    # zero out gradients
    optimizer.zero_grad()

    images, texts = batch
    images = images.to(device)
    texts = texts.to(device)

    # forward pass
    logits_per_image, logits_per_text = model(images, texts)

    # compute loss
    ground_truth = torch.arange(len(images), dtype = torch.long, device = device)
    loss = (loss_img(logits_per_image, ground_truth) + loss_descriptions(logits_per_text, ground_truth)) / 2

    # backward pass
    loss.backward()
    if device == "cpu":
      optimizer.step()
    else:
      convert_models_to_fp32(model)
      optimizer.step()
      clip.model.convert_weights(model)

    total_loss += loss.item()
    avg_loss = total_loss / len(training_dataloader)
    pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

Epoch 0/16, Loss: 2.2051: 100%|██████████| 688/688 [02:11<00:00,  5.24it/s]
Epoch 1/16, Loss: 2.1951: 100%|██████████| 688/688 [02:09<00:00,  5.32it/s]
Epoch 2/16, Loss: 2.1397: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 3/16, Loss: 2.0760: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 4/16, Loss: 2.0547: 100%|██████████| 688/688 [02:08<00:00,  5.34it/s]
Epoch 5/16, Loss: 2.0288: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 6/16, Loss: 2.0133: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 7/16, Loss: 2.0014: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 8/16, Loss: 1.9903: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 9/16, Loss: 1.9782: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 10/16, Loss: 1.9664: 100%|██████████| 688/688 [02:08<00:00,  5.34it/s]
Epoch 11/16, Loss: 1.9684: 100%|██████████| 688/688 [02:08<00:00,  5.35it/s]
Epoch 12/16, Loss: 1.9593: 100%|██████████| 688/688 [02:08<00:00,  5.34it/s]
Epoch 13/

In [11]:
# saving the model
torch.save(model.state_dict(), "euroSATclip.pt")

In [23]:
# running predictions on test set
!unzip test.zip

classes = [
    'forest',
    'permanent crop land',
    'residential buildings or homes or apartments',
    'river',
    'pasture land',
    'lake or sea',
    'brushland or shrubland',
    'annual crop land',
    'industrial buildings or commercial buildings',
    'highway or road',
]

dataPath = os.path.join(os.getcwd(), "test")
testImages = []
groundTruth_Testlabels = []

for cls in classes:
  for img in os.listdir(os.path.join(dataPath, cls)):
    testImages.append(os.path.join(dataPath, cls, img))
    groundTruth_Testlabels.append(f"a centered satellite photo of {cls}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: test/permanent crop land/PermanentCrop_2463.jpg  
  inflating: test/permanent crop land/PermanentCrop_2305.jpg  
  inflating: test/permanent crop land/PermanentCrop_2311.jpg  
  inflating: test/permanent crop land/PermanentCrop_2477.jpg  
  inflating: test/permanent crop land/PermanentCrop_2339.jpg  
  inflating: test/permanent crop land/PermanentCrop_2107.jpg  
  inflating: test/permanent crop land/PermanentCrop_2113.jpg  
  inflating: test/permanent crop land/PermanentCrop_2098.jpg  
  inflating: test/permanent crop land/PermanentCrop_2073.jpg  
  inflating: test/permanent crop land/PermanentCrop_2067.jpg  
  inflating: test/permanent crop land/PermanentCrop_2271.jpg  
  inflating: test/permanent crop land/PermanentCrop_2265.jpg  
  inflating: test/permanent crop land/PermanentCrop_2259.jpg  
  inflating: test/permanent crop land/PermanentCrop_2258.jpg  
  inflating: test/permanent crop land/PermanentCrop_2

In [13]:
# Function to analyze an image
def analyze_image(image_path, descriptions):
    # Preprocess the image
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

    # Encode the descriptions
    text = clip.tokenize(descriptions).to(device)

    # Get the image and text features
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

    # Calculate similarity
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (image_features @ text_features.T).squeeze()

    # Find the description with the highest similarity
    best_match_idx = similarity.argmax().item()
    best_description = descriptions[best_match_idx]
    return best_description

In [24]:
model.load_state_dict(torch.load("euroSATclip.pt"))

fineTuned_predictions = []
for img in tqdm(testImages, desc="Analyzing images"):
  cls = analyze_image(img, classes)
  fineTuned_predictions.append(f"a centered satellite photo of {cls}")

Analyzing images: 100%|██████████| 5000/5000 [01:47<00:00, 46.35it/s]


In [25]:
acc = accuracy_score(fineTuned_predictions, groundTruth_labels)
print(acc)

0.7376
